Add RT-DETR
This commit is contained in:
@@ -20,17 +20,6 @@ class DeepStreamOutput(nn.Module):
|
|||||||
scores, classes = torch.max(x['pred_logits'], 2, keepdim=True)
|
scores, classes = torch.max(x['pred_logits'], 2, keepdim=True)
|
||||||
classes = classes.float()
|
classes = classes.float()
|
||||||
return boxes, scores, classes
|
return boxes, scores, classes
|
||||||
|
|
||||||
|
|
||||||
class DeepStreamInput(nn.Module):
|
|
||||||
def __init__(self, img_size, device):
|
|
||||||
self.img_size = img_size
|
|
||||||
self.device = device
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
size = torch.tensor([[*self.img_size]]).to(self.device)
|
|
||||||
return x, size
|
|
||||||
|
|
||||||
|
|
||||||
def suppress_warnings():
|
def suppress_warnings():
|
||||||
|
|||||||
Reference in New Issue
Block a user