Fix route layer

This commit is contained in:
Marcos Luciano
2022-08-16 14:36:52 -03:00
parent ab082fc292
commit dc023b308e
6 changed files with 33 additions and 21 deletions

View File

@@ -46,20 +46,21 @@ nvinfer1::ITensor* routeLayer(
layers += std::to_string(idxLayers[idxLayers.size() - 1]);
if (concatInputs.size() == 1)
return concatInputs[0];
output = concatInputs[0];
else {
int axis = 0;
if (block.find("axis") != block.end())
axis = std::stoi(block.at("axis"));
if (axis < 0)
axis = concatInputs[0]->getDimensions().nbDims + axis;
int axis = 0;
if (block.find("axis") != block.end())
axis = std::stoi(block.at("axis"));
if (axis < 0)
axis = concatInputs[0]->getDimensions().nbDims + axis;
nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size());
assert(concat != nullptr);
std::string concatLayerName = "route_" + std::to_string(layerIdx);
concat->setName(concatLayerName.c_str());
concat->setAxis(axis);
output = concat->getOutput(0);
nvinfer1::IConcatenationLayer* concat = network->addConcatenation(concatInputs.data(), concatInputs.size());
assert(concat != nullptr);
std::string concatLayerName = "route_" + std::to_string(layerIdx);
concat->setName(concatLayerName.c_str());
concat->setAxis(axis);
output = concat->getOutput(0);
}
if (block.find("groups") != block.end())
{