Files
inaturalist_pytorch_model/data.py
2021-07-01 12:41:54 -04:00

212 lines
7.0 KiB
Python

# %%
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from collections import defaultdict as ddict
import json
import torch
from torchvision import datasets, transforms as T
import cv2
import numpy as np
import os
import sys
sys.path.append(r"K:\Designs\ML\inaturalist_models\data_aug")
sys.path.append(r"K:\Designs\ML\inaturalist_models\vision")
from references.detection import utils, engine
import data_aug
import bbox_util
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
IMAGE_MEAN = [0.485, 0.456, 0.406]
IMAGE_STD = [0.229, 0.224, 0.225]
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def create_map(list_in, from_key, to_key):
cmap = dict()
for l in list_in:
cmap[l[from_key]] = l[to_key]
return cmap
class iNaturalistDataset(torch.utils.data.Dataset):
def __init__(self, validation=False, train=False):
self.validation = validation
self.train = train
self.transforms = T.Compose([T.Resize(600, max_size=1024), T.ToTensor()])
if validation:
json_path = os.path.join(PATH_ROOT, r"val_2017_bboxes\val_2017_bboxes.json")
elif train:
json_path = os.path.join(
PATH_ROOT, r"train_2017_bboxes\train_2017_bboxes.json"
)
with open(json_path, "r") as rj:
f = json.load(rj)
categories = list()
image_info = dict()
for category in f["categories"]:
if category["supercategory"] == "Aves":
if category['name'] in ['Archilochus colubris','Icterus galbula']:
print(category['name'])
categories.append(category)
categories = sorted(categories, key=lambda k: k["name"])
for idx, cat in enumerate(categories):
cat["new_id"] = idx + 1
orig_to_new_id = create_map(categories, "id", "new_id")
for annot in f["annotations"]:
if annot["category_id"] in orig_to_new_id:
annot["new_category_id"] = orig_to_new_id[annot["category_id"]]
id = annot["image_id"]
if id not in image_info:
image_info[id] = dict()
annot["bbox"][2] += annot["bbox"][0]
annot["bbox"][3] += annot["bbox"][1]
image_info[id]["annotation"] = annot
for img in f["images"]:
id = img["id"]
path = os.path.join(PATH_ROOT, img["file_name"])
height = img["height"]
width = img["width"]
if id in image_info:
image_info[id].update({"path": path, "height": height, "width": width})
for idx, (id, im_in) in enumerate(image_info.items()):
im_in["idx"] = idx
self.images = image_info
self.categories = categories
self.idx_to_id = [x for x in self.images]
self.num_classes = len(self.categories) + 1
self.num_samples = len(self.images)
self.transforms = [
data_aug.RandomHorizontalFlip(0.5),
data_aug.Resize(600),
]
self.pre_transform = T.Compose([T.ToTensor()])#],T.Normalize(mean=[0.485, 0.456, 0.406],
#std=[0.229, 0.224, 0.225])])
def __len__(self):
return self.num_samples
def transform(self, img, bbox):
for x in self.transforms:
img, bbox = x(img, bbox)
img = self.pre_transform(img)
return img, bbox
def __getitem__(self, idx):
idd = self.idx_to_id[idx]
c_image = self.images[idd]
# print(c_image, idx, self.validation, self.train)
# breakpoint()
image = np.asarray(cv2.imread(c_image["path"])[:,:,::-1].copy(),dtype=np.float32)
annot = c_image["annotation"]
bbox = annot["bbox"]
bbox.append(annot["new_category_id"])
bbox = np.asarray([bbox], dtype=np.float32)
image, bbox = self.transform(image.copy(), bbox.copy())
boxes = torch.as_tensor(bbox[:,:4], dtype=torch.float32)
target = dict()
target["boxes"] = boxes
target["labels"] = torch.as_tensor([annot["new_category_id"]], dtype=torch.int64)
target['image_id'] = torch.tensor([annot['image_id']])
target['area'] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
target['iscrowd'] = torch.zeros((1,), dtype=torch.int64)
return image, target
# v = iNaturalistDataset(validation= True)
# o = v[10]
# %%
# oimage = t.tensor(o[0]*255, dtype=t.uint8)
# import matplotlib.pyplot as plt
# ox = draw_bounding_boxes(oimage, o[1]['boxes'], width=1)
# plt.imshow(ox.permute([1,2,0]))
# plt.savefig('crap2.png')
# %%
def run():
val_dataset = iNaturalistDataset(validation=True)
train_dataset = iNaturalistDataset(train=True)
train_data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=16, shuffle=True, num_workers=4, collate_fn=utils.collate_fn
)
val_data_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=16, shuffle=True, num_workers=4, collate_fn=utils.collate_fn
)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
pretrained=True, num_classes=train_dataset.num_classes, progress=True
)
model.to(device)
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
# let's train it for 10 epochs
num_epochs = 10
for epoch in range(num_epochs):
print(epoch)
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth')
# train for one epoch, printing every 10 iterations
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
# update the learning rate
lr_scheduler.step()
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth')
# evaluate on the test dataset
engine.evaluate(model, val_data_loader, device=device)
if __name__ == "__main__":
run()
# # %%
# json_path = os.path.join(
# PATH_ROOT, r"train_2017_bboxes\train_2017_bboxes.json"
# )
# with open(json_path, "r") as rj:
# f = json.load(rj)
# # %%
# image_id: 2358