13 lines
446 B
Python
13 lines
446 B
Python
# %%
|
|
import torchvision.models.detection
|
|
|
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
|
|
|
|
def Model(num_classes, model_type=None):
|
|
chosen_model = torchvision.models.detection.__dict__[model_type]
|
|
model = chosen_model(pretrained=True)
|
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
|
return model
|