Fix route layer

This commit is contained in:
Marcos Luciano
2022-08-16 14:36:52 -03:00
parent ab082fc292
commit dc023b308e
6 changed files with 33 additions and 21 deletions

View File

@@ -25,14 +25,14 @@ nvinfer1::ITensor* reorgLayer(
slice1->setName(slice1LayerName.c_str());
nvinfer1::ISliceLayer *slice2 = network->addSlice(
*input, nvinfer1::Dims{3, {0, 0, 1}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}},
*input, nvinfer1::Dims{3, {0, 1, 0}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}},
nvinfer1::Dims{3, {1, 2, 2}});
assert(slice2 != nullptr);
std::string slice2LayerName = "slice2_" + std::to_string(layerIdx);
slice2->setName(slice2LayerName.c_str());
nvinfer1::ISliceLayer *slice3 = network->addSlice(
*input, nvinfer1::Dims{3, {0, 1, 0}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}},
*input, nvinfer1::Dims{3, {0, 0, 1}}, nvinfer1::Dims{3, {inputDims.d[0], inputDims.d[1] / 2, inputDims.d[2] / 2}},
nvinfer1::Dims{3, {1, 2, 2}});
assert(slice3 != nullptr);
std::string slice3LayerName = "slice3_" + std::to_string(layerIdx);

View File

@@ -46,20 +46,21 @@ nvinfer1::ITensor* routeLayer(
layers += std::to_string(idxLayers[idxLayers.size() - 1]);
if (concatInputs.size() == 1)
return concatInputs[0];
output = concatInputs[0];
else {
int axis = 0;
if (block.find("axis") != block.end())
axis = std::stoi(block.at("axis"));
if (axis < 0)
axis = concatInputs[0]->getDimensions().nbDims + axis;
int axis = 0;
if (block.find("axis") != block.end())
axis = std::stoi(block.at("axis"));
if (axis < 0)
axis = concatInputs[0]->getDimensions().nbDims + axis;
nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size());
assert(concat != nullptr);
std::string concatLayerName = "route_" + std::to_string(layerIdx);
concat->setName(concatLayerName.c_str());
concat->setAxis(axis);
output = concat->getOutput(0);
nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size());
assert(concat != nullptr);
std::string concatLayerName = "route_" + std::to_string(layerIdx);
concat->setName(concatLayerName.c_str());
concat->setAxis(axis);
output = concat->getOutput(0);
}
if (block.find("groups") != block.end())
{