yacwc
This commit is contained in:
52
train.py
52
train.py
@@ -0,0 +1,52 @@
|
||||
# %%
|
||||
from engine import train_one_epoch, evaluate
|
||||
from model import Model
|
||||
from data import iNaturalistDataset
|
||||
import torch
|
||||
import os
|
||||
import time
|
||||
|
||||
if not os.path.exists('models/'):
|
||||
os.mkdirs('models')
|
||||
|
||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||
def run():
|
||||
val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True))
|
||||
train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False))
|
||||
|
||||
|
||||
train_data_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
||||
)
|
||||
val_data_loader = torch.utils.data.DataLoader(
|
||||
val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
||||
)
|
||||
|
||||
num_classes = 5
|
||||
model = Model(num_classes)
|
||||
model.to(device)
|
||||
|
||||
params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.SGD(params, lr=0.005,
|
||||
momentum=0.9, weight_decay=0.0005)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
|
||||
step_size=3,
|
||||
gamma=0.1)
|
||||
|
||||
num_epochs = 10
|
||||
for epoch in range(num_epochs):
|
||||
print(epoch)
|
||||
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth')
|
||||
# train for one epoch, printing every 10 iterations
|
||||
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
|
||||
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
|
||||
# update the learning rate
|
||||
lr_scheduler.step()
|
||||
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth')
|
||||
# evaluate on the test dataset
|
||||
engine.evaluate(model, val_data_loader, device=device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
Reference in New Issue
Block a user