52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
# %%
|
|
from engine import train_one_epoch, evaluate
|
|
from model import Model
|
|
from data import iNaturalistDataset
|
|
import torch
|
|
import os
|
|
import time
|
|
|
|
if not os.path.exists('models/'):
|
|
os.mkdirs('models')
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
def run():
|
|
val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True))
|
|
train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False))
|
|
|
|
|
|
train_data_loader = torch.utils.data.DataLoader(
|
|
train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
|
)
|
|
val_data_loader = torch.utils.data.DataLoader(
|
|
val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
|
)
|
|
|
|
num_classes = 5
|
|
model = Model(num_classes)
|
|
model.to(device)
|
|
|
|
params = [p for p in model.parameters() if p.requires_grad]
|
|
optimizer = torch.optim.SGD(params, lr=0.005,
|
|
momentum=0.9, weight_decay=0.0005)
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
|
|
step_size=3,
|
|
gamma=0.1)
|
|
|
|
num_epochs = 10
|
|
for epoch in range(num_epochs):
|
|
print(epoch)
|
|
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth')
|
|
# train for one epoch, printing every 10 iterations
|
|
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
|
|
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
|
|
# update the learning rate
|
|
lr_scheduler.step()
|
|
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth')
|
|
# evaluate on the test dataset
|
|
engine.evaluate(model, val_data_loader, device=device)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run() |