Fix logger error in DeepStream 6.0 / 6.0.1 + Change output classes format + Fixes

This commit is contained in:
Marcos Luciano
2023-06-08 13:47:43 -03:00
parent 9fd80c5248
commit 64fa573f72
23 changed files with 233 additions and 258 deletions

View File

@@ -43,30 +43,6 @@ Generate the ONNX model file (example for DAMO-YOLO-S*)
python3 export_damoyolo.py -w damoyolo_tinynasL25_S_477.pth -c configs/damoyolo_tinynasL25_S.py --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 11 or lower. The default opset is 11.
```
--opset 11
```
**NOTE**: To change the inference size (defaut: 640)
```
@@ -88,6 +64,30 @@ or
-s 1280 1280
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 11 or lower. The default opset is 11.
```
--opset 11
```
#### 5. Copy generated files
Copy the generated ONNX model file and labels.txt file (if generated) to the `DeepStream-Yolo` folder.

View File

@@ -41,13 +41,13 @@ pip3 install onnx onnxsim onnxruntime
python3 export_ppyoloe.py -w ppyoloe_plus_crn_s_80e_coco.pdparams -c configs/ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic

View File

@@ -46,30 +46,6 @@ Generate the ONNX model file (example for YOLO-NAS S)
python3 export_yolonas.py -m yolo_nas_s -w yolo_nas_s_coco.pth --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 14.
```
--opset 12
```
**NOTE**: Model names
```
@@ -88,6 +64,18 @@ or
-m yolo_nas_l
```
**NOTE**: Number of classes (example for 80 classes)
```
-n 80
```
or
```
--classes 80
```
**NOTE**: To change the inference size (defaut: 640)
```
@@ -109,6 +97,30 @@ or
-s 1280 1280
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 14.
```
--opset 12
```
#### 5. Copy generated file
Copy the generated ONNX model file to the `DeepStream-Yolo` folder.

View File

@@ -55,37 +55,13 @@ Generate the ONNX model file
python3 export_yolor.py -w yolor-p6.pt --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 12.
```
--opset 12
```
**NOTE**: To convert a P6 model
```
--p6
```
**NOTE**: To change the inference size (defaut: 640)
**NOTE**: To change the inference size (defaut: 640 / 1280 for `--p6` models)
```
-s SIZE
@@ -106,6 +82,30 @@ or
-s 1280 1280
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 12.
```
--opset 12
```
#### 5. Copy generated files
Copy the generated ONNX model file and labels.txt file (if generated) to the `DeepStream-Yolo` folder

View File

@@ -46,13 +46,13 @@ Generate the ONNX model file (example for YOLOX-s)
python3 export_yolox.py -w yolox_s.pth -c exps/default/yolox_s.py --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic

View File

@@ -47,37 +47,13 @@ Generate the ONNX model file (example for YOLOv5s)
python3 export_yoloV5.py -w yolov5s.pt --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 17.
```
--opset 12
```
**NOTE**: To convert a P6 model
```
--p6
```
**NOTE**: To change the inference size (defaut: 640)
**NOTE**: To change the inference size (defaut: 640 / 1280 for `--p6` models)
```
-s SIZE
@@ -98,6 +74,30 @@ or
-s 1280 1280
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 17.
```
--opset 12
```
#### 5. Copy generated files
Copy the generated ONNX model file and labels.txt file (if generated) to the `DeepStream-Yolo` folder.

View File

@@ -47,37 +47,13 @@ Generate the ONNX model file (example for YOLOv6-S 4.0)
python3 export_yoloV6.py -w yolov6s.pt --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 13.
```
--opset 12
```
**NOTE**: To convert a P6 model
```
--p6
```
**NOTE**: To change the inference size (defaut: 640)
**NOTE**: To change the inference size (defaut: 640 / 1280 for `--p6` models)
```
-s SIZE
@@ -98,6 +74,30 @@ or
-s 1280 1280
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 13.
```
--opset 12
```
#### 5. Copy generated file
Copy the generated ONNX model file to the `DeepStream-Yolo` folder.

View File

@@ -49,37 +49,13 @@ Generate the ONNX model file (example for YOLOv7)
python3 export_yoloV7.py -w yolov7.pt --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 12.
```
--opset 12
```
**NOTE**: To convert a P6 model
```
--p6
```
**NOTE**: To change the inference size (defaut: 640)
**NOTE**: To change the inference size (defaut: 640 / 1280 for `--p6` models)
```
-s SIZE
@@ -100,6 +76,30 @@ or
-s 1280 1280
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 12.
```
--opset 12
```
#### 6. Copy generated files
Copy the generated ONNX model file and labels.txt file (if generated) to the `DeepStream-Yolo` folder.

View File

@@ -46,30 +46,6 @@ Generate the ONNX model file (example for YOLOv8s)
python3 export_yoloV8.py -w yolov8s.pt --dynamic
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 16.
```
--opset 12
```
**NOTE**: To change the inference size (defaut: 640)
```
@@ -91,6 +67,30 @@ or
-s 1280 1280
```
**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)
```
--simplify
```
**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)
```
--dynamic
```
**NOTE**: To use implicit batch-size (example for batch-size = 4)
```
--batch 4
```
**NOTE**: If you are using DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 16.
```
--opset 12
```
#### 5. Copy generated files
Copy the generated ONNX model file and labels.txt file (if generated) to the `DeepStream-Yolo` folder.

View File

@@ -73,14 +73,14 @@ addBBoxProposal(const float bx1, const float by1, const float bx2, const float b
}
static std::vector<NvDsInferParseObjectInfo>
decodeTensorYolo(const float* boxes, const float* scores, const int* classes, const uint& outputSize, const uint& netW,
decodeTensorYolo(const float* boxes, const float* scores, const float* classes, const uint& outputSize, const uint& netW,
const uint& netH, const std::vector<float>& preclusterThreshold)
{
std::vector<NvDsInferParseObjectInfo> binfo;
for (uint b = 0; b < outputSize; ++b) {
float maxProb = scores[b];
int maxIndex = classes[b];
int maxIndex = (int) classes[b];
if (maxProb < preclusterThreshold[maxIndex])
continue;
@@ -102,14 +102,14 @@ decodeTensorYolo(const float* boxes, const float* scores, const int* classes, co
}
static std::vector<NvDsInferParseObjectInfo>
decodeTensorYoloE(const float* boxes, const float* scores, const int* classes, const uint& outputSize, const uint& netW,
decodeTensorYoloE(const float* boxes, const float* scores, const float* classes, const uint& outputSize, const uint& netW,
const uint& netH, const std::vector<float>& preclusterThreshold)
{
std::vector<NvDsInferParseObjectInfo> binfo;
for (uint b = 0; b < outputSize; ++b) {
float maxProb = scores[b];
int maxIndex = classes[b];
int maxIndex = (int) classes[b];
if (maxProb < preclusterThreshold[maxIndex])
continue;
@@ -136,26 +136,14 @@ NvDsInferParseCustomYolo(std::vector<NvDsInferLayerInfo> const& outputLayersInfo
std::vector<NvDsInferParseObjectInfo> objects;
NvDsInferLayerInfo* boxes;
NvDsInferLayerInfo* scores;
NvDsInferLayerInfo* classes;
const NvDsInferLayerInfo& boxes = outputLayersInfo[0];
const NvDsInferLayerInfo& scores = outputLayersInfo[1];
const NvDsInferLayerInfo& classes = outputLayersInfo[2];
for (uint i = 0; i < 3; ++i) {
if (outputLayersInfo[i].dataType == NvDsInferDataType::INT32) {
classes = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
else if (outputLayersInfo[i].inferDims.d[1] == 4) {
boxes = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
else {
scores = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
}
const uint outputSize = boxes.inferDims.d[0];
const uint outputSize = boxes->inferDims.d[0];
std::vector<NvDsInferParseObjectInfo> outObjs = decodeTensorYolo((const float*) (boxes->buffer),
(const float*) (scores->buffer), (const int*) (classes->buffer), outputSize, networkInfo.width, networkInfo.height,
std::vector<NvDsInferParseObjectInfo> outObjs = decodeTensorYolo((const float*) (boxes.buffer),
(const float*) (scores.buffer), (const float*) (classes.buffer), outputSize, networkInfo.width, networkInfo.height,
detectionParams.perClassPreclusterThreshold);
objects.insert(objects.end(), outObjs.begin(), outObjs.end());
@@ -176,26 +164,14 @@ NvDsInferParseCustomYoloE(std::vector<NvDsInferLayerInfo> const& outputLayersInf
std::vector<NvDsInferParseObjectInfo> objects;
NvDsInferLayerInfo* boxes;
NvDsInferLayerInfo* scores;
NvDsInferLayerInfo* classes;
const NvDsInferLayerInfo& boxes = outputLayersInfo[0];
const NvDsInferLayerInfo& scores = outputLayersInfo[1];
const NvDsInferLayerInfo& classes = outputLayersInfo[2];
for (uint i = 0; i < 3; ++i) {
if (outputLayersInfo[i].dataType == NvDsInferDataType::INT32) {
classes = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
else if (outputLayersInfo[i].inferDims.d[1] == 4) {
boxes = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
else {
scores = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
}
const uint outputSize = boxes.inferDims.d[0];
const uint outputSize = boxes->inferDims.d[0];
std::vector<NvDsInferParseObjectInfo> outObjs = decodeTensorYoloE((const float*) (boxes->buffer),
(const float*) (scores->buffer), (const int*) (classes->buffer), outputSize, networkInfo.width, networkInfo.height,
std::vector<NvDsInferParseObjectInfo> outObjs = decodeTensorYoloE((const float*) (boxes.buffer),
(const float*) (scores.buffer), (const float*) (classes.buffer), outputSize, networkInfo.width, networkInfo.height,
detectionParams.perClassPreclusterThreshold);
objects.insert(objects.end(), outObjs.begin(), outObjs.end());

View File

@@ -37,7 +37,7 @@ extern "C" bool
NvDsInferParseYoloECuda(std::vector<NvDsInferLayerInfo> const& outputLayersInfo, NvDsInferNetworkInfo const& networkInfo,
NvDsInferParseDetectionParams const& detectionParams, std::vector<NvDsInferParseObjectInfo>& objectList);
__global__ void decodeTensorYoloCuda(NvDsInferParseObjectInfo *binfo, float* boxes, float* scores, int* classes,
__global__ void decodeTensorYoloCuda(NvDsInferParseObjectInfo *binfo, float* boxes, float* scores, float* classes,
int outputSize, int netW, int netH, float minPreclusterThreshold)
{
int x_id = blockIdx.x * blockDim.x + threadIdx.x;
@@ -46,7 +46,7 @@ __global__ void decodeTensorYoloCuda(NvDsInferParseObjectInfo *binfo, float* box
return;
float maxProb = scores[x_id];
int maxIndex = classes[x_id];
int maxIndex = (int) classes[x_id];
if (maxProb < minPreclusterThreshold) {
binfo[x_id].detectionConfidence = 0.0;
@@ -76,7 +76,7 @@ __global__ void decodeTensorYoloCuda(NvDsInferParseObjectInfo *binfo, float* box
binfo[x_id].classId = maxIndex;
}
__global__ void decodeTensorYoloECuda(NvDsInferParseObjectInfo *binfo, float* boxes, float* scores, int* classes,
__global__ void decodeTensorYoloECuda(NvDsInferParseObjectInfo *binfo, float* boxes, float* scores, float* classes,
int outputSize, int netW, int netH, float minPreclusterThreshold)
{
int x_id = blockIdx.x * blockDim.x + threadIdx.x;
@@ -85,7 +85,7 @@ __global__ void decodeTensorYoloECuda(NvDsInferParseObjectInfo *binfo, float* bo
return;
float maxProb = scores[x_id];
int maxIndex = classes[x_id];
int maxIndex = (int) classes[x_id];
if (maxProb < minPreclusterThreshold) {
binfo[x_id].detectionConfidence = 0.0;
@@ -119,23 +119,11 @@ static bool NvDsInferParseCustomYoloCuda(std::vector<NvDsInferLayerInfo> const&
return false;
}
NvDsInferLayerInfo* boxes;
NvDsInferLayerInfo* scores;
NvDsInferLayerInfo* classes;
const NvDsInferLayerInfo& boxes = outputLayersInfo[0];
const NvDsInferLayerInfo& scores = outputLayersInfo[1];
const NvDsInferLayerInfo& classes = outputLayersInfo[2];
for (uint i = 0; i < 3; ++i) {
if (outputLayersInfo[i].dataType == NvDsInferDataType::INT32) {
classes = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
else if (outputLayersInfo[i].inferDims.d[1] == 4) {
boxes = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
else {
scores = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
}
const int outputSize = boxes->inferDims.d[0];
const int outputSize = boxes.inferDims.d[0];
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
@@ -146,8 +134,8 @@ static bool NvDsInferParseCustomYoloCuda(std::vector<NvDsInferLayerInfo> const&
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
decodeTensorYoloCuda<<<number_of_blocks, threads_per_block>>>(
thrust::raw_pointer_cast(objects.data()), (float*) (boxes->buffer), (float*) (scores->buffer),
(int*) (classes->buffer), outputSize, networkInfo.width, networkInfo.height, minPreclusterThreshold);
thrust::raw_pointer_cast(objects.data()), (float*) (boxes.buffer), (float*) (scores.buffer),
(float*) (classes.buffer), outputSize, networkInfo.width, networkInfo.height, minPreclusterThreshold);
objectList.resize(outputSize);
thrust::copy(objects.begin(), objects.end(), objectList.begin());
@@ -164,23 +152,11 @@ static bool NvDsInferParseCustomYoloECuda(std::vector<NvDsInferLayerInfo> const&
return false;
}
NvDsInferLayerInfo* boxes;
NvDsInferLayerInfo* scores;
NvDsInferLayerInfo* classes;
const NvDsInferLayerInfo& boxes = outputLayersInfo[0];
const NvDsInferLayerInfo& scores = outputLayersInfo[1];
const NvDsInferLayerInfo& classes = outputLayersInfo[2];
for (uint i = 0; i < 3; ++i) {
if (outputLayersInfo[i].dataType == NvDsInferDataType::INT32) {
classes = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
else if (outputLayersInfo[i].inferDims.d[1] == 4) {
boxes = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
else {
scores = (NvDsInferLayerInfo*) &outputLayersInfo[i];
}
}
const int outputSize = boxes->inferDims.d[0];
const int outputSize = boxes.inferDims.d[0];
thrust::device_vector<NvDsInferParseObjectInfo> objects(outputSize);
@@ -191,8 +167,8 @@ static bool NvDsInferParseCustomYoloECuda(std::vector<NvDsInferLayerInfo> const&
int number_of_blocks = ((outputSize - 1) / threads_per_block) + 1;
decodeTensorYoloECuda<<<number_of_blocks, threads_per_block>>>(
thrust::raw_pointer_cast(objects.data()), (float*) (boxes->buffer), (float*) (scores->buffer),
(int*) (classes->buffer), outputSize, networkInfo.width, networkInfo.height, minPreclusterThreshold);
thrust::raw_pointer_cast(objects.data()), (float*) (boxes.buffer), (float*) (scores.buffer),
(float*) (classes.buffer), outputSize, networkInfo.width, networkInfo.height, minPreclusterThreshold);
objectList.resize(outputSize);
thrust::copy(objects.begin(), objects.end(), objectList.begin());

View File

@@ -76,7 +76,7 @@ Yolo::createEngine(nvinfer1::IBuilder* builder)
if (m_NetworkType == "onnx") {
#if NV_TENSORRT_MAJOR >= 8
#if NV_TENSORRT_MAJOR >= 8 && NV_TENSORRT_MINOR > 0
parser = nvonnxparser::createParser(*network, *builder->getLogger());
#else
parser = nvonnxparser::createParser(*network, logger);

View File

@@ -45,7 +45,9 @@
#define INT int32_t
#else
#define INT int
#endif
#if NV_TENSORRT_MAJOR < 8 || (NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 0)
static class Logger : public nvinfer1::ILogger {
void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override {
if (severity <= nvinfer1::ILogger::Severity::kWARNING)

View File

@@ -18,6 +18,7 @@ class DeepStreamOutput(nn.Module):
def forward(self, x):
boxes = x[1]
scores, classes = torch.max(x[0], 2, keepdim=True)
classes = classes.float()
return boxes, scores, classes

View File

@@ -19,7 +19,7 @@ class DeepStreamOutput(nn.Layer):
boxes = x['bbox']
x['bbox_num'] = x['bbox_num'].transpose([0, 2, 1])
scores = paddle.max(x['bbox_num'], 2, keepdim=True)
classes = paddle.argmax(x['bbox_num'], 2, keepdim=True)
classes = paddle.cast(paddle.argmax(x['bbox_num'], 2, keepdim=True), dtype='float32')
return boxes, scores, classes

View File

@@ -20,6 +20,7 @@ class DeepStreamOutput(nn.Module):
objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
scores *= objectness
classes = classes.float()
return boxes, scores, classes

View File

@@ -24,6 +24,7 @@ class DeepStreamOutput(nn.Module):
objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
scores *= objectness
classes = classes.float()
return boxes, scores, classes

View File

@@ -20,6 +20,7 @@ class DeepStreamOutput(nn.Module):
objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
scores *= objectness
classes = classes.float()
return boxes, scores, classes

View File

@@ -18,6 +18,7 @@ class DeepStreamOutput(nn.Module):
x = x.transpose(1, 2)
boxes = x[:, :, :4]
scores, classes = torch.max(x[:, :, 4:], 2, keepdim=True)
classes = classes.float()
return boxes, scores, classes

View File

@@ -19,6 +19,7 @@ class DeepStreamOutput(nn.Module):
x = x.transpose(1, 2)
boxes = x[:, :, :4]
scores, classes = torch.max(x[:, :, 4:], 2, keepdim=True)
classes = classes.float()
return boxes, scores, classes

View File

@@ -15,6 +15,7 @@ class DeepStreamOutput(nn.Module):
def forward(self, x):
boxes = x[0]
scores, classes = torch.max(x[1], 2, keepdim=True)
classes = classes.float()
return boxes, scores, classes

View File

@@ -17,6 +17,7 @@ class DeepStreamOutput(nn.Module):
objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
scores *= objectness
classes = classes.float()
return boxes, scores, classes

View File

@@ -19,6 +19,7 @@ class DeepStreamOutput(nn.Module):
objectness = x[:, :, 4:5]
scores, classes = torch.max(x[:, :, 5:], 2, keepdim=True)
scores *= objectness
classes = classes.float()
return boxes, scores, classes