This commit is contained in:
2021-07-01 20:26:24 -04:00
parent 8b02bf9a8c
commit f46d193826
16 changed files with 433 additions and 146 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

24
bear_utils.py Normal file
View File

@@ -0,0 +1,24 @@
from math import sin, cos, sqrt, atan2, radians
def get_distance_from_home(lat_b, lon_b):
lat_a = 42.295940
lon_a = -83.751960
return distance_lat_lon(lat_a, lon_a, lat_b, lon_b)
def distance_lat_lon(lat_a, lon_a, lat_b, lon_b):
R = 6373.0
R = 6373.0
lat1 = radians(lat_a)
lon1 = radians(lon_a)
lat2 = radians(lat_b)
lon2 = radians(lon_b)
dlon = lon2 - lon1
dlat = lat2 - lat1
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
c = 2 * atan2(sqrt(a), sqrt(1 - a))
return R * c

0
config.py Normal file
View File

115
data.py
View File

@@ -1,27 +1,29 @@
# %%
import os
from unicodedata import category
import torch
from PIL import Image
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import sys
import json
import torch
from torchvision import transforms as T
import transforms as T
import os
import numpy as np
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if sys.platform == 'win32':
if sys.platform == "win32":
PATH_ROOT = r"D:\ishan\ml\inaturalist\\"
else:
raise NotImplementedError("Not defined for this platform")
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
trsf = []
trsf.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
trsf.append(T.RandomHorizontalFlip(0.5))
return T.Compose(trsf)
def create_map(list_in, from_key, to_key):
@@ -32,38 +34,43 @@ def create_map(list_in, from_key, to_key):
class iNaturalistDataset(torch.utils.data.Dataset):
def __init__(self, validation=False, train=False, transforms = None, species = None):
def __init__(self, validation=False, train=False, species=None):
self.validation = validation
self.train = train
if (self.train or self.validation) or (self.train and self.validation)
if (not self.train and not self.validation) or (self.train and self.validation):
raise Exception("Need to do either train or validation")
self.transforms = get_transform(self.train)
self.transform = get_transform(self.train)
if validation:
json_path = os.path.join(PATH_ROOT, "val_2017_bboxes","val_2017_bboxes.json")
json_path = os.path.join(
PATH_ROOT, "val_2017_bboxes", "val_2017_bboxes.json"
)
elif train:
json_path = os.path.join(
PATH_ROOT, "train_2017_bboxes","train_2017_bboxes.json"
PATH_ROOT, "train_2017_bboxes", "train_2017_bboxes.json"
)
with open(json_path, "r") as rj:
f = json.load(rj)
self.raw_data = f
categories = list()
image_info = dict()
orig_id_to_name = dict()
for category in f["categories"]:
do_add = False
orig_id_to_name[category["id"]] = category
if species is None:
do_add = True
if category['name'] in species:
print(category['name'])
elif category["name"] in species:
print(category["name"])
do_add = True
if do_add:
categories.append(category)
categories = sorted(categories, key=lambda k: k["name"])
@@ -96,11 +103,11 @@ class iNaturalistDataset(torch.utils.data.Dataset):
self.images = image_info
self.categories = categories
self.orig_id_to_name = orig_id_to_name
self.idx_to_id = [x for x in self.images]
self.num_classes = len(self.categories) + 1
self.num_samples = len(self.images)
def __len__(self):
return self.num_samples
@@ -115,12 +122,68 @@ class iNaturalistDataset(torch.utils.data.Dataset):
boxes = bbox
target = dict()
target["boxes"] = torch.as_tensor([boxes])
target["labels"] = torch.as_tensor([annot["new_category_id"]], dtype=torch.int64)
target['image_id'] = torch.tensor([annot['image_id']])
target['area'] = torch.as_tensor([annot['area']])
target['iscrowd'] = torch.zeros((1,), dtype=torch.int64)
target["labels"] = torch.as_tensor(
[annot["new_category_id"]], dtype=torch.int64
)
target["image_id"] = torch.tensor([annot["image_id"]])
target["area"] = torch.as_tensor([annot["area"]])
target["iscrowd"] = torch.zeros((1,), dtype=torch.int64)
if self.transforms is not None:
img, target = self.transforms(img, target)
if self.transform is not None:
img, target = self.transform(img, target)
return img, target
if False:
train_dataset = iNaturalistDataset(train=True)
loc_path = os.path.join(PATH_ROOT, "inat2017_locations", "train2017_locations.json")
with open(loc_path, "r") as lfile:
locs = json.load(lfile)
from bear_utils import get_distance_from_home
# %%
category_distances = dict()
inserts = 0
for loc in locs:
lat = loc["lat"]
lon = loc["lon"]
im_id = loc["id"]
if lat is None or lon is None:
continue
ff = get_distance_from_home(lat, lon)
if im_id in train_dataset.images:
inserts += 1
train_dataset.images[im_id]["distance"] = ff
category_id = train_dataset.images[im_id]["annotation"]["category_id"]
if category_id not in category_distances:
category_distances[category_id] = list()
category_distances[category_id].append(ff)
# %%
from EcoNameTranslator import to_common
for k, v in category_distances.items():
name = train_dataset.orig_id_to_name[k]
if np.average(np.asarray(v) < 250) > 0.1:
if name["supercategory"] == "Aves":
print(len(v), to_common([name["name"]]))
# %%
fc = sorted(
category_distances, key=lambda x: len(category_distances[x]), reverse=True
)
for x in fc:
cc = train_dataset.orig_id_to_name[x]
if cc["supercategory"] == "Aves":
ou = to_common([cc["name"]])
print(ou, len(category_distances[x]))
# %%

View File

@@ -1,11 +1,12 @@
# %%
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torchvision.models.detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def Model(num_classes):
model = fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2 # 1 class (person) + background
def Model(num_classes, model_type=None):
chosen_model = torchvision.models.detection.__dict__[model_type]
model = chosen_model(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model

View File

@@ -0,0 +1,22 @@
{
"categories": [
{
"supercategory": "Aves",
"id": 206,
"name": "Archilochus colubris",
"new_id": 1
},
{
"supercategory": "Aves",
"id": 4493,
"name": "Icterus galbula",
"new_id": 2
},
{
"supercategory": "Aves",
"id": 403,
"name": "Poecile atricapillus",
"new_id": 3
}
]
}

View File

@@ -0,0 +1,23 @@
{
"categories": [
{
"supercategory": "Aves",
"id": 206,
"name": "Archilochus colubris",
"new_id": 1
},
{
"supercategory": "Aves",
"id": 4493,
"name": "Icterus galbula",
"new_id": 2
},
{
"supercategory": "Aves",
"id": 403,
"name": "Poecile atricapillus",
"new_id": 3
}
],
"model_type": "fasterrcnn_mobilenet_v3_large_fpn"
}

View File

@@ -0,0 +1,23 @@
{
"categories": [
{
"supercategory": "Aves",
"id": 206,
"name": "Archilochus colubris",
"new_id": 1
},
{
"supercategory": "Aves",
"id": 4493,
"name": "Icterus galbula",
"new_id": 2
},
{
"supercategory": "Aves",
"id": 403,
"name": "Poecile atricapillus",
"new_id": 3
}
],
"model_type": "fasterrcnn_mobilenet_v3_large_fpn"
}

View File

@@ -0,0 +1,23 @@
{
"categories": [
{
"supercategory": "Aves",
"id": 206,
"name": "Archilochus colubris",
"new_id": 1
},
{
"supercategory": "Aves",
"id": 4493,
"name": "Icterus galbula",
"new_id": 2
},
{
"supercategory": "Aves",
"id": 403,
"name": "Poecile atricapillus",
"new_id": 3
}
],
"model_type": "fasterrcnn_mobilenet_v3_large_fpn"
}

View File

@@ -4,48 +4,78 @@ from model import Model
from data import iNaturalistDataset
import torch
import os
import time
import datetime as dt
import json
import utils
if not os.path.exists("models/"):
os.mkdir("models")
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model_root = "models/" + dt.datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = model_root + ".pth"
model_info = model_root + ".json"
species_list = set(["Poecile atricapillus", "Archilochus colubris", "Icterus galbula"])
model_type = "fasterrcnn_mobilenet_v3_large_fpn"
if not os.path.exists('models/'):
os.mkdirs('models')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def run():
val_dataset = iNaturalistDataset(validation=True, transforms = get_transform(train=True))
train_dataset = iNaturalistDataset(train=True, transforms = get_transform(train=False))
val_dataset = iNaturalistDataset(
validation=True,
species=species_list,
)
train_dataset = iNaturalistDataset(
train=True,
species=species_list,
)
with open(model_info, "w") as js_p:
json.dump(
{"categories": train_dataset.categories, "model_type": model_type},
js_p,
default=str,
indent=4,
)
train_data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
)
val_data_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=8, shuffle=True, num_workers=1, collate_fn=utils.collate_fn
train_dataset,
batch_size=8,
shuffle=True,
num_workers=4,
collate_fn=utils.collate_fn,
)
num_classes = 5
model = Model(num_classes)
val_data_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=8,
shuffle=True,
num_workers=4,
collate_fn=utils.collate_fn,
)
num_classes = len(species_list) + 1
model = Model(num_classes, model_type)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=3,
gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 10
for epoch in range(num_epochs):
print(epoch)
torch.save(model.state_dict(), 'model_weights_start_'+str(epoch)+ '.pth')
# train for one epoch, printing every 10 iterations
engine.train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=10)
torch.save(model.state_dict(), 'model_weights_post_train_'+str(epoch)+ '.pth')
# update the learning rate
train_one_epoch(
model, optimizer, train_data_loader, device, epoch, print_freq=10
)
lr_scheduler.step()
torch.save(model.state_dict(), 'model_weights_post_step_'+str(epoch)+ '.pth')
# evaluate on the test dataset
engine.evaluate(model, val_data_loader, device=device)
torch.save(model.state_dict(), model_path)
evaluate(model, val_data_loader, device=device)
if __name__ == "__main__":

View File

@@ -28,8 +28,9 @@ class Compose(object):
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
@@ -45,15 +46,23 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
class ToTensor(nn.Module):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.to_tensor(image)
return image, target
class RandomIoUCrop(nn.Module):
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40):
def __init__(
self,
min_scale: float = 0.3,
max_scale: float = 1.0,
min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0,
sampler_options: Optional[List[float]] = None,
trials: int = 40,
):
super().__init__()
# Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
self.min_scale = min_scale
@@ -65,14 +74,19 @@ class RandomIoUCrop(nn.Module):
self.options = sampler_options
self.trials = trials
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if target is None:
raise ValueError("The targets can't be None for this transform.")
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
raise ValueError(
"image should be 2/3 dimensional. Got {} dimensions.".format(
image.ndimension()
)
)
elif image.ndimension() == 2:
image = image.unsqueeze(0)
@@ -82,7 +96,9 @@ class RandomIoUCrop(nn.Module):
# sample an option
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
min_jaccard_overlap = self.options[idx]
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
if (
min_jaccard_overlap >= 1.0
): # a value larger than 1 encodes the leave as-is option
return image, target
for _ in range(self.trials):
@@ -106,14 +122,22 @@ class RandomIoUCrop(nn.Module):
# check for any valid boxes with centers within the crop area
cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
is_within_crop_area = (
(left < cx) & (cx < right) & (top < cy) & (cy < bottom)
)
if not is_within_crop_area.any():
continue
# check at least 1 box with jaccard limitations
boxes = target["boxes"][is_within_crop_area]
ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]],
dtype=boxes.dtype, device=boxes.device))
ious = torchvision.ops.boxes.box_iou(
boxes,
torch.tensor(
[[left, top, right, bottom]],
dtype=boxes.dtype,
device=boxes.device,
),
)
if ious.max() < min_jaccard_overlap:
continue
@@ -130,14 +154,21 @@ class RandomIoUCrop(nn.Module):
class RandomZoomOut(nn.Module):
def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5):
def __init__(
self,
fill: Optional[List[float]] = None,
side_range: Tuple[float, float] = (1.0, 4.0),
p: float = 0.5,
):
super().__init__()
if fill is None:
fill = [0., 0., 0.]
fill = [0.0, 0.0, 0.0]
self.fill = fill
self.side_range = side_range
if side_range[0] < 1. or side_range[0] > side_range[1]:
raise ValueError("Invalid canvas side range provided {}.".format(side_range))
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(
"Invalid canvas side range provided {}.".format(side_range)
)
self.p = p
@torch.jit.unused
@@ -146,11 +177,16 @@ class RandomZoomOut(nn.Module):
# We fake the type to make it work on JIT
return tuple(int(x) for x in self.fill) if is_pil else 0
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
raise ValueError(
"image should be 2/3 dimensional. Got {} dimensions.".format(
image.ndimension()
)
)
elif image.ndimension() == 2:
image = image.unsqueeze(0)
@@ -159,7 +195,9 @@ class RandomZoomOut(nn.Module):
orig_w, orig_h = F._get_image_size(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
r = self.side_range[0] + torch.rand(1) * (
self.side_range[1] - self.side_range[0]
)
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
@@ -176,9 +214,12 @@ class RandomZoomOut(nn.Module):
image = F.pad(image, [left, top, right, bottom], fill=fill)
if isinstance(image, torch.Tensor):
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \
image[..., :, (left + orig_w):] = v
v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(
-1, 1, 1
)
image[..., :top, :] = image[..., :, :left] = image[
..., (top + orig_h) :, :
] = image[..., :, (left + orig_w) :] = v
if target is not None:
target["boxes"][:, 0::2] += left
@@ -188,8 +229,14 @@ class RandomZoomOut(nn.Module):
class RandomPhotometricDistort(nn.Module):
def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5),
hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5):
def __init__(
self,
contrast: Tuple[float] = (0.5, 1.5),
saturation: Tuple[float] = (0.5, 1.5),
hue: Tuple[float] = (-0.05, 0.05),
brightness: Tuple[float] = (0.875, 1.125),
p: float = 0.5,
):
super().__init__()
self._brightness = T.ColorJitter(brightness=brightness)
self._contrast = T.ColorJitter(contrast=contrast)
@@ -197,11 +244,16 @@ class RandomPhotometricDistort(nn.Module):
self._saturation = T.ColorJitter(saturation=saturation)
self.p = p
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if isinstance(image, torch.Tensor):
if image.ndimension() not in {2, 3}:
raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension()))
raise ValueError(
"image should be 2/3 dimensional. Got {} dimensions.".format(
image.ndimension()
)
)
elif image.ndimension() == 2:
image = image.unsqueeze(0)

122
utils.py
View File

@@ -8,6 +8,8 @@ import torch
import torch.distributed as dist
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
@@ -32,7 +34,7 @@ class SmoothedValue(object):
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
@@ -67,7 +69,8 @@ class SmoothedValue(object):
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
value=self.value,
)
def all_gather(data):
@@ -130,15 +133,14 @@ class MetricLogger(object):
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
)
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
@@ -151,31 +153,35 @@ class MetricLogger(object):
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join([
log_msg = self.delimiter.join(
[
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join([
log_msg = self.delimiter.join(
[
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
@@ -185,22 +191,37 @@ class MetricLogger(object):
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
print(
"{} Total time: {} ({:.4f} s / it)".format(
header, total_time_str, total_time / len(iterable)
)
)
def collate_fn(batch):
@@ -208,7 +229,6 @@ def collate_fn(batch):
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
def f(x):
if x >= warmup_iters:
return 1
@@ -231,10 +251,11 @@ def setup_for_distributed(is_master):
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
@@ -271,25 +292,30 @@ def save_on_master(*args, **kwargs):
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
args.dist_backend = "nccl"
print(
"| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)