83 lines
2.1 KiB
Python
83 lines
2.1 KiB
Python
# %%
|
|
from engine import train_one_epoch, evaluate
|
|
from model import Model
|
|
from data import iNaturalistDataset
|
|
import torch
|
|
import os
|
|
import datetime as dt
|
|
import json
|
|
import utils
|
|
|
|
if not os.path.exists("models/"):
|
|
os.mkdir("models")
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
model_path = model_root + ".pth"
|
|
model_info = model_root + ".json"
|
|
|
|
|
|
species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"])
|
|
model_type = "fasterrcnn_mobilenet_v3_large_fpn"
|
|
|
|
|
|
def run():
|
|
val_dataset = iNaturalistDataset(
|
|
validation=True,
|
|
species=species_list,
|
|
)
|
|
train_dataset = iNaturalistDataset(
|
|
train=True,
|
|
species=species_list,
|
|
)
|
|
|
|
with open(model_info, "w") as js_p:
|
|
json.dump(
|
|
{"categories": train_dataset.categories, "model_type": model_type},
|
|
js_p,
|
|
default=str,
|
|
indent=4,
|
|
)
|
|
|
|
train_data_loader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
batch_size=16,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
collate_fn=utils.collate_fn,
|
|
)
|
|
|
|
val_data_loader = torch.utils.data.DataLoader(
|
|
val_dataset,
|
|
batch_size=16,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
collate_fn=utils.collate_fn,
|
|
)
|
|
|
|
num_classes = len(species_list) + 1
|
|
model = Model(num_classes, model_type)
|
|
model.to(device)
|
|
|
|
params = [p for p in model.parameters() if p.requires_grad]
|
|
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
|
|
|
|
num_epochs = 10
|
|
for epoch in range(num_epochs):
|
|
train_one_epoch(
|
|
model, optimizer, train_data_loader, device, epoch, print_freq=10
|
|
)
|
|
lr_scheduler.step()
|
|
torch.save(model.state_dict(), model_path)
|
|
evaluate(model, val_data_loader, device=device)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|