Fix ONNX export

This commit is contained in:
Marcos Luciano
2023-05-29 21:54:43 -03:00
parent 141c0f2fee
commit b2c4bee8dc
9 changed files with 9 additions and 9 deletions

View File

@@ -18,7 +18,7 @@ class DeepStreamOutput(nn.Module):
def forward(self, x): def forward(self, x):
boxes = x[1] boxes = x[1]
scores, classes = torch.max(x[0], 2, keepdim=True) scores, classes = torch.max(x[0], 2, keepdim=True)
return torch.cat((boxes, scores, classes), dim=2) return torch.cat((boxes, scores, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():

View File

@@ -19,7 +19,7 @@ class DeepStreamOutput(nn.Module):
boxes = x[:, :, :4] boxes = x[:, :, :4]
objectness = x[:, :, 4:5] objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
return torch.cat((boxes, scores * objectness, classes), dim=2) return torch.cat((boxes, scores * objectness, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():

View File

@@ -23,7 +23,7 @@ class DeepStreamOutput(nn.Module):
boxes = x[:, :, :4] boxes = x[:, :, :4]
objectness = x[:, :, 4:5] objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
return torch.cat((boxes, scores * objectness, classes), dim=2) return torch.cat((boxes, scores * objectness, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():

View File

@@ -19,7 +19,7 @@ class DeepStreamOutput(nn.Module):
boxes = x[:, :, :4] boxes = x[:, :, :4]
objectness = x[:, :, 4:5] objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
return torch.cat((boxes, scores * objectness, classes), dim=2) return torch.cat((boxes, scores * objectness, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():

View File

@@ -18,7 +18,7 @@ class DeepStreamOutput(nn.Module):
x = x.transpose(1, 2) x = x.transpose(1, 2)
boxes = x[:, :, :4] boxes = x[:, :, :4]
scores, classes = torch.max(x[:, :, 4:], 2, keepdim=True) scores, classes = torch.max(x[:, :, 4:], 2, keepdim=True)
return torch.cat((boxes, scores, classes), dim=2) return torch.cat((boxes, scores, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():

View File

@@ -19,7 +19,7 @@ class DeepStreamOutput(nn.Module):
x = x.transpose(1, 2) x = x.transpose(1, 2)
boxes = x[:, :, :4] boxes = x[:, :, :4]
scores, classes = torch.max(x[:, :, 4:], 2, keepdim=True) scores, classes = torch.max(x[:, :, 4:], 2, keepdim=True)
return torch.cat((boxes, scores, classes), dim=2) return torch.cat((boxes, scores, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():

View File

@@ -15,7 +15,7 @@ class DeepStreamOutput(nn.Module):
def forward(self, x): def forward(self, x):
boxes = x[0] boxes = x[0]
scores, classes = torch.max(x[1], 2, keepdim=True) scores, classes = torch.max(x[1], 2, keepdim=True)
return torch.cat((boxes, scores, classes), dim=2) return torch.cat((boxes, scores, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():

View File

@@ -16,7 +16,7 @@ class DeepStreamOutput(nn.Module):
boxes = x[:, :, :4] boxes = x[:, :, :4]
objectness = x[:, :, 4:5] objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
return torch.cat((boxes, scores * objectness, classes), dim=2) return torch.cat((boxes, scores * objectness, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():

View File

@@ -18,7 +18,7 @@ class DeepStreamOutput(nn.Module):
boxes = x[:, :, :4] boxes = x[:, :, :4]
objectness = x[:, :, 4:5] objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True) scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
return torch.cat((boxes, scores * objectness, classes), dim=2) return torch.cat((boxes, scores * objectness, classes.float()), dim=2)
def suppress_warnings(): def suppress_warnings():