yacwc
This commit is contained in:
11
model.py
11
model.py
@@ -1,11 +1,12 @@
|
||||
# %%
|
||||
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
||||
import torchvision.models.detection
|
||||
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
|
||||
def Model(num_classes):
|
||||
model = fasterrcnn_resnet50_fpn(pretrained=True)
|
||||
num_classes = 2 # 1 class (person) + background
|
||||
|
||||
def Model(num_classes, model_type=None):
|
||||
chosen_model = torchvision.models.detection.__dict__[model_type]
|
||||
model = chosen_model(pretrained=True)
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user