yawc
This commit is contained in:
90
train.py
90
train.py
@@ -7,7 +7,8 @@ import os
|
||||
import datetime as dt
|
||||
import json
|
||||
import utils
|
||||
|
||||
import pandas as pd
|
||||
import sys
|
||||
if not os.path.exists("models/"):
|
||||
os.mkdir("models")
|
||||
|
||||
@@ -16,16 +17,27 @@ if torch.cuda.is_available():
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_path = model_root + ".pth"
|
||||
model_info = model_root + ".json"
|
||||
default_model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
default_model_path = default_model_root + ".pth"
|
||||
default_model_info = default_model_root + ".json"
|
||||
default_state_path = default_model_root + ".oth"
|
||||
|
||||
|
||||
species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"])
|
||||
model_type = "fasterrcnn_mobilenet_v3_large_fpn"
|
||||
#species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"])
|
||||
|
||||
csv_path = '/home/thebears/Seafile/Designs/ML/inaturalist_models/species_occurence.csv'
|
||||
df = pd.read_csv(csv_path)
|
||||
species_list = set(list(df[df['count']>1000]['species']))
|
||||
|
||||
#model_type = "fasterrcnn_mobilenet_v3_large_fpn"
|
||||
#batch_size = 16
|
||||
|
||||
model_type = 'fasterrcnn_resnet50_fpn'
|
||||
batch_size = 8
|
||||
|
||||
|
||||
def run():
|
||||
|
||||
def run(model_name = None, epoch_start = 0):
|
||||
val_dataset = iNaturalistDataset(
|
||||
validation=True,
|
||||
species=species_list,
|
||||
@@ -35,27 +47,41 @@ def run():
|
||||
species=species_list,
|
||||
)
|
||||
|
||||
with open(model_info, "w") as js_p:
|
||||
json.dump(
|
||||
{"categories": train_dataset.categories, "model_type": model_type},
|
||||
js_p,
|
||||
default=str,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
if model_name is None:
|
||||
fresh_start = True
|
||||
model_info = default_model_info
|
||||
model_path = default_model_path
|
||||
state_path = default_state_path
|
||||
else:
|
||||
fresh_start = False
|
||||
model_info = model_name.rstrip('.pth').rstrip('.json')+'.json'
|
||||
model_path = model_info.rstrip('.json')+'.pth'
|
||||
state_path = model_info.rstrip('.json')+'.oth'
|
||||
|
||||
|
||||
if fresh_start:
|
||||
with open(model_info, "w") as js_p:
|
||||
json.dump(
|
||||
{"categories": train_dataset.categories, "model_type": model_type},
|
||||
js_p,
|
||||
default=str,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
train_data_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=16,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
num_workers=10,
|
||||
collate_fn=utils.collate_fn,
|
||||
)
|
||||
|
||||
val_data_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=16,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
num_workers=10,
|
||||
collate_fn=utils.collate_fn,
|
||||
)
|
||||
|
||||
@@ -63,20 +89,38 @@ def run():
|
||||
model = Model(num_classes, model_type)
|
||||
model.to(device)
|
||||
|
||||
if not fresh_start:
|
||||
model.load_state_dict(
|
||||
torch.load(model_path, map_location = torch.device(device))
|
||||
)
|
||||
|
||||
params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
|
||||
|
||||
if os.path.exists(state_path):
|
||||
optimizer.load_state_dict(torch.load(state_path, map_location = torch.device(device)))
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
|
||||
|
||||
|
||||
num_epochs = 10
|
||||
for epoch in range(num_epochs):
|
||||
|
||||
for epoch in range(epoch_start, num_epochs):
|
||||
train_one_epoch(
|
||||
model, optimizer, train_data_loader, device, epoch, print_freq=10
|
||||
)
|
||||
model, optimizer, train_data_loader, device, epoch, print_freq=10 )
|
||||
lr_scheduler.step()
|
||||
torch.save(model.state_dict(), model_path)
|
||||
torch.save(optimizer.state_dict(), state_path)
|
||||
evaluate(model, val_data_loader, device=device)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
if len(sys.argv) == 3:
|
||||
model_name = sys.argv[1]
|
||||
epoch_start = int(sys.argv[2])
|
||||
run(model_name = model_name, epoch_start = epoch_start)
|
||||
else:
|
||||
run()
|
||||
|
||||
# run()
|
||||
|
||||
Reference in New Issue
Block a user