Add RT-DETR

This commit is contained in:
Marcos Luciano
2023-11-01 18:42:56 -03:00
parent 1177624dd2
commit 5af9da189d

View File

@@ -20,17 +20,6 @@ class DeepStreamOutput(nn.Module):
scores, classes = torch.max(x['pred_logits'], 2, keepdim=True) scores, classes = torch.max(x['pred_logits'], 2, keepdim=True)
classes = classes.float() classes = classes.float()
return boxes, scores, classes return boxes, scores, classes
class DeepStreamInput(nn.Module):
def __init__(self, img_size, device):
self.img_size = img_size
self.device = device
super().__init__()
def forward(self, x):
size = torch.tensor([[*self.img_size]]).to(self.device)
return x, size
def suppress_warnings(): def suppress_warnings():