New features
- Added support for INT8 calibration - Added support for non square models - Updated mAP comparison between models
This commit is contained in:
@@ -20,20 +20,20 @@
|
||||
|
||||
inline __device__ float sigmoidGPU(const float& x) { return 1.0f / (1.0f + __expf(-x)); }
|
||||
|
||||
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
|
||||
__global__ void gpuYoloLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes, const uint new_coords, const float scale_x_y)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
|
||||
|
||||
if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
|
||||
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const int numGridCells = gridSize * gridSize;
|
||||
const int bbindex = y_id * gridSize + x_id;
|
||||
const int numGridCells = gridSizeX * gridSizeY;
|
||||
const int bbindex = y_id * gridSizeX + x_id;
|
||||
|
||||
float alpha = scale_x_y;
|
||||
float beta = -0.5 * (scale_x_y - 1);
|
||||
@@ -84,20 +84,20 @@ __global__ void gpuYoloLayer(const float* input, float* output, const uint gridS
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSize, const uint numOutputClasses,
|
||||
__global__ void gpuRegionLayer(const float* input, float* output, const uint gridSizeX, const uint gridSizeY, const uint numOutputClasses,
|
||||
const uint numBBoxes)
|
||||
{
|
||||
uint x_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint y_id = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
uint z_id = blockIdx.z * blockDim.z + threadIdx.z;
|
||||
|
||||
if ((x_id >= gridSize) || (y_id >= gridSize) || (z_id >= numBBoxes))
|
||||
if ((x_id >= gridSizeX) || (y_id >= gridSizeY) || (z_id >= numBBoxes))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const int numGridCells = gridSize * gridSize;
|
||||
const int bbindex = y_id * gridSize + x_id;
|
||||
const int numGridCells = gridSizeX * gridSizeY;
|
||||
const int bbindex = y_id * gridSizeX + x_id;
|
||||
|
||||
output[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]
|
||||
= sigmoidGPU(input[bbindex + numGridCells * (z_id * (5 + numOutputClasses) + 0)]);
|
||||
@@ -132,24 +132,24 @@ __global__ void gpuRegionLayer(const float* input, float* output, const uint gri
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSize,
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes,
|
||||
uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType);
|
||||
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSize,
|
||||
cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize, const uint& gridSizeX, const uint& gridSizeY,
|
||||
const uint& numOutputClasses, const uint& numBBoxes,
|
||||
uint64_t outputSize, cudaStream_t stream, const uint modelCoords, const float modelScale, const uint modelType)
|
||||
{
|
||||
dim3 threads_per_block(16, 16, 4);
|
||||
dim3 number_of_blocks((gridSize / threads_per_block.x) + 1,
|
||||
(gridSize / threads_per_block.y) + 1,
|
||||
dim3 number_of_blocks((gridSizeX / threads_per_block.x) + 1,
|
||||
(gridSizeY / threads_per_block.y) + 1,
|
||||
(numBBoxes / threads_per_block.z) + 1);
|
||||
if (modelType == 1) {
|
||||
for (unsigned int batch = 0; batch < batchSize; ++batch)
|
||||
{
|
||||
gpuYoloLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes, modelCoords, modelScale);
|
||||
}
|
||||
}
|
||||
@@ -158,7 +158,7 @@ cudaError_t cudaYoloLayer(const void* input, void* output, const uint& batchSize
|
||||
{
|
||||
gpuRegionLayer<<<number_of_blocks, threads_per_block, 0, stream>>>(
|
||||
reinterpret_cast<const float*>(input) + (batch * outputSize),
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSize, numOutputClasses,
|
||||
reinterpret_cast<float*>(output) + (batch * outputSize), gridSizeX, gridSizeY, numOutputClasses,
|
||||
numBBoxes);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user