This commit is contained in:
2021-07-01 20:26:24 -04:00
parent 8b02bf9a8c
commit f46d193826
16 changed files with 433 additions and 146 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

24
bear_utils.py Normal file
View File

@@ -0,0 +1,24 @@
from math import sin, cos, sqrt, atan2, radians
def get_distance_from_home(lat_b, lon_b):
lat_a = 42.295940
lon_a = -83.751960
return distance_lat_lon(lat_a, lon_a, lat_b, lon_b)
def distance_lat_lon(lat_a, lon_a, lat_b, lon_b):
R = 6373.0
R = 6373.0
lat1 = radians(lat_a)
lon1 = radians(lon_a)
lat2 = radians(lat_b)
lon2 = radians(lon_b)
dlon = lon2 - lon1
dlat = lat2 - lat1
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
c = 2 * atan2(sqrt(a), sqrt(1 - a))
return R * c

0
config.py Normal file
View File

125
data.py
View File

@@ -1,27 +1,29 @@
# %% # %%
import os import os
from unicodedata import category
import torch import torch
from PIL import Image from PIL import Image
import torchvision import sys
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import json import json
import torch import torch
from torchvision import transforms as T import transforms as T
import os import os
import numpy as np
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if sys.platform == 'win32': if sys.platform == "win32":
PATH_ROOT = r"D:\ishan\ml\inaturalist\\" PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
else: else:
raise NotImplementedError("Not defined for this platform") raise NotImplementedError("Not defined for this platform")
def get_transform(train): def get_transform(train):
transforms = [] trsf = []
transforms.append(T.ToTensor()) trsf.append(T.ToTensor())
if train: if train:
transforms.append(T.RandomHorizontalFlip(0.5)) trsf.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms) return T.Compose(trsf)
def create_map(list_in, from_key, to_key): def create_map(list_in, from_key, to_key):
@@ -32,40 +34,45 @@ def create_map(list_in, from_key, to_key):
class iNaturalistDataset(torch.utils.data.Dataset): class iNaturalistDataset(torch.utils.data.Dataset):
def __init__(self, validation=False, train=False, transforms = None, species = None): def __init__(self, validation=False, train=False, species=None):
self.validation = validation self.validation = validation
self.train = train self.train = train
if (self.train or self.validation) or (self.train and self.validation) if (not self.train and not self.validation) or (self.train and self.validation):
raise Exception("Need to do either train or validation") raise Exception("Need to do either train or validation")
self.transforms = get_transform(self.train)
self.transform = get_transform(self.train)
if validation: if validation:
json_path = os.path.join(PATH_ROOT, "val_2017_bboxes","val_2017_bboxes.json") json_path = os.path.join(
PATH_ROOT, "val_2017_bboxes", "val_2017_bboxes.json"
)
elif train: elif train:
json_path = os.path.join( json_path = os.path.join(
PATH_ROOT, "train_2017_bboxes","train_2017_bboxes.json" PATH_ROOT, "train_2017_bboxes", "train_2017_bboxes.json"
) )
with open(json_path, "r") as rj: with open(json_path, "r") as rj:
f = json.load(rj) f = json.load(rj)
self.raw_data = f
categories = list() categories = list()
image_info = dict() image_info = dict()
orig_id_to_name = dict()
for category in f["categories"]: for category in f["categories"]:
do_add = False do_add = False
orig_id_to_name[category["id"]] = category
if species is None: if species is None:
do_add = True do_add = True
if category['name'] in species: elif category["name"] in species:
print(category['name']) print(category["name"])
do_add = True
if do_add:
categories.append(category) categories.append(category)
categories = sorted(categories, key=lambda k: k["name"]) categories = sorted(categories, key=lambda k: k["name"])
for idx, cat in enumerate(categories): for idx, cat in enumerate(categories):
cat["new_id"] = idx + 1 cat["new_id"] = idx + 1
@@ -93,13 +100,13 @@ class iNaturalistDataset(torch.utils.data.Dataset):
for idx, (id, im_in) in enumerate(image_info.items()): for idx, (id, im_in) in enumerate(image_info.items()):
im_in["idx"] = idx im_in["idx"] = idx
self.images = image_info self.images = image_info
self.categories = categories self.categories = categories
self.orig_id_to_name = orig_id_to_name
self.idx_to_id = [x for x in self.images] self.idx_to_id = [x for x in self.images]
self.num_classes = len(self.categories) + 1 self.num_classes = len(self.categories) + 1
self.num_samples = len(self.images) self.num_samples = len(self.images)
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples
@@ -109,18 +116,74 @@ class iNaturalistDataset(torch.utils.data.Dataset):
c_image = self.images[idd] c_image = self.images[idd]
img_path = c_image["path"] img_path = c_image["path"]
img = Image.open(img_path).convert("RGB") img = Image.open(img_path).convert("RGB")
annot = c_image["annotation"] annot = c_image["annotation"]
bbox = annot["bbox"] bbox = annot["bbox"]
boxes = bbox boxes = bbox
target = dict() target = dict()
target["boxes"] = torch.as_tensor([boxes]) target["boxes"] = torch.as_tensor([boxes])
target["labels"] = torch.as_tensor([annot["new_category_id"]], dtype=torch.int64) target["labels"] = torch.as_tensor(
target['image_id'] = torch.tensor([annot['image_id']]) [annot["new_category_id"]], dtype=torch.int64
target['area'] = torch.as_tensor([annot['area']]) )
target['iscrowd'] = torch.zeros((1,), dtype=torch.int64) target["image_id"] = torch.tensor([annot["image_id"]])
target["area"] = torch.as_tensor([annot["area"]])
if self.transforms is not None: target["iscrowd"] = torch.zeros((1,), dtype=torch.int64)
img, target = self.transforms(img, target)
return img, target if self.transform is not None:
img, target = self.transform(img, target)
return img, target
if False:
train_dataset = iNaturalistDataset(train=True)
loc_path = os.path.join(PATH_ROOT, "inat2017_locations", "train2017_locations.json")
with open(loc_path, "r") as lfile:
locs = json.load(lfile)
from bear_utils import get_distance_from_home
# %%
category_distances = dict()
inserts = 0
for loc in locs:
lat = loc["lat"]
lon = loc["lon"]
im_id = loc["id"]
if lat is None or lon is None:
continue
ff = get_distance_from_home(lat, lon)
if im_id in train_dataset.images:
inserts += 1
train_dataset.images[im_id]["distance"] = ff
category_id = train_dataset.images[im_id]["annotation"]["category_id"]
if category_id not in category_distances:
category_distances[category_id] = list()
category_distances[category_id].append(ff)
# %%
from EcoNameTranslator import to_common
for k, v in category_distances.items():
name = train_dataset.orig_id_to_name[k]
if np.average(np.asarray(v) < 250) > 0.1:
if name["supercategory"] == "Aves":
print(len(v), to_common([name["name"]]))
# %%
fc = sorted(
category_distances, key=lambda x: len(category_distances[x]), reverse=True
)
for x in fc:
cc = train_dataset.orig_id_to_name[x]
if cc["supercategory"] == "Aves":
ou = to_common([cc["name"]])
print(ou, len(category_distances[x]))
# %%

View File

@@ -1,11 +1,12 @@
# %% # %%
from torchvision.models.detection import fasterrcnn_resnet50_fpn import torchvision.models.detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def Model(num_classes):
model = fasterrcnn_resnet50_fpn(pretrained=True) def Model(num_classes, model_type=None):
num_classes = 2 # 1 class (person) + background chosen_model = torchvision.models.detection.__dict__[model_type]
model = chosen_model(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model return model

View File

@@ -0,0 +1,22 @@
{
"categories": [
{
"supercategory": "Aves",
"id": 206,
"name": "Archilochus colubris",
"new_id": 1
},
{
"supercategory": "Aves",
"id": 4493,
"name": "Icterus galbula",
"new_id": 2
},
{
"supercategory": "Aves",
"id": 403,
"name": "Poecile atricapillus",
"new_id": 3
}
]
}

View File

@@ -0,0 +1,23 @@
{
"categories": [
{
"supercategory": "Aves",
"id": 206,
"name": "Archilochus colubris",
"new_id": 1
},
{
"supercategory": "Aves",
"id": 4493,
"name": "Icterus galbula",
"new_id": 2
},
{
"supercategory": "Aves",
"id": 403,
"name": "Poecile atricapillus",
"new_id": 3
}
],
"model_type": "fasterrcnn_mobilenet_v3_large_fpn"
}

View File

@@ -0,0 +1,23 @@
{
"categories": [
{
"supercategory": "Aves",
"id": 206,
"name": "Archilochus colubris",
"new_id": 1
},
{
"supercategory": "Aves",
"id": 4493,
"name": "Icterus galbula",
"new_id": 2
},
{
"supercategory": "Aves",
"id": 403,
"name": "Poecile atricapillus",
"new_id": 3
}
],
"model_type": "fasterrcnn_mobilenet_v3_large_fpn"
}

View File

@@ -0,0 +1,23 @@
{
"categories": [
{
"supercategory": "Aves",
"id": 206,
"name": "Archilochus colubris",
"new_id": 1
},
{
"supercategory": "Aves",
"id": 4493,
"name": "Icterus galbula",
"new_id": 2
},
{
"supercategory": "Aves",
"id": 403,
"name": "Poecile atricapillus",
"new_id": 3
}
],
"model_type": "fasterrcnn_mobilenet_v3_large_fpn"
}

View File

@@ -4,49 +4,79 @@ from model import Model
from data import iNaturalistDataset from data import iNaturalistDataset
import torch import torch
import os import os
import time import datetime as dt
import json
import utils
if not os.path.exists("models/"):
os.mkdir("models")
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = model_root + ".pth"
model_info = model_root + ".json"
species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"])
model_type = "fasterrcnn_mobilenet_v3_large_fpn"
if not os.path.exists('models/'):
os.mkdirs('models')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def run(): def run():
val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True)) val_dataset = iNaturalistDataset(
train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False)) validation=True,
species=species_list,
)
train_dataset = iNaturalistDataset(
train=True,
species=species_list,
)
with open(model_info, "w") as js_p:
json.dump(
{"categories": train_dataset.categories, "model_type": model_type},
js_p,
default=str,
indent=4,
)
train_data_loader = torch.utils.data.DataLoader( train_data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn train_dataset,
) batch_size=8,
val_data_loader = torch.utils.data.DataLoader( shuffle=True,
val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn num_workers=4,
collate_fn=utils.collate_fn,
) )
num_classes = 5 val_data_loader = torch.utils.data.DataLoader(
model = Model(num_classes) val_dataset,
batch_size=8,
shuffle=True,
num_workers=4,
collate_fn=utils.collate_fn,
)
num_classes = len(species_list) + 1
model = Model(num_classes, model_type)
model.to(device) model.to(device)
params = [p for p in model.parameters() if p.requires_grad] params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
step_size=3,
gamma=0.1)
num_epochs = 10 num_epochs = 10
for epoch in range(num_epochs): for epoch in range(num_epochs):
print(epoch) train_one_epoch(
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth') model, optimizer, train_data_loader, device, epoch, print_freq=10
# train for one epoch, printing every 10 iterations )
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
# update the learning rate
lr_scheduler.step() lr_scheduler.step()
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth') torch.save(model.state_dict(), model_path)
# evaluate on the test dataset evaluate(model, val_data_loader, device=device)
engine.evaluate(model, val_data_loader, device=device)
if __name__ == "__main__": if __name__ == "__main__":
run() run()

View File

@@ -28,8 +28,9 @@ class Compose(object):
class RandomHorizontalFlip(T.RandomHorizontalFlip): class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
image = F.hflip(image) image = F.hflip(image)
if target is not None: if target is not None:
@@ -45,15 +46,23 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
class ToTensor(nn.Module): class ToTensor(nn.Module):
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.to_tensor(image) image = F.to_tensor(image)
return image, target return image, target
class RandomIoUCrop(nn.Module): class RandomIoUCrop(nn.Module):
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5, def __init__(
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40): self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__() super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale self.min_scale = min_scale
@@ -65,14 +74,19 @@ class RandomIoUCrop(nn.Module):
self.options = sampler_options self.options = sampler_options
self.trials = trials self.trials = trials
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if target is None: if target is None:
raise ValueError("The targets can't be None for this transform.") raise ValueError("The targets can't be None for this transform.")
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) raise ValueError(
"image should be 2/3 dimensional. Got {} dimensions.".format(
image.ndimension()
)
)
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
@@ -82,7 +96,9 @@ class RandomIoUCrop(nn.Module):
# sample an option # sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx] min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option if (
min_jaccard_overlap >= 1.0
): # a value larger than 1 encodes the leave as-is option
return image, target return image, target
for _ in range(self.trials): for _ in range(self.trials):
@@ -106,14 +122,22 @@ class RandomIoUCrop(nn.Module):
# check for any valid boxes with centers within the crop area # check for any valid boxes with centers within the crop area
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2]) cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3]) cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) is_within_crop_area = (
(left < cx) & (cx < right) & (top < cy) & (cy < bottom)
)
if not is_within_crop_area.any(): if not is_within_crop_area.any():
continue continue
# check at least 1 box with jaccard limitations # check at least 1 box with jaccard limitations
boxes = target["boxes"][is_within_crop_area] boxes = target["boxes"][is_within_crop_area]
ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]], ious = torchvision.ops.boxes.box_iou(
dtype=boxes.dtype, device=boxes.device)) boxes,
torch.tensor(
[[left, top, right, bottom]],
dtype=boxes.dtype,
device=boxes.device,
),
)
if ious.max() < min_jaccard_overlap: if ious.max() < min_jaccard_overlap:
continue continue
@@ -130,14 +154,21 @@ class RandomIoUCrop(nn.Module):
class RandomZoomOut(nn.Module): class RandomZoomOut(nn.Module):
def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5): def __init__(
self,
fill: Optional[List[float]] = None,
side_range: Tuple[float, float] = (1.0, 4.0),
p: float = 0.5,
):
super().__init__() super().__init__()
if fill is None: if fill is None:
fill = [0., 0., 0.] fill = [0.0, 0.0, 0.0]
self.fill = fill self.fill = fill
self.side_range = side_range self.side_range = side_range
if side_range[0] < 1. or side_range[0] > side_range[1]: if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError("Invalid canvas side range provided {}.".format(side_range)) raise ValueError(
"Invalid canvas side range provided {}.".format(side_range)
)
self.p = p self.p = p
@torch.jit.unused @torch.jit.unused
@@ -146,11 +177,16 @@ class RandomZoomOut(nn.Module):
# We fake the type to make it work on JIT # We fake the type to make it work on JIT
return tuple(int(x) for x in self.fill) if is_pil else 0 return tuple(int(x) for x in self.fill) if is_pil else 0
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) raise ValueError(
"image should be 2/3 dimensional. Got {} dimensions.".format(
image.ndimension()
)
)
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
@@ -159,7 +195,9 @@ class RandomZoomOut(nn.Module):
orig_w, orig_h = F._get_image_size(image) orig_w, orig_h = F._get_image_size(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (
self.side_range[1] - self.side_range[0]
)
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r) canvas_height = int(orig_h * r)
@@ -176,9 +214,12 @@ class RandomZoomOut(nn.Module):
image = F.pad(image, [left, top, right, bottom], fill=fill) image = F.pad(image, [left, top, right, bottom], fill=fill)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1) v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \ -1, 1, 1
image[..., :, (left + orig_w):] = v )
image[..., :top, :] = image[..., :, :left] = image[
..., (top + orig_h) :, :
] = image[..., :, (left + orig_w) :] = v
if target is not None: if target is not None:
target["boxes"][:, 0::2] += left target["boxes"][:, 0::2] += left
@@ -188,8 +229,14 @@ class RandomZoomOut(nn.Module):
class RandomPhotometricDistort(nn.Module): class RandomPhotometricDistort(nn.Module):
def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5), def __init__(
hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5): self,
contrast: Tuple[float] = (0.5, 1.5),
saturation: Tuple[float] = (0.5, 1.5),
hue: Tuple[float] = (-0.05, 0.05),
brightness: Tuple[float] = (0.875, 1.125),
p: float = 0.5,
):
super().__init__() super().__init__()
self._brightness = T.ColorJitter(brightness=brightness) self._brightness = T.ColorJitter(brightness=brightness)
self._contrast = T.ColorJitter(contrast=contrast) self._contrast = T.ColorJitter(contrast=contrast)
@@ -197,11 +244,16 @@ class RandomPhotometricDistort(nn.Module):
self._saturation = T.ColorJitter(saturation=saturation) self._saturation = T.ColorJitter(saturation=saturation)
self.p = p self.p = p
def forward(self, image: Tensor, def forward(
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}: if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) raise ValueError(
"image should be 2/3 dimensional. Got {} dimensions.".format(
image.ndimension()
)
)
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)

130
utils.py
View File

@@ -8,6 +8,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
class SmoothedValue(object): class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
window or the global series average. window or the global series average.
@@ -32,7 +34,7 @@ class SmoothedValue(object):
""" """
if not is_dist_avail_and_initialized(): if not is_dist_avail_and_initialized():
return return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier() dist.barrier()
dist.all_reduce(t) dist.all_reduce(t)
t = t.tolist() t = t.tolist()
@@ -67,7 +69,8 @@ class SmoothedValue(object):
avg=self.avg, avg=self.avg,
global_avg=self.global_avg, global_avg=self.global_avg,
max=self.max, max=self.max,
value=self.value) value=self.value,
)
def all_gather(data): def all_gather(data):
@@ -130,15 +133,14 @@ class MetricLogger(object):
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format( raise AttributeError(
type(self).__name__, attr)) "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
)
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
for name, meter in self.meters.items(): for name, meter in self.meters.items():
loss_str.append( loss_str.append("{}: {}".format(name, str(meter)))
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str) return self.delimiter.join(loss_str)
def synchronize_between_processes(self): def synchronize_between_processes(self):
@@ -151,31 +153,35 @@ class MetricLogger(object):
def log_every(self, iterable, print_freq, header=None): def log_every(self, iterable, print_freq, header=None):
i = 0 i = 0
if not header: if not header:
header = '' header = ""
start_time = time.time() start_time = time.time()
end = time.time() end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}') iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ':' + str(len(str(len(iterable)))) + 'd' space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available(): if torch.cuda.is_available():
log_msg = self.delimiter.join([ log_msg = self.delimiter.join(
header, [
'[{0' + space_fmt + '}/{1}]', header,
'eta: {eta}', "[{0" + space_fmt + "}/{1}]",
'{meters}', "eta: {eta}",
'time: {time}', "{meters}",
'data: {data}', "time: {time}",
'max mem: {memory:.0f}' "data: {data}",
]) "max mem: {memory:.0f}",
]
)
else: else:
log_msg = self.delimiter.join([ log_msg = self.delimiter.join(
header, [
'[{0' + space_fmt + '}/{1}]', header,
'eta: {eta}', "[{0" + space_fmt + "}/{1}]",
'{meters}', "eta: {eta}",
'time: {time}', "{meters}",
'data: {data}' "time: {time}",
]) "data: {data}",
]
)
MB = 1024.0 * 1024.0 MB = 1024.0 * 1024.0
for obj in iterable: for obj in iterable:
data_time.update(time.time() - end) data_time.update(time.time() - end)
@@ -185,22 +191,37 @@ class MetricLogger(object):
eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available(): if torch.cuda.is_available():
print(log_msg.format( print(
i, len(iterable), eta=eta_string, log_msg.format(
meters=str(self), i,
time=str(iter_time), data=str(data_time), len(iterable),
memory=torch.cuda.max_memory_allocated() / MB)) eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else: else:
print(log_msg.format( print(
i, len(iterable), eta=eta_string, log_msg.format(
meters=str(self), i,
time=str(iter_time), data=str(data_time))) len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1 i += 1
end = time.time() end = time.time()
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format( print(
header, total_time_str, total_time / len(iterable))) "{} Total time: {} ({:.4f} s / it)".format(
header, total_time_str, total_time / len(iterable)
)
)
def collate_fn(batch): def collate_fn(batch):
@@ -208,7 +229,6 @@ def collate_fn(batch):
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
def f(x): def f(x):
if x >= warmup_iters: if x >= warmup_iters:
return 1 return 1
@@ -231,10 +251,11 @@ def setup_for_distributed(is_master):
This function disables printing when not in master process This function disables printing when not in master process
""" """
import builtins as __builtin__ import builtins as __builtin__
builtin_print = __builtin__.print builtin_print = __builtin__.print
def print(*args, **kwargs): def print(*args, **kwargs):
force = kwargs.pop('force', False) force = kwargs.pop("force", False)
if is_master or force: if is_master or force:
builtin_print(*args, **kwargs) builtin_print(*args, **kwargs)
@@ -271,25 +292,30 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args): def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"]) args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE']) args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ['LOCAL_RANK']) args.gpu = int(os.environ["LOCAL_RANK"])
elif 'SLURM_PROCID' in os.environ: elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ['SLURM_PROCID']) args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count() args.gpu = args.rank % torch.cuda.device_count()
else: else:
print('Not using distributed mode') print("Not using distributed mode")
args.distributed = False args.distributed = False
return return
args.distributed = True args.distributed = True
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl' args.dist_backend = "nccl"
print('| distributed init (rank {}): {}'.format( print(
args.rank, args.dist_url), flush=True) "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, )
world_size=args.world_size, rank=args.rank) torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
torch.distributed.barrier() torch.distributed.barrier()
setup_for_distributed(args.rank == 0) setup_for_distributed(args.rank == 0)