# %% from engine import train_one_epoch, evaluate from model import Model from data import iNaturalistDataset import torch import os import datetime as dt import json import utils if not os.path.exists("models/"): os.mkdir("models") if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S") model_path = model_root + ".pth" model_info = model_root + ".json" species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"]) model_type = "fasterrcnn_mobilenet_v3_large_fpn" def run(): val_dataset = iNaturalistDataset( validation=True, species=species_list, ) train_dataset = iNaturalistDataset( train=True, species=species_list, ) with open(model_info, "w") as js_p: json.dump( {"categories": train_dataset.categories, "model_type": model_type}, js_p, default=str, indent=4, ) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=16, shuffle=True, num_workers=4, collate_fn=utils.collate_fn, ) val_data_loader = torch.utils.data.DataLoader( val_dataset, batch_size=16, shuffle=True, num_workers=4, collate_fn=utils.collate_fn, ) num_classes = len(species_list) + 1 model = Model(num_classes, model_type) model.to(device) params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) num_epochs = 10 for epoch in range(num_epochs): train_one_epoch( model, optimizer, train_data_loader, device, epoch, print_freq=10 ) lr_scheduler.step() torch.save(model.state_dict(), model_path) evaluate(model, val_data_loader, device=device) if __name__ == "__main__": run()