# %% from engine import train_one_epoch, evaluate from model import Model from data import iNaturalistDataset import torch import os import datetime as dt import json import utils import pandas as pd import sys if not os.path.exists("models/"): os.mkdir("models") if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") default_model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S") default_model_path = default_model_root + ".pth" default_model_info = default_model_root + ".json" default_state_path = default_model_root + ".oth" default_sched_path = default_model_root + ".sth" #species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"]) csv_path = '/home/thebears/Seafile/Designs/ML/inaturalist_models/species_occurence.csv' df = pd.read_csv(csv_path) species_list = set(list(df[df['count']>1000]['species'])) #model_type = "fasterrcnn_mobilenet_v3_large_fpn" #batch_size = 16 model_type = 'fasterrcnn_resnet50_fpn' batch_size = 8 num_epochs = 10 def run(model_name = None, epoch_start = 0): val_dataset = iNaturalistDataset( validation=True, species=species_list, ) train_dataset = iNaturalistDataset( train=True, species=species_list, ) print(len(val_dataset.categories)) print(len(train_dataset.categories)) if model_name is None: fresh_start = True model_info = default_model_info model_path = default_model_path state_path = default_state_path sched_path = default_sched_path else: fresh_start = False model_info = model_name.rstrip('.pth').rstrip('.json')+'.json' model_path = model_info.rstrip('.json')+'.pth' state_path = model_info.rstrip('.json')+'.oth' sched_path = model_info.rstrip('.json')+'.sth' if fresh_start: with open(model_info, "w") as js_p: json.dump( {"categories": train_dataset.categories, "model_type": model_type}, js_p, default=str, indent=4, ) else: print('Continuing run') train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, collate_fn=utils.collate_fn, ) val_data_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=True, num_workers=10, collate_fn=utils.collate_fn, ) num_classes = len(train_dataset.categories) + 1 model = Model(num_classes, model_type) model.to(device) if not fresh_start: print('Loading state dict') model.load_state_dict( torch.load(model_path, map_location = torch.device(device)) ) params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005) if os.path.exists(state_path): print('Loading optimizer') optimizer.load_state_dict(torch.load(state_path, map_location = torch.device(device))) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1) if os.path.exists(sched_path): print('Loading scheduler') lr_scheduler.load_state_dict(torch.load(sched_path, map_location = torch.device(device))) for epoch in range(epoch_start, num_epochs): print('Epoch '+str(epoch)) train_one_epoch( model, optimizer, train_data_loader, device, epoch, print_freq=10 ) lr_scheduler.step() torch.save(model.state_dict(), model_path) torch.save(optimizer.state_dict(), state_path) torch.save(lr_scheduler.state_dict(), sched_path) evaluate(model, val_data_loader, device=device) if __name__ == "__main__": if len(sys.argv) == 3: model_name = sys.argv[1] epoch_start = int(sys.argv[2]) run(model_name = model_name, epoch_start = epoch_start) else: run() # run()