yacwc
This commit is contained in:
143
data.py
143
data.py
@@ -1,26 +1,20 @@
|
|||||||
# %%
|
# %%
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms as T
|
from torchvision import transforms as T
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append(r"K:\Designs\ML\inaturalist_models\data_aug")
|
|
||||||
sys.path.append(r"K:\Designs\ML\inaturalist_models\vision")
|
|
||||||
from references.detection import utils, engine
|
|
||||||
import data_aug
|
|
||||||
import bbox_util
|
|
||||||
|
|
||||||
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
|
if sys.platform == 'win32':
|
||||||
|
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Not defined for this platform")
|
||||||
|
|
||||||
def get_transform(train):
|
def get_transform(train):
|
||||||
transforms = []
|
transforms = []
|
||||||
@@ -29,9 +23,6 @@ def get_transform(train):
|
|||||||
transforms.append(T.RandomHorizontalFlip(0.5))
|
transforms.append(T.RandomHorizontalFlip(0.5))
|
||||||
return T.Compose(transforms)
|
return T.Compose(transforms)
|
||||||
|
|
||||||
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
|
|
||||||
|
|
||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
||||||
|
|
||||||
def create_map(list_in, from_key, to_key):
|
def create_map(list_in, from_key, to_key):
|
||||||
cmap = dict()
|
cmap = dict()
|
||||||
@@ -41,18 +32,22 @@ def create_map(list_in, from_key, to_key):
|
|||||||
|
|
||||||
|
|
||||||
class iNaturalistDataset(torch.utils.data.Dataset):
|
class iNaturalistDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, validation=False, train=False, transforms = None):
|
def __init__(self, validation=False, train=False, transforms = None, species = None):
|
||||||
|
|
||||||
self.validation = validation
|
self.validation = validation
|
||||||
self.train = train
|
self.train = train
|
||||||
self.transforms = transforms
|
|
||||||
|
if (self.train or self.validation) or (self.train and self.validation)
|
||||||
|
raise Exception("Need to do either train or validation")
|
||||||
|
|
||||||
|
self.transforms = get_transform(self.train)
|
||||||
|
|
||||||
|
|
||||||
if validation:
|
if validation:
|
||||||
json_path = os.path.join(PATH_ROOT, r"val_2017_bboxes\val_2017_bboxes.json")
|
json_path = os.path.join(PATH_ROOT, "val_2017_bboxes","val_2017_bboxes.json")
|
||||||
elif train:
|
elif train:
|
||||||
json_path = os.path.join(
|
json_path = os.path.join(
|
||||||
PATH_ROOT, r"train_2017_bboxes\train_2017_bboxes.json"
|
PATH_ROOT, "train_2017_bboxes","train_2017_bboxes.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(json_path, "r") as rj:
|
with open(json_path, "r") as rj:
|
||||||
@@ -61,12 +56,16 @@ class iNaturalistDataset(torch.utils.data.Dataset):
|
|||||||
categories = list()
|
categories = list()
|
||||||
image_info = dict()
|
image_info = dict()
|
||||||
|
|
||||||
for category in f["categories"]:
|
|
||||||
if category["supercategory"] == "Aves":
|
|
||||||
if category['name'] in ['Archilochus colubris']:#,'Icterus galbula']:
|
|
||||||
print(category['name'])
|
|
||||||
categories.append(category)
|
|
||||||
|
|
||||||
|
for category in f["categories"]:
|
||||||
|
do_add = False
|
||||||
|
if species is None:
|
||||||
|
do_add = True
|
||||||
|
if category['name'] in species:
|
||||||
|
print(category['name'])
|
||||||
|
categories.append(category)
|
||||||
|
|
||||||
categories = sorted(categories, key=lambda k: k["name"])
|
categories = sorted(categories, key=lambda k: k["name"])
|
||||||
for idx, cat in enumerate(categories):
|
for idx, cat in enumerate(categories):
|
||||||
cat["new_id"] = idx + 1
|
cat["new_id"] = idx + 1
|
||||||
@@ -94,6 +93,7 @@ class iNaturalistDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
for idx, (id, im_in) in enumerate(image_info.items()):
|
for idx, (id, im_in) in enumerate(image_info.items()):
|
||||||
im_in["idx"] = idx
|
im_in["idx"] = idx
|
||||||
|
|
||||||
self.images = image_info
|
self.images = image_info
|
||||||
self.categories = categories
|
self.categories = categories
|
||||||
self.idx_to_id = [x for x in self.images]
|
self.idx_to_id = [x for x in self.images]
|
||||||
@@ -120,102 +120,7 @@ class iNaturalistDataset(torch.utils.data.Dataset):
|
|||||||
target['area'] = torch.as_tensor([annot['area']])
|
target['area'] = torch.as_tensor([annot['area']])
|
||||||
target['iscrowd'] = torch.zeros((1,), dtype=torch.int64)
|
target['iscrowd'] = torch.zeros((1,), dtype=torch.int64)
|
||||||
|
|
||||||
|
|
||||||
if self.transforms is not None:
|
if self.transforms is not None:
|
||||||
img, target = self.transforms(img, target)
|
img, target = self.transforms(img, target)
|
||||||
|
|
||||||
return img, target
|
return img, target
|
||||||
# %%
|
|
||||||
# v = iNaturalistDataset(validation=True)
|
|
||||||
|
|
||||||
|
|
||||||
# v = iNaturalistDataset(validation= True)
|
|
||||||
# o = v[10]
|
|
||||||
# %%
|
|
||||||
# oimage = t.tensor(o[0]*255, dtype=t.uint8)
|
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
# ox = draw_bounding_boxes(oimage, o[1]['boxes'], width=1)
|
|
||||||
# plt.imshow(ox.permute([1,2,0]))
|
|
||||||
# plt.savefig('crap2.png')
|
|
||||||
|
|
||||||
def get_model(num_classes):
|
|
||||||
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
|
|
||||||
num_classes = 2 # 1 class (person) + background
|
|
||||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
|
||||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
import transforms as T
|
|
||||||
|
|
||||||
def get_transform(train):
|
|
||||||
transforms = []
|
|
||||||
transforms.append(T.ToTensor())
|
|
||||||
if train:
|
|
||||||
transforms.append(T.RandomHorizontalFlip(0.5))
|
|
||||||
return T.Compose(transforms)
|
|
||||||
|
|
||||||
from engine import train_one_epoch, evaluate
|
|
||||||
import utils
|
|
||||||
# %%
|
|
||||||
def run():
|
|
||||||
val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True))
|
|
||||||
train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False))
|
|
||||||
|
|
||||||
|
|
||||||
train_data_loader = torch.utils.data.DataLoader(
|
|
||||||
train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
|
||||||
)
|
|
||||||
val_data_loader = torch.utils.data.DataLoader(
|
|
||||||
val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
import torchvision
|
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
||||||
num_classes = 2
|
|
||||||
|
|
||||||
|
|
||||||
model = get_model(num_classes)
|
|
||||||
model.to(device)
|
|
||||||
# construct an optimizer
|
|
||||||
params = [p for p in model.parameters() if p.requires_grad]
|
|
||||||
optimizer = torch.optim.SGD(params, lr=0.005,
|
|
||||||
momentum=0.9, weight_decay=0.0005)
|
|
||||||
# and a learning rate scheduler
|
|
||||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
|
|
||||||
step_size=3,
|
|
||||||
gamma=0.1)
|
|
||||||
|
|
||||||
# let's train it for 10 epochs
|
|
||||||
num_epochs = 10
|
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
print(epoch)
|
|
||||||
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth')
|
|
||||||
# train for one epoch, printing every 10 iterations
|
|
||||||
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
|
|
||||||
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
|
|
||||||
# update the learning rate
|
|
||||||
lr_scheduler.step()
|
|
||||||
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth')
|
|
||||||
# evaluate on the test dataset
|
|
||||||
engine.evaluate(model, val_data_loader, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# # %%
|
|
||||||
# json_path = os.path.join(
|
|
||||||
# PATH_ROOT, r"train_2017_bboxes\train_2017_bboxes.json"
|
|
||||||
# )
|
|
||||||
# with open(json_path, "r") as rj:
|
|
||||||
# f = json.load(rj)
|
|
||||||
|
|
||||||
|
|
||||||
# # %%
|
|
||||||
# image_id: 2358
|
|
||||||
|
|
||||||
216
model.py
216
model.py
@@ -1,221 +1,11 @@
|
|||||||
# %%
|
# %%
|
||||||
import os
|
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
import torchvision
|
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
from torchvision import transforms as T
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
sys.path.append(r"K:\Designs\ML\inaturalist_models\data_aug")
|
def Model(num_classes):
|
||||||
sys.path.append(r"K:\Designs\ML\inaturalist_models\vision")
|
model = fasterrcnn_resnet50_fpn(pretrained=True)
|
||||||
from references.detection import utils, engine
|
|
||||||
import data_aug
|
|
||||||
import bbox_util
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_transform(train):
|
|
||||||
transforms = []
|
|
||||||
transforms.append(T.ToTensor())
|
|
||||||
if train:
|
|
||||||
transforms.append(T.RandomHorizontalFlip(0.5))
|
|
||||||
return T.Compose(transforms)
|
|
||||||
|
|
||||||
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
|
|
||||||
|
|
||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
||||||
|
|
||||||
def create_map(list_in, from_key, to_key):
|
|
||||||
cmap = dict()
|
|
||||||
for l in list_in:
|
|
||||||
cmap[l[from_key]] = l[to_key]
|
|
||||||
return cmap
|
|
||||||
|
|
||||||
|
|
||||||
class iNaturalistDataset(torch.utils.data.Dataset):
|
|
||||||
def __init__(self, validation=False, train=False, transforms = None):
|
|
||||||
|
|
||||||
self.validation = validation
|
|
||||||
self.train = train
|
|
||||||
self.transforms = transforms
|
|
||||||
|
|
||||||
|
|
||||||
if validation:
|
|
||||||
json_path = os.path.join(PATH_ROOT, r"val_2017_bboxes\val_2017_bboxes.json")
|
|
||||||
elif train:
|
|
||||||
json_path = os.path.join(
|
|
||||||
PATH_ROOT, r"train_2017_bboxes\train_2017_bboxes.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(json_path, "r") as rj:
|
|
||||||
f = json.load(rj)
|
|
||||||
|
|
||||||
categories = list()
|
|
||||||
image_info = dict()
|
|
||||||
|
|
||||||
for category in f["categories"]:
|
|
||||||
if category["supercategory"] == "Aves":
|
|
||||||
if category['name'] in ['Archilochus colubris']:#,'Icterus galbula']:
|
|
||||||
print(category['name'])
|
|
||||||
categories.append(category)
|
|
||||||
|
|
||||||
categories = sorted(categories, key=lambda k: k["name"])
|
|
||||||
for idx, cat in enumerate(categories):
|
|
||||||
cat["new_id"] = idx + 1
|
|
||||||
|
|
||||||
orig_to_new_id = create_map(categories, "id", "new_id")
|
|
||||||
|
|
||||||
for annot in f["annotations"]:
|
|
||||||
if annot["category_id"] in orig_to_new_id:
|
|
||||||
annot["new_category_id"] = orig_to_new_id[annot["category_id"]]
|
|
||||||
id = annot["image_id"]
|
|
||||||
if id not in image_info:
|
|
||||||
image_info[id] = dict()
|
|
||||||
|
|
||||||
annot["bbox"][2] += annot["bbox"][0]
|
|
||||||
annot["bbox"][3] += annot["bbox"][1]
|
|
||||||
image_info[id]["annotation"] = annot
|
|
||||||
|
|
||||||
for img in f["images"]:
|
|
||||||
id = img["id"]
|
|
||||||
path = os.path.join(PATH_ROOT, img["file_name"])
|
|
||||||
height = img["height"]
|
|
||||||
width = img["width"]
|
|
||||||
if id in image_info:
|
|
||||||
image_info[id].update({"path": path, "height": height, "width": width})
|
|
||||||
|
|
||||||
for idx, (id, im_in) in enumerate(image_info.items()):
|
|
||||||
im_in["idx"] = idx
|
|
||||||
self.images = image_info
|
|
||||||
self.categories = categories
|
|
||||||
self.idx_to_id = [x for x in self.images]
|
|
||||||
self.num_classes = len(self.categories) + 1
|
|
||||||
self.num_samples = len(self.images)
|
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_samples
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
idd = self.idx_to_id[idx]
|
|
||||||
c_image = self.images[idd]
|
|
||||||
img_path = c_image["path"]
|
|
||||||
img = Image.open(img_path).convert("RGB")
|
|
||||||
|
|
||||||
annot = c_image["annotation"]
|
|
||||||
bbox = annot["bbox"]
|
|
||||||
boxes = bbox
|
|
||||||
target = dict()
|
|
||||||
target["boxes"] = torch.as_tensor([boxes])
|
|
||||||
target["labels"] = torch.as_tensor([annot["new_category_id"]], dtype=torch.int64)
|
|
||||||
target['image_id'] = torch.tensor([annot['image_id']])
|
|
||||||
target['area'] = torch.as_tensor([annot['area']])
|
|
||||||
target['iscrowd'] = torch.zeros((1,), dtype=torch.int64)
|
|
||||||
|
|
||||||
|
|
||||||
if self.transforms is not None:
|
|
||||||
img, target = self.transforms(img, target)
|
|
||||||
|
|
||||||
return img, target
|
|
||||||
# %%
|
|
||||||
# v = iNaturalistDataset(validation=True)
|
|
||||||
|
|
||||||
|
|
||||||
# v = iNaturalistDataset(validation= True)
|
|
||||||
# o = v[10]
|
|
||||||
# %%
|
|
||||||
# oimage = t.tensor(o[0]*255, dtype=t.uint8)
|
|
||||||
# import matplotlib.pyplot as plt
|
|
||||||
# ox = draw_bounding_boxes(oimage, o[1]['boxes'], width=1)
|
|
||||||
# plt.imshow(ox.permute([1,2,0]))
|
|
||||||
# plt.savefig('crap2.png')
|
|
||||||
|
|
||||||
def get_model(num_classes):
|
|
||||||
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
|
|
||||||
num_classes = 2 # 1 class (person) + background
|
num_classes = 2 # 1 class (person) + background
|
||||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
import transforms as T
|
|
||||||
|
|
||||||
def get_transform(train):
|
|
||||||
transforms = []
|
|
||||||
transforms.append(T.ToTensor())
|
|
||||||
if train:
|
|
||||||
transforms.append(T.RandomHorizontalFlip(0.5))
|
|
||||||
return T.Compose(transforms)
|
|
||||||
|
|
||||||
from engine import train_one_epoch, evaluate
|
|
||||||
import utils
|
|
||||||
# %%
|
|
||||||
def run():
|
|
||||||
val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True))
|
|
||||||
train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False))
|
|
||||||
|
|
||||||
|
|
||||||
train_data_loader = torch.utils.data.DataLoader(
|
|
||||||
train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
|
||||||
)
|
|
||||||
val_data_loader = torch.utils.data.DataLoader(
|
|
||||||
val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
import torchvision
|
|
||||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
||||||
num_classes = 2
|
|
||||||
|
|
||||||
|
|
||||||
model = get_model(num_classes)
|
|
||||||
model.to(device)
|
|
||||||
# construct an optimizer
|
|
||||||
params = [p for p in model.parameters() if p.requires_grad]
|
|
||||||
optimizer = torch.optim.SGD(params, lr=0.005,
|
|
||||||
momentum=0.9, weight_decay=0.0005)
|
|
||||||
# and a learning rate scheduler
|
|
||||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
|
|
||||||
step_size=3,
|
|
||||||
gamma=0.1)
|
|
||||||
|
|
||||||
# let's train it for 10 epochs
|
|
||||||
num_epochs = 10
|
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
print(epoch)
|
|
||||||
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth')
|
|
||||||
# train for one epoch, printing every 10 iterations
|
|
||||||
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
|
|
||||||
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
|
|
||||||
# update the learning rate
|
|
||||||
lr_scheduler.step()
|
|
||||||
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth')
|
|
||||||
# evaluate on the test dataset
|
|
||||||
engine.evaluate(model, val_data_loader, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# # %%
|
|
||||||
# json_path = os.path.join(
|
|
||||||
# PATH_ROOT, r"train_2017_bboxes\train_2017_bboxes.json"
|
|
||||||
# )
|
|
||||||
# with open(json_path, "r") as rj:
|
|
||||||
# f = json.load(rj)
|
|
||||||
|
|
||||||
|
|
||||||
# # %%
|
|
||||||
# image_id: 2358
|
|
||||||
|
|
||||||
|
|||||||
52
train.py
52
train.py
@@ -0,0 +1,52 @@
|
|||||||
|
# %%
|
||||||
|
from engine import train_one_epoch, evaluate
|
||||||
|
from model import Model
|
||||||
|
from data import iNaturalistDataset
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
if not os.path.exists('models/'):
|
||||||
|
os.mkdirs('models')
|
||||||
|
|
||||||
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
def run():
|
||||||
|
val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True))
|
||||||
|
train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False))
|
||||||
|
|
||||||
|
|
||||||
|
train_data_loader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
||||||
|
)
|
||||||
|
val_data_loader = torch.utils.data.DataLoader(
|
||||||
|
val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
num_classes = 5
|
||||||
|
model = Model(num_classes)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
params = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
optimizer = torch.optim.SGD(params, lr=0.005,
|
||||||
|
momentum=0.9, weight_decay=0.0005)
|
||||||
|
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
|
||||||
|
step_size=3,
|
||||||
|
gamma=0.1)
|
||||||
|
|
||||||
|
num_epochs = 10
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(epoch)
|
||||||
|
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth')
|
||||||
|
# train for one epoch, printing every 10 iterations
|
||||||
|
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
|
||||||
|
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
|
||||||
|
# update the learning rate
|
||||||
|
lr_scheduler.step()
|
||||||
|
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth')
|
||||||
|
# evaluate on the test dataset
|
||||||
|
engine.evaluate(model, val_data_loader, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
Reference in New Issue
Block a user