import
torchvision.transforms as transforms
from
torchvision.datasets
import
ImageFolder
from
tqdm
import
tqdm
import
torch
import
torchvision
import
torch.nn as nn
from
torch.utils.data
import
DataLoader
import
numpy as np
data_transform
=
transforms.Compose(
[
transforms.Resize((
224
,
224
)),
transforms.ToTensor(),
transforms.Normalize(
(
0.485
,
0.456
,
0.406
), (
0.229
,
0.224
,
0.225
)
),
]
)
class
CustomDataset:
def
__init__(
self
, data_dir):
self
.dataset
=
ImageFolder(root
=
data_dir, transform
=
data_transform)
def
__len__(
self
):
return
len
(
self
.dataset)
def
__getitem__(
self
, idx):
image, label
=
self
.dataset[idx]
return
image, label
class
MyResNet18(torch.nn.Module):
def
__init__(
self
, num_classes):
super
(MyResNet18,
self
).__init__()
self
.resnet
=
torchvision.models.resnet18(pretrained
=
True
)
self
.resnet.fc
=
nn.Linear(
512
, num_classes)
def
forward(
self
, x):
return
self
.resnet(x)
def
train(epoch):
device
=
torch.device(
"cuda"
if
torch.cuda.is_available()
else
"cpu"
)
data_dir
=
"dataset"
custom_dataset
=
CustomDataset(data_dir)
batch_size
=
64
data_loader
=
DataLoader(custom_dataset, batch_size
=
batch_size, shuffle
=
True
)
model
=
MyResNet18(num_classes
=
91
)
model.to(device)
criterion
=
torch.nn.CrossEntropyLoss()
optimizer
=
torch.optim.SGD(model.parameters(), lr
=
0.001
, momentum
=
0.9
)
for
i
in
range
(epoch):
losses
=
[]
data_loader_tqdm
=
tqdm(data_loader)
for
inputs, labels
in
data_loader_tqdm:
inputs, labels
=
inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs
=
model(inputs)
loss
=
criterion(outputs, labels)
losses.append(loss.item())
epoch_loss
=
np.mean(losses)
data_loader_tqdm.set_description(
f
"This epoch is {i} and it's loss is {loss.item()}, average loss {epoch_loss}"
)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), f
'model/my_resnet18_{epoch_loss}.pth'
)
print
(f
"completed. Model saved."
)
if
__name__
=
=
'__main__'
:
train(
50
)