yawc
This commit is contained in:
@@ -20,15 +20,15 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
|
||||
if epoch == 0:
|
||||
warmup_factor = 1. / 1000
|
||||
warmup_iters = min(1000, len(data_loader) - 1)
|
||||
|
||||
lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
|
||||
|
||||
|
||||
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
|
||||
images = list(image.to(device) for image in images)
|
||||
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||
|
||||
loss_dict = model(images, targets)
|
||||
|
||||
print('Hey I''m here')
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
|
||||
# reduce losses over all GPUs for logging purposes
|
||||
|
||||
Reference in New Issue
Block a user