yacwc
This commit is contained in:
233
train.py
233
train.py
@@ -1,233 +0,0 @@
|
||||
r"""PyTorch Detection Training.
|
||||
|
||||
To run in a multi-gpu environment, use the distributed launcher::
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
|
||||
train.py ... --world-size $NGPU
|
||||
|
||||
The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu.
|
||||
--lr 0.02 --batch-size 2 --world-size 8
|
||||
If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU.
|
||||
|
||||
On top of that, for training Faster/Mask R-CNN, the default hyperparameters are
|
||||
--epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3
|
||||
|
||||
Also, if you train Keypoint R-CNN, the default hyperparameters are
|
||||
--epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3
|
||||
Because the number of images is smaller in the person keypoint subset of COCO,
|
||||
the number of epochs should be adapted so that we have the same number of iterations.
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
import torchvision.models.detection
|
||||
import torchvision.models.detection.mask_rcnn
|
||||
|
||||
from coco_utils import get_coco, get_coco_kp
|
||||
|
||||
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
|
||||
from engine import train_one_epoch, evaluate
|
||||
|
||||
import presets
|
||||
import utils
|
||||
|
||||
|
||||
def get_dataset(name, image_set, transform, data_path):
|
||||
paths = {
|
||||
"coco": (data_path, get_coco, 91),
|
||||
"coco_kp": (data_path, get_coco_kp, 2)
|
||||
}
|
||||
p, ds_fn, num_classes = paths[name]
|
||||
|
||||
ds = ds_fn(p, image_set=image_set, transforms=transform)
|
||||
return ds, num_classes
|
||||
|
||||
|
||||
def get_transform(train, data_augmentation):
|
||||
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()
|
||||
|
||||
|
||||
def get_args_parser(add_help=True):
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='PyTorch Detection Training', add_help=add_help)
|
||||
|
||||
parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
|
||||
parser.add_argument('--dataset', default='coco', help='dataset')
|
||||
parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
|
||||
parser.add_argument('--device', default='cuda', help='device')
|
||||
parser.add_argument('-b', '--batch-size', default=2, type=int,
|
||||
help='images per gpu, the total batch size is $NGPU x batch_size')
|
||||
parser.add_argument('--epochs', default=26, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--lr', default=0.02, type=float,
|
||||
help='initial learning rate, 0.02 is the default value for training '
|
||||
'on 8 gpus and 2 images_per_gpu')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)')
|
||||
parser.add_argument('--lr-step-size', default=8, type=int,
|
||||
help='decrease lr every step-size epochs (multisteplr scheduler only)')
|
||||
parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
|
||||
help='decrease lr every step-size epochs (multisteplr scheduler only)')
|
||||
parser.add_argument('--lr-gamma', default=0.1, type=float,
|
||||
help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)')
|
||||
parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
|
||||
parser.add_argument('--output-dir', default='.', help='path where to save')
|
||||
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
||||
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
|
||||
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
|
||||
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
|
||||
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
|
||||
help='number of trainable layers of backbone')
|
||||
parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)')
|
||||
parser.add_argument(
|
||||
"--sync-bn",
|
||||
dest="sync_bn",
|
||||
help="Use sync batch norm",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-only",
|
||||
dest="test_only",
|
||||
help="Only test the model",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained",
|
||||
dest="pretrained",
|
||||
help="Use pre-trained models from the modelzoo",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
# distributed training parameters
|
||||
parser.add_argument('--world-size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.output_dir:
|
||||
utils.mkdir(args.output_dir)
|
||||
|
||||
utils.init_distributed_mode(args)
|
||||
print(args)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# Data loading code
|
||||
print("Loading data")
|
||||
|
||||
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args.data_augmentation),
|
||||
args.data_path)
|
||||
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)
|
||||
|
||||
print("Creating data loaders")
|
||||
if args.distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
|
||||
else:
|
||||
train_sampler = torch.utils.data.RandomSampler(dataset)
|
||||
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
|
||||
|
||||
if args.aspect_ratio_group_factor >= 0:
|
||||
group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
|
||||
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
|
||||
else:
|
||||
train_batch_sampler = torch.utils.data.BatchSampler(
|
||||
train_sampler, args.batch_size, drop_last=True)
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
|
||||
collate_fn=utils.collate_fn)
|
||||
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test, batch_size=1,
|
||||
sampler=test_sampler, num_workers=args.workers,
|
||||
collate_fn=utils.collate_fn)
|
||||
|
||||
print("Creating model")
|
||||
kwargs = {
|
||||
"trainable_backbone_layers": args.trainable_backbone_layers
|
||||
}
|
||||
if "rcnn" in args.model:
|
||||
if args.rpn_score_thresh is not None:
|
||||
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
|
||||
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
|
||||
**kwargs)
|
||||
model.to(device)
|
||||
if args.distributed and args.sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
model_without_ddp = model
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
model_without_ddp = model.module
|
||||
|
||||
params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.SGD(
|
||||
params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
|
||||
args.lr_scheduler = args.lr_scheduler.lower()
|
||||
if args.lr_scheduler == 'multisteplr':
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
|
||||
elif args.lr_scheduler == 'cosineannealinglr':
|
||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
||||
else:
|
||||
raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR "
|
||||
"are supported.".format(args.lr_scheduler))
|
||||
|
||||
if args.resume:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
|
||||
if args.test_only:
|
||||
evaluate(model, data_loader_test, device=device)
|
||||
return
|
||||
|
||||
print("Start training")
|
||||
start_time = time.time()
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq)
|
||||
lr_scheduler.step()
|
||||
if args.output_dir:
|
||||
checkpoint = {
|
||||
'model': model_without_ddp.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'lr_scheduler': lr_scheduler.state_dict(),
|
||||
'args': args,
|
||||
'epoch': epoch
|
||||
}
|
||||
utils.save_on_master(
|
||||
checkpoint,
|
||||
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
|
||||
utils.save_on_master(
|
||||
checkpoint,
|
||||
os.path.join(args.output_dir, 'checkpoint.pth'))
|
||||
|
||||
# evaluate after every epoch
|
||||
evaluate(model, data_loader_test, device=device)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args_parser().parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user