This commit is contained in:
2021-07-01 20:26:24 -04:00
parent 8b02bf9a8c
commit f46d193826
16 changed files with 433 additions and 146 deletions

View File

@@ -1,11 +1,12 @@
# %%
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torchvision.models.detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def Model(num_classes):
model = fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2 # 1 class (person) + background
def Model(num_classes, model_type=None):
chosen_model = torchvision.models.detection.__dict__[model_type]
model = chosen_model(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model