This commit is contained in:
2021-07-01 20:26:24 -04:00
parent 8b02bf9a8c
commit f46d193826
16 changed files with 433 additions and 146 deletions

View File

@@ -4,49 +4,79 @@ from model import Model
from data import iNaturalistDataset
import torch
import os
import time
import datetime as dt
import json
import utils
if not os.path.exists("models/"):
os.mkdir("models")
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = model_root + ".pth"
model_info = model_root + ".json"
species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"])
model_type = "fasterrcnn_mobilenet_v3_large_fpn"
if not os.path.exists('models/'):
os.mkdirs('models')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def run():
val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True))
train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False))
val_dataset = iNaturalistDataset(
validation=True,
species=species_list,
)
train_dataset = iNaturalistDataset(
train=True,
species=species_list,
)
with open(model_info, "w") as js_p:
json.dump(
{"categories": train_dataset.categories, "model_type": model_type},
js_p,
default=str,
indent=4,
)
train_data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
)
val_data_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
train_dataset,
batch_size=8,
shuffle=True,
num_workers=4,
collate_fn=utils.collate_fn,
)
num_classes = 5
model = Model(num_classes)
val_data_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=8,
shuffle=True,
num_workers=4,
collate_fn=utils.collate_fn,
)
num_classes = len(species_list) + 1
model = Model(num_classes, model_type)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 10
for epoch in range(num_epochs):
print(epoch)
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth')
# train for one epoch, printing every 10 iterations
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
# update the learning rate
train_one_epoch(
model, optimizer, train_data_loader, device, epoch, print_freq=10
)
lr_scheduler.step()
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth')
# evaluate on the test dataset
engine.evaluate(model, val_data_loader, device=device)
torch.save(model.state_dict(), model_path)
evaluate(model, val_data_loader, device=device)
if __name__ == "__main__":
run()
if __name__ == "__main__":
run()