12 lines
451 B
Python
12 lines
451 B
Python
# %%
|
|
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
|
|
def Model(num_classes):
|
|
model = fasterrcnn_resnet50_fpn(pretrained=True)
|
|
num_classes = 2 # 1 class (person) + background
|
|
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
|
return model
|
|
|