Files
inaturalist_pytorch_model/model.py
2021-07-01 15:41:04 -04:00

12 lines
451 B
Python

# %%
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def Model(num_classes):
model = fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2 # 1 class (person) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model