# %% import torchvision.models.detection from torchvision.models.detection.faster_rcnn import FastRCNNPredictor def Model(num_classes, model_type=None): chosen_model = torchvision.models.detection.__dict__[model_type] model = chosen_model(pretrained=True) in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model