# %% import os import torch from PIL import Image import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor import json import torch from torchvision import transforms as T import os device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') if sys.platform == 'win32': PATH_ROOT = r"D:\ishan\ml\inaturalist\\" else: raise NotImplementedError("Not defined for this platform") def get_transform(train): transforms = [] transforms.append(T.ToTensor()) if train: transforms.append(T.RandomHorizontalFlip(0.5)) return T.Compose(transforms) def create_map(list_in, from_key, to_key): cmap = dict() for l in list_in: cmap[l[from_key]] = l[to_key] return cmap class iNaturalistDataset(torch.utils.data.Dataset): def __init__(self, validation=False, train=False, transforms = None, species = None): self.validation = validation self.train = train if (self.train or self.validation) or (self.train and self.validation) raise Exception("Need to do either train or validation") self.transforms = get_transform(self.train) if validation: json_path = os.path.join(PATH_ROOT, "val_2017_bboxes","val_2017_bboxes.json") elif train: json_path = os.path.join( PATH_ROOT, "train_2017_bboxes","train_2017_bboxes.json" ) with open(json_path, "r") as rj: f = json.load(rj) categories = list() image_info = dict() for category in f["categories"]: do_add = False if species is None: do_add = True if category['name'] in species: print(category['name']) categories.append(category) categories = sorted(categories, key=lambda k: k["name"]) for idx, cat in enumerate(categories): cat["new_id"] = idx + 1 orig_to_new_id = create_map(categories, "id", "new_id") for annot in f["annotations"]: if annot["category_id"] in orig_to_new_id: annot["new_category_id"] = orig_to_new_id[annot["category_id"]] id = annot["image_id"] if id not in image_info: image_info[id] = dict() annot["bbox"][2] += annot["bbox"][0] annot["bbox"][3] += annot["bbox"][1] image_info[id]["annotation"] = annot for img in f["images"]: id = img["id"] path = os.path.join(PATH_ROOT, img["file_name"]) height = img["height"] width = img["width"] if id in image_info: image_info[id].update({"path": path, "height": height, "width": width}) for idx, (id, im_in) in enumerate(image_info.items()): im_in["idx"] = idx self.images = image_info self.categories = categories self.idx_to_id = [x for x in self.images] self.num_classes = len(self.categories) + 1 self.num_samples = len(self.images) def __len__(self): return self.num_samples def __getitem__(self, idx): idd = self.idx_to_id[idx] c_image = self.images[idd] img_path = c_image["path"] img = Image.open(img_path).convert("RGB") annot = c_image["annotation"] bbox = annot["bbox"] boxes = bbox target = dict() target["boxes"] = torch.as_tensor([boxes]) target["labels"] = torch.as_tensor([annot["new_category_id"]], dtype=torch.int64) target['image_id'] = torch.tensor([annot['image_id']]) target['area'] = torch.as_tensor([annot['area']]) target['iscrowd'] = torch.zeros((1,), dtype=torch.int64) if self.transforms is not None: img, target = self.transforms(img, target) return img, target