Files
2021-07-01 20:26:24 -04:00

13 lines
446 B
Python

# %%
import torchvision.models.detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def Model(num_classes, model_type=None):
chosen_model = torchvision.models.detection.__dict__[model_type]
model = chosen_model(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model