# %% from torchvision.models.detection import fasterrcnn_resnet50_fpn from torchvision.models.detection.faster_rcnn import FastRCNNPredictor def Model(num_classes): model = fasterrcnn_resnet50_fpn(pretrained=True) num_classes = 2 # 1 class (person) + background in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model