126 lines
4.0 KiB
Python
126 lines
4.0 KiB
Python
# %%
|
|
import os
|
|
import torch
|
|
from PIL import Image
|
|
import torchvision
|
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
import json
|
|
import torch
|
|
from torchvision import transforms as T
|
|
import os
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
|
|
if sys.platform == 'win32':
|
|
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
|
|
else:
|
|
raise NotImplementedError("Not defined for this platform")
|
|
|
|
def get_transform(train):
|
|
transforms = []
|
|
transforms.append(T.ToTensor())
|
|
if train:
|
|
transforms.append(T.RandomHorizontalFlip(0.5))
|
|
return T.Compose(transforms)
|
|
|
|
|
|
def create_map(list_in, from_key, to_key):
|
|
cmap = dict()
|
|
for l in list_in:
|
|
cmap[l[from_key]] = l[to_key]
|
|
return cmap
|
|
|
|
|
|
class iNaturalistDataset(torch.utils.data.Dataset):
|
|
def __init__(self, validation=False, train=False, transforms = None, species = None):
|
|
|
|
self.validation = validation
|
|
self.train = train
|
|
|
|
if (self.train or self.validation) or (self.train and self.validation)
|
|
raise Exception("Need to do either train or validation")
|
|
|
|
self.transforms = get_transform(self.train)
|
|
|
|
|
|
if validation:
|
|
json_path = os.path.join(PATH_ROOT, "val_2017_bboxes","val_2017_bboxes.json")
|
|
elif train:
|
|
json_path = os.path.join(
|
|
PATH_ROOT, "train_2017_bboxes","train_2017_bboxes.json"
|
|
)
|
|
|
|
with open(json_path, "r") as rj:
|
|
f = json.load(rj)
|
|
|
|
categories = list()
|
|
image_info = dict()
|
|
|
|
|
|
|
|
for category in f["categories"]:
|
|
do_add = False
|
|
if species is None:
|
|
do_add = True
|
|
if category['name'] in species:
|
|
print(category['name'])
|
|
categories.append(category)
|
|
|
|
categories = sorted(categories, key=lambda k: k["name"])
|
|
for idx, cat in enumerate(categories):
|
|
cat["new_id"] = idx + 1
|
|
|
|
orig_to_new_id = create_map(categories, "id", "new_id")
|
|
|
|
for annot in f["annotations"]:
|
|
if annot["category_id"] in orig_to_new_id:
|
|
annot["new_category_id"] = orig_to_new_id[annot["category_id"]]
|
|
id = annot["image_id"]
|
|
if id not in image_info:
|
|
image_info[id] = dict()
|
|
|
|
annot["bbox"][2] += annot["bbox"][0]
|
|
annot["bbox"][3] += annot["bbox"][1]
|
|
image_info[id]["annotation"] = annot
|
|
|
|
for img in f["images"]:
|
|
id = img["id"]
|
|
path = os.path.join(PATH_ROOT, img["file_name"])
|
|
height = img["height"]
|
|
width = img["width"]
|
|
if id in image_info:
|
|
image_info[id].update({"path": path, "height": height, "width": width})
|
|
|
|
for idx, (id, im_in) in enumerate(image_info.items()):
|
|
im_in["idx"] = idx
|
|
|
|
self.images = image_info
|
|
self.categories = categories
|
|
self.idx_to_id = [x for x in self.images]
|
|
self.num_classes = len(self.categories) + 1
|
|
self.num_samples = len(self.images)
|
|
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
idd = self.idx_to_id[idx]
|
|
c_image = self.images[idd]
|
|
img_path = c_image["path"]
|
|
img = Image.open(img_path).convert("RGB")
|
|
|
|
annot = c_image["annotation"]
|
|
bbox = annot["bbox"]
|
|
boxes = bbox
|
|
target = dict()
|
|
target["boxes"] = torch.as_tensor([boxes])
|
|
target["labels"] = torch.as_tensor([annot["new_category_id"]], dtype=torch.int64)
|
|
target['image_id'] = torch.tensor([annot['image_id']])
|
|
target['area'] = torch.as_tensor([annot['area']])
|
|
target['iscrowd'] = torch.zeros((1,), dtype=torch.int64)
|
|
|
|
if self.transforms is not None:
|
|
img, target = self.transforms(img, target)
|
|
|
|
return img, target |