import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) num_classes = 1 # 1 class (person) + background in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) import torchvision from torchvision.models.detection import FasterRCNN from torchvision.models.detection.rpn import AnchorGenerator backbone = torchvision.models.mobilenet_v2(pretrained=True).features backbone.out_channels = list(backbone.modules())[-3].out_channels anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)) roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], output_size=7, sampling_ratio=2) model = FasterRCNN(backbone, num_classes=2, rpn_anchor_generator=anchor_generator, box_roi_pool=roi_pooler) # %%