initial
This commit is contained in:
37
presets.py
Normal file
37
presets.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import transforms as T
|
||||
|
||||
|
||||
class DetectionPresetTrain:
|
||||
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
|
||||
if data_augmentation == 'hflip':
|
||||
self.transforms = T.Compose([
|
||||
T.RandomHorizontalFlip(p=hflip_prob),
|
||||
T.ToTensor(),
|
||||
])
|
||||
elif data_augmentation == 'ssd':
|
||||
self.transforms = T.Compose([
|
||||
T.RandomPhotometricDistort(),
|
||||
T.RandomZoomOut(fill=list(mean)),
|
||||
T.RandomIoUCrop(),
|
||||
T.RandomHorizontalFlip(p=hflip_prob),
|
||||
T.ToTensor(),
|
||||
])
|
||||
elif data_augmentation == 'ssdlite':
|
||||
self.transforms = T.Compose([
|
||||
T.RandomIoUCrop(),
|
||||
T.RandomHorizontalFlip(p=hflip_prob),
|
||||
T.ToTensor(),
|
||||
])
|
||||
else:
|
||||
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
|
||||
|
||||
def __call__(self, img, target):
|
||||
return self.transforms(img, target)
|
||||
|
||||
|
||||
class DetectionPresetEval:
|
||||
def __init__(self):
|
||||
self.transforms = T.ToTensor()
|
||||
|
||||
def __call__(self, img, target):
|
||||
return self.transforms(img, target)
|
||||
Reference in New Issue
Block a user