import bisect from collections import defaultdict import copy from itertools import repeat, chain import math import numpy as np import torch import torch.utils.data from torch.utils.data.sampler import BatchSampler, Sampler from torch.utils.model_zoo import tqdm import torchvision from PIL import Image def _repeat_to_at_least(iterable, n): repeat_times = math.ceil(n / len(iterable)) repeated = chain.from_iterable(repeat(iterable, repeat_times)) return list(repeated) class GroupedBatchSampler(BatchSampler): """ Wraps another sampler to yield a mini-batch of indices. It enforces that the batch only contain elements from the same group. It also tries to provide mini-batches which follows an ordering which is as close as possible to the ordering from the original sampler. Args: sampler (Sampler): Base sampler. group_ids (list[int]): If the sampler produces indices in range [0, N), `group_ids` must be a list of `N` ints which contains the group id of each sample. The group ids must be a continuous set of integers starting from 0, i.e. they must be in the range [0, num_groups). batch_size (int): Size of mini-batch. """ def __init__(self, sampler, group_ids, batch_size): if not isinstance(sampler, Sampler): raise ValueError( "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler) ) self.sampler = sampler self.group_ids = group_ids self.batch_size = batch_size def __iter__(self): buffer_per_group = defaultdict(list) samples_per_group = defaultdict(list) num_batches = 0 for idx in self.sampler: group_id = self.group_ids[idx] buffer_per_group[group_id].append(idx) samples_per_group[group_id].append(idx) if len(buffer_per_group[group_id]) == self.batch_size: yield buffer_per_group[group_id] num_batches += 1 del buffer_per_group[group_id] assert len(buffer_per_group[group_id]) < self.batch_size # now we have run out of elements that satisfy # the group criteria, let's return the remaining # elements so that the size of the sampler is # deterministic expected_num_batches = len(self) num_remaining = expected_num_batches - num_batches if num_remaining > 0: # for the remaining batches, take first the buffers with largest number # of elements for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True): remaining = self.batch_size - len(buffer_per_group[group_id]) samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining) buffer_per_group[group_id].extend(samples_from_group_id[:remaining]) assert len(buffer_per_group[group_id]) == self.batch_size yield buffer_per_group[group_id] num_remaining -= 1 if num_remaining == 0: break assert num_remaining == 0 def __len__(self): return len(self.sampler) // self.batch_size def _compute_aspect_ratios_slow(dataset, indices=None): print("Your dataset doesn't support the fast path for " "computing the aspect ratios, so will iterate over " "the full dataset and load every image instead. " "This might take some time...") if indices is None: indices = range(len(dataset)) class SubsetSampler(Sampler): def __init__(self, indices): self.indices = indices def __iter__(self): return iter(self.indices) def __len__(self): return len(self.indices) sampler = SubsetSampler(indices) data_loader = torch.utils.data.DataLoader( dataset, batch_size=1, sampler=sampler, num_workers=14, # you might want to increase it for faster processing collate_fn=lambda x: x[0]) aspect_ratios = [] with tqdm(total=len(dataset)) as pbar: for _i, (img, _) in enumerate(data_loader): pbar.update(1) height, width = img.shape[-2:] aspect_ratio = float(width) / float(height) aspect_ratios.append(aspect_ratio) return aspect_ratios def _compute_aspect_ratios_custom_dataset(dataset, indices=None): if indices is None: indices = range(len(dataset)) aspect_ratios = [] for i in indices: height, width = dataset.get_height_and_width(i) aspect_ratio = float(width) / float(height) aspect_ratios.append(aspect_ratio) return aspect_ratios def _compute_aspect_ratios_coco_dataset(dataset, indices=None): if indices is None: indices = range(len(dataset)) aspect_ratios = [] for i in indices: img_info = dataset.coco.imgs[dataset.ids[i]] aspect_ratio = float(img_info["width"]) / float(img_info["height"]) aspect_ratios.append(aspect_ratio) return aspect_ratios def _compute_aspect_ratios_voc_dataset(dataset, indices=None): if indices is None: indices = range(len(dataset)) aspect_ratios = [] for i in indices: # this doesn't load the data into memory, because PIL loads it lazily width, height = Image.open(dataset.images[i]).size aspect_ratio = float(width) / float(height) aspect_ratios.append(aspect_ratio) return aspect_ratios def _compute_aspect_ratios_subset_dataset(dataset, indices=None): if indices is None: indices = range(len(dataset)) ds_indices = [dataset.indices[i] for i in indices] return compute_aspect_ratios(dataset.dataset, ds_indices) def compute_aspect_ratios(dataset, indices=None): if hasattr(dataset, "get_height_and_width"): return _compute_aspect_ratios_custom_dataset(dataset, indices) if isinstance(dataset, torchvision.datasets.CocoDetection): return _compute_aspect_ratios_coco_dataset(dataset, indices) if isinstance(dataset, torchvision.datasets.VOCDetection): return _compute_aspect_ratios_voc_dataset(dataset, indices) if isinstance(dataset, torch.utils.data.Subset): return _compute_aspect_ratios_subset_dataset(dataset, indices) # slow path return _compute_aspect_ratios_slow(dataset, indices) def _quantize(x, bins): bins = copy.deepcopy(bins) bins = sorted(bins) quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) return quantized def create_aspect_ratio_groups(dataset, k=0): aspect_ratios = compute_aspect_ratios(dataset) bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0] groups = _quantize(aspect_ratios, bins) # count number of elements per group counts = np.unique(groups, return_counts=True)[1] fbins = [0] + bins + [np.inf] print("Using {} as bins for aspect ratio quantization".format(fbins)) print("Count of instances per bin: {}".format(counts)) return groups