Add YOLOv7 support

This commit is contained in:
Marcos Luciano
2022-08-12 16:33:26 -03:00
parent 23547b19b2
commit 80d08990a0
10 changed files with 586 additions and 47 deletions

View File

@@ -8,7 +8,6 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* DeepStream tutorials * DeepStream tutorials
* YOLOX support * YOLOX support
* YOLOv6 support * YOLOv6 support
* YOLOv7 support
* Dynamic batch-size * Dynamic batch-size
### Improvements on this repository ### Improvements on this repository
@@ -27,6 +26,7 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* **GPU YOLO Decoder** [#138](https://github.com/marcoslucianops/DeepStream-Yolo/issues/138) * **GPU YOLO Decoder** [#138](https://github.com/marcoslucianops/DeepStream-Yolo/issues/138)
* **GPU Batched NMS** [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142) * **GPU Batched NMS** [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142)
* **PP-YOLOE support** * **PP-YOLOE support**
* **YOLOv7 support**
## ##
@@ -42,6 +42,7 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* [YOLOv5 usage](docs/YOLOv5.md) * [YOLOv5 usage](docs/YOLOv5.md)
* [YOLOR usage](docs/YOLOR.md) * [YOLOR usage](docs/YOLOR.md)
* [PP-YOLOE usage](docs/PPYOLOE.md) * [PP-YOLOE usage](docs/PPYOLOE.md)
* [YOLOv7 usage](docs/YOLOv7.md)
* [Using your custom model](docs/customModels.md) * [Using your custom model](docs/customModels.md)
* [Multiple YOLO GIEs](docs/multipleGIEs.md) * [Multiple YOLO GIEs](docs/multipleGIEs.md)
@@ -89,6 +90,7 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models
* [YOLOv5 >= 2.0](https://github.com/ultralytics/yolov5) * [YOLOv5 >= 2.0](https://github.com/ultralytics/yolov5)
* [YOLOR](https://github.com/WongKinYiu/yolor) * [YOLOR](https://github.com/WongKinYiu/yolor)
* [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe) * [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)
* [YOLOv7](https://github.com/WongKinYiu/yolov7)
* [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo) * [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo)
* [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest) * [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest)

View File

@@ -0,0 +1,24 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
custom-network-config=yolov7.cfg
model-file=yolov7.wts
model-engine-file=model_b1_gpu0_fp32.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
cluster-mode=4
maintain-aspect-ratio=0
parse-bbox-func-name=NvDsInferParseYolo
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
engine-create-func-name=NvDsInferYoloCudaEngineGet
[class-attrs-all]
pre-cluster-threshold=0

View File

@@ -74,7 +74,7 @@ Open the `DeepStream-Yolo` folder and compile the lib
## ##
### Edit the config_infer_primary_yoloV5 file ### Edit the config_infer_primary_ppyoloe file
Edit the `config_infer_primary_ppyoloe.txt` file according to your model (example for PP-YOLOE-s) Edit the `config_infer_primary_ppyoloe.txt` file according to your model (example for PP-YOLOE-s)
@@ -97,7 +97,7 @@ offsets=123.675;116.28;103.53
## ##
### Edit the deepstream_app_config.txt file ### Edit the deepstream_app_config file
``` ```
... ...

View File

@@ -2,7 +2,7 @@
**NOTE**: You need to use the main branch of the YOLOR repo to convert the model. **NOTE**: You need to use the main branch of the YOLOR repo to convert the model.
**NOTE**: The cfg is required. **NOTE**: The cfg file is required.
* [Convert model](#convert-model) * [Convert model](#convert-model)
* [Compile the lib](#compile-the-lib) * [Compile the lib](#compile-the-lib)
@@ -92,7 +92,7 @@ model-file=yolor_csp.wts
## ##
### Edit the deepstream_app_config.txt file ### Edit the deepstream_app_config file
``` ```
... ...

View File

@@ -2,7 +2,7 @@
**NOTE**: You can use the main branch of the YOLOv5 repo to convert all model versions. **NOTE**: You can use the main branch of the YOLOv5 repo to convert all model versions.
**NOTE**: The yaml is not required. **NOTE**: The yaml file is not required.
* [Convert model](#convert-model) * [Convert model](#convert-model)
* [Compile the lib](#compile-the-lib) * [Compile the lib](#compile-the-lib)
@@ -117,7 +117,7 @@ model-file=yolov5s.wts
## ##
### Edit the deepstream_app_config.txt file ### Edit the deepstream_app_config file
``` ```
... ...

133
docs/YOLOv7.md Normal file
View File

@@ -0,0 +1,133 @@
# YOLOv7 usage
**NOTE**: The yaml file is not required.
* [Convert model](#convert-model)
* [Compile the lib](#compile-the-lib)
* [Edit the config_infer_primary_yoloV7 file](#edit-the-config_infer_primary_yolov7-file)
* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file)
* [Testing the model](#testing-the-model)
##
### Convert model
#### 1. Download the YOLOv7 repo and install the requirements
```
git clone https://github.com/WongKinYiu/yolov7.git
cd yolov7
pip3 install -r requirements.txt
```
**NOTE**: It is recommended to use Python virtualenv.
#### 2. Copy conversor
Copy the `gen_wts_yoloV7.py` file from `DeepStream-Yolo/utils` directory to the `yolov7` folder.
#### 3. Download the model
Download the `pt` file from [YOLOv7](https://github.com/WongKinYiu/yolov7/releases/) releases (example for YOLOv7)
```
wget hhttps://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt
```
**NOTE**: You can use your custom model, but it is important to keep the YOLO model reference (`yolov7_`) in you `cfg` and `weights`/`wts` filenames to generate the engine correctly.
#### 4. Convert model
Generate the `cfg` and `wts` files (example for YOLOv7)
```
python3 gen_wts_yoloV7.py -w yolov7.pt
```
**NOTE**: To change the inference size (defaut: 640)
```
-s SIZE
--size SIZE
-s HEIGHT WIDTH
--size HEIGHT WIDTH
```
Example for 1280
```
-s 1280
```
or
```
-s 1280 1280
```
#### 5. Copy generated files
Copy the generated `cfg` and `wts` files to the `DeepStream-Yolo` folder.
##
### Compile the lib
Open the `DeepStream-Yolo` folder and compile the lib
* DeepStream 6.1 on x86 platform
```
CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on x86 platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.1 on Jetson platform
```
CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo
```
* DeepStream 6.0.1 / 6.0 on Jetson platform
```
CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo
```
##
### Edit the config_infer_primary_yoloV7 file
Edit the `config_infer_primary_yoloV7.txt` file according to your model (example for YOLOv7)
```
[property]
...
custom-network-config=yolov7.cfg
model-file=yolov7.wts
...
```
##
### Edit the deepstream_app_config file
```
...
[primary-gie]
...
config-file=config_infer_primary_yoloV7.txt
```
##
### Testing the model
```
deepstream-app -c deepstream_app_config.txt
```

View File

@@ -49,37 +49,16 @@ LIBS+= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib6
LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group
INCS:= $(wildcard *.h) INCS:= $(wildcard *.h)
SRCFILES:= nvdsinfer_yolo_engine.cpp \
nvdsparsebbox_Yolo.cpp \ SRCFILES:= $(filter-out calibrator.cpp, $(wildcard *.cpp))
yoloPlugins.cpp \
layers/convolutional_layer.cpp \
layers/batchnorm_layer.cpp \
layers/implicit_layer.cpp \
layers/channels_layer.cpp \
layers/shortcut_layer.cpp \
layers/route_layer.cpp \
layers/upsample_layer.cpp \
layers/pooling_layer.cpp \
layers/activation_layer.cpp \
layers/reorg_layer.cpp \
layers/reduce_layer.cpp \
layers/shuffle_layer.cpp \
layers/softmax_layer.cpp \
layers/cls_layer.cpp \
layers/reg_layer.cpp \
utils.cpp \
yolo.cpp \
yoloForward.cu \
yoloForward_v2.cu \
yoloForward_nc.cu \
yoloForward_r.cu \
yoloForward_e.cu \
sortDetections.cu
ifeq ($(OPENCV), 1) ifeq ($(OPENCV), 1)
SRCFILES+= calibrator.cpp SRCFILES+= calibrator.cpp
endif endif
SRCFILES+= $(wildcard layers/*.cpp)
SRCFILES+= $(wildcard *.cu)
TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so
TARGET_OBJS:= $(SRCFILES:.cpp=.o) TARGET_OBJS:= $(SRCFILES:.cpp=.o)

View File

@@ -333,7 +333,6 @@ def export_model():
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
check_config(cfg) check_config(cfg)
check_gpu(cfg.use_gpu)
check_version() check_version()
trainer = Trainer(cfg, mode='test') trainer = Trainer(cfg, mode='test')

View File

@@ -115,8 +115,10 @@ class Layers(object):
self.current = child.i self.current = child.i
self.fc.write('\n# Concat\n') self.fc.write('\n# Concat\n')
r = self.get_route(child.f[1]) r = []
self.route('-1, %d' % (r - 1)) for i in range(1, len(child.f)):
r.append(self.get_route(child.f[i]))
self.route('-1, %s' % str(r)[1:-1])
def Detect(self, child): def Detect(self, child):
self.current = child.i self.current = child.i
@@ -126,7 +128,7 @@ class Layers(object):
for i, m in enumerate(child.m): for i, m in enumerate(child.m):
r = self.get_route(child.f[i]) r = self.get_route(child.f[i])
self.route('%d' % (r - 1)) self.route('%d' % r)
self.convolutional(m, detect=True) self.convolutional(m, detect=True)
self.yolo(i) self.yolo(i)
@@ -137,12 +139,36 @@ class Layers(object):
'channels=3\n' + 'channels=3\n' +
'letter_box=1\n') 'letter_box=1\n')
def CBH(self, child):
self.current = child.i
self.fc.write('\n# CBH\n')
self.convolutional(child.conv, act='hardswish')
def LC_Block(self, child):
self.current = child.i
self.fc.write('\n# LC_Block\n')
self.convolutional(child.dw_conv, act='hardswish')
if child.use_se:
self.avgpool()
self.convolutional(child.se.conv1, act='relu')
self.convolutional(child.se.conv2, act='silu')
self.shortcut(-4, ew='mul')
self.convolutional(child.pw_conv, act='hardswish')
def Dense(self, child):
self.current = child.i
self.fc.write('\n# Dense\n')
self.convolutional(child.dense_conv, act='hardswish')
def reorg(self): def reorg(self):
self.blocks[self.current] += 1 self.blocks[self.current] += 1
self.fc.write('\n[reorg]\n') self.fc.write('\n[reorg]\n')
def convolutional(self, cv, detect=False): def convolutional(self, cv, act=None, detect=False):
self.blocks[self.current] += 1 self.blocks[self.current] += 1
self.get_state_dict(cv.state_dict()) self.get_state_dict(cv.state_dict())
@@ -164,7 +190,8 @@ class Layers(object):
groups = cv.conv.groups groups = cv.conv.groups
bias = cv.conv.bias bias = cv.conv.bias
bn = True if hasattr(cv, 'bn') else False bn = True if hasattr(cv, 'bn') else False
act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear' if act is None:
act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear'
b = 'batch_normalize=1\n' if bn is True else '' b = 'batch_normalize=1\n' if bn is True else ''
g = 'groups=%d\n' % groups if groups > 1 else '' g = 'groups=%d\n' % groups if groups > 1 else ''
@@ -198,12 +225,15 @@ class Layers(object):
self.fc.write('\n[route]\n' + self.fc.write('\n[route]\n' +
'layers=%s\n' % layers) 'layers=%s\n' % layers)
def shortcut(self, r, activation='linear'): def shortcut(self, r, ew='add', act='linear'):
self.blocks[self.current] += 1 self.blocks[self.current] += 1
m = 'mode=mul\n' if ew == 'mul' else ''
self.fc.write('\n[shortcut]\n' + self.fc.write('\n[shortcut]\n' +
'from=%d\n' % r + 'from=%d\n' % r +
'activation=%s\n' % activation) m +
'activation=%s\n' % act)
def maxpool(self, m): def maxpool(self, m):
self.blocks[self.current] += 1 self.blocks[self.current] += 1
@@ -226,6 +256,11 @@ class Layers(object):
self.fc.write('\n[upsample]\n' + self.fc.write('\n[upsample]\n' +
'stride=%d\n' % stride) 'stride=%d\n' % stride)
def avgpool(self):
self.blocks[self.current] += 1
self.fc.write('\n[avgpool]\n')
def yolo(self, i): def yolo(self, i):
self.blocks[self.current] += 1 self.blocks[self.current] += 1
@@ -272,12 +307,19 @@ class Layers(object):
def get_route(self, n): def get_route(self, n):
r = 0 r = 0
for i, b in enumerate(self.blocks): if n < 0:
if i <= n: for i, b in enumerate(self.blocks[self.current-1::-1]):
r += b if i < abs(n) - 1:
else: r -= b
break else:
return r break
else:
for i, b in enumerate(self.blocks):
if i <= n:
r += b
else:
break
return r - 1
def get_activation(self, act): def get_activation(self, act):
if act == 'Hardswish': if act == 'Hardswish':
@@ -286,6 +328,7 @@ class Layers(object):
return 'leaky' return 'leaky'
elif act == 'SiLU': elif act == 'SiLU':
return 'silu' return 'silu'
return 'linear'
def parse_args(): def parse_args():
@@ -336,6 +379,12 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
layers.Concat(child) layers.Concat(child)
elif child._get_name() == 'Detect': elif child._get_name() == 'Detect':
layers.Detect(child) layers.Detect(child)
elif child._get_name() == 'CBH':
layers.CBH(child)
elif child._get_name() == 'LC_Block':
layers.LC_Block(child)
elif child._get_name() == 'Dense':
layers.Dense(child)
else: else:
raise SystemExit('Model not supported') raise SystemExit('Model not supported')

353
utils/gen_wts_yoloV7.py Normal file
View File

@@ -0,0 +1,353 @@
import argparse
import os
import struct
import torch
from utils.torch_utils import select_device
class Layers(object):
def __init__(self, n, size, fw, fc):
self.blocks = [0 for _ in range(n)]
self.current = 0
self.width = size[0] if len(size) == 1 else size[1]
self.height = size[0]
self.num = 0
self.nc = 0
self.anchors = ''
self.masks = []
self.fw = fw
self.fc = fc
self.wc = 0
self.net()
def ReOrg(self, child):
self.current = child.i
self.fc.write('\n# ReOrg\n')
self.reorg()
def Conv(self, child):
self.current = child.i
self.fc.write('\n# Conv\n')
if child.f != -1:
r = self.get_route(child.f)
self.route('%d' % r)
self.convolutional(child)
def DownC(self, child):
self.current = child.i
self.fc.write('\n# DownC\n')
self.maxpool(child.mp)
self.convolutional(child.cv3)
self.route('-3')
self.convolutional(child.cv1)
self.convolutional(child.cv2)
self.route('-1, -4')
def MP(self, child):
self.current = child.i
self.fc.write('\n# MP\n')
self.maxpool(child.m)
def SP(self, child):
self.current = child.i
self.fc.write('\n# SP\n')
if child.f != -1:
r = self.get_route(child.f)
self.route('%d' % r)
self.maxpool(child.m)
def SPPCSPC(self, child):
self.current = child.i
self.fc.write('\n# SPPCSPC\n')
self.convolutional(child.cv2)
self.route('-2')
self.convolutional(child.cv1)
self.convolutional(child.cv3)
self.convolutional(child.cv4)
self.maxpool(child.m[0])
self.route('-2')
self.maxpool(child.m[1])
self.route('-4')
self.maxpool(child.m[2])
self.route('-6, -5, -3, -1')
self.convolutional(child.cv5)
self.convolutional(child.cv6)
self.route('-1, -13')
self.convolutional(child.cv7)
def RepConv(self, child):
self.current = child.i
self.fc.write('\n# RepConv\n')
if child.f != -1:
r = self.get_route(child.f)
self.route('%d' % r)
self.convolutional(child.rbr_1x1)
self.route('-2')
self.convolutional(child.rbr_dense)
self.shortcut(-3, act=self.get_activation(child.act._get_name()))
def Upsample(self, child):
self.current = child.i
self.fc.write('\n# Upsample\n')
self.upsample(child)
def Concat(self, child):
self.current = child.i
self.fc.write('\n# Concat\n')
r = []
for i in range(1, len(child.f)):
r.append(self.get_route(child.f[i]))
self.route('-1, %s' % str(r)[1:-1])
def Shortcut(self, child):
self.current = child.i
self.fc.write('\n# Shortcut\n')
r = self.get_route(child.f[1])
self.shortcut(r)
def Detect(self, child):
self.current = child.i
self.fc.write('\n# Detect\n')
self.get_anchors(child.state_dict(), child.m[0].out_channels)
for i, m in enumerate(child.m):
r = self.get_route(child.f[i])
self.route('%d' % r)
self.convolutional(m, detect=True)
self.yolo(i)
def net(self):
self.fc.write('[net]\n' +
'width=%d\n' % self.width +
'height=%d\n' % self.height +
'channels=3\n')
def reorg(self):
self.blocks[self.current] += 1
self.fc.write('\n[reorg]\n')
def convolutional(self, cv, act=None, detect=False):
self.blocks[self.current] += 1
self.get_state_dict(cv.state_dict())
if cv._get_name() == 'Conv2d':
filters = cv.out_channels
size = cv.kernel_size
stride = cv.stride
pad = cv.padding
groups = cv.groups
bias = cv.bias
bn = False
act = 'linear' if not detect else 'logistic'
elif cv._get_name() == 'Sequential':
filters = cv[0].out_channels
size = cv[0].kernel_size
stride = cv[0].stride
pad = cv[0].padding
groups = cv[0].groups
bias = cv[0].bias
bn = True if cv[1]._get_name() == 'BatchNorm2d' else False
act = 'linear'
else:
filters = cv.conv.out_channels
size = cv.conv.kernel_size
stride = cv.conv.stride
pad = cv.conv.padding
groups = cv.conv.groups
bias = cv.conv.bias
bn = True if hasattr(cv, 'bn') else False
if act is None:
act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear'
b = 'batch_normalize=1\n' if bn is True else ''
g = 'groups=%d\n' % groups if groups > 1 else ''
w = 'bias=0\n' if bias is None and bn is False else ''
self.fc.write('\n[convolutional]\n' +
b +
'filters=%d\n' % filters +
'size=%s\n' % self.get_value(size) +
'stride=%s\n' % self.get_value(stride) +
'pad=%s\n' % self.get_value(pad) +
g +
w +
'activation=%s\n' % act)
def route(self, layers):
self.blocks[self.current] += 1
self.fc.write('\n[route]\n' +
'layers=%s\n' % layers)
def shortcut(self, r, act='linear'):
self.blocks[self.current] += 1
self.fc.write('\n[shortcut]\n' +
'from=%d\n' % r +
'activation=%s\n' % act)
def maxpool(self, m):
self.blocks[self.current] += 1
stride = m.stride
size = m.kernel_size
mode = m.ceil_mode
m = 'maxpool_up' if mode else 'maxpool'
self.fc.write('\n[%s]\n' % m +
'stride=%d\n' % stride +
'size=%d\n' % size)
def upsample(self, child):
self.blocks[self.current] += 1
stride = child.scale_factor
self.fc.write('\n[upsample]\n' +
'stride=%d\n' % stride)
def yolo(self, i):
self.blocks[self.current] += 1
self.fc.write('\n[yolo]\n' +
'mask=%s\n' % self.masks[i] +
'anchors=%s\n' % self.anchors +
'classes=%d\n' % self.nc +
'num=%d\n' % self.num +
'scale_x_y=2.0\n' +
'new_coords=1\n')
def get_state_dict(self, state_dict):
for k, v in state_dict.items():
if 'num_batches_tracked' not in k:
vr = v.reshape(-1).numpy()
self.fw.write('{} {} '.format(k, len(vr)))
for vv in vr:
self.fw.write(' ')
self.fw.write(struct.pack('>f', float(vv)).hex())
self.fw.write('\n')
self.wc += 1
def get_anchors(self, state_dict, out_channels):
anchor_grid = state_dict['anchor_grid']
aa = anchor_grid.reshape(-1).tolist()
am = anchor_grid.tolist()
self.num = (len(aa) / 2)
self.nc = int((out_channels / (self.num / len(am))) - 5)
self.anchors = str(aa)[1:-1]
n = 0
for m in am:
mask = []
for _ in range(len(m)):
mask.append(n)
n += 1
self.masks.append(str(mask)[1:-1])
def get_value(self, key):
if type(key) == int:
return key
return key[0] if key[0] == key[1] else str(key)[1:-1]
def get_route(self, n):
r = 0
if n < 0:
for i, b in enumerate(self.blocks[self.current-1::-1]):
if i < abs(n) - 1:
r -= b
else:
break
else:
for i, b in enumerate(self.blocks):
if i <= n:
r += b
else:
break
return r - 1
def get_activation(self, act):
if act == 'Hardswish':
return 'hardswish'
elif act == 'LeakyReLU':
return 'leaky'
elif act == 'SiLU':
return 'silu'
return 'linear'
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch YOLOv7 conversion')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
parser.add_argument(
'-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid weights file')
return args.weights, args.size
pt_file, inference_size = parse_args()
model_name = os.path.basename(pt_file).split('.pt')[0]
wts_file = model_name + '.wts' if 'yolov7' in model_name else 'yolov7_' + model_name + '.wts'
cfg_file = model_name + '.cfg' if 'yolov7' in model_name else 'yolov7_' + model_name + '.cfg'
device = select_device('cpu')
model = torch.load(pt_file, map_location=device)
model = model['ema' if model.get('ema') else 'model'].float()
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
delattr(model.model[-1], 'anchor_grid')
model.model[-1].register_buffer('anchor_grid', anchor_grid)
model.to(device).eval()
with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
layers = Layers(len(model.model), inference_size, fw, fc)
for child in model.model.children():
if child._get_name() == 'ReOrg':
layers.ReOrg(child)
elif child._get_name() == 'Conv':
layers.Conv(child)
elif child._get_name() == 'DownC':
layers.DownC(child)
elif child._get_name() == 'MP':
layers.MP(child)
elif child._get_name() == 'SP':
layers.SP(child)
elif child._get_name() == 'SPPCSPC':
layers.SPPCSPC(child)
elif child._get_name() == 'RepConv':
layers.RepConv(child)
elif child._get_name() == 'Upsample':
layers.Upsample(child)
elif child._get_name() == 'Concat':
layers.Concat(child)
elif child._get_name() == 'Shortcut':
layers.Shortcut(child)
elif child._get_name() == 'Detect':
layers.Detect(child)
else:
raise SystemExit('Model not supported')
os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))