diff --git a/README.md b/README.md index e38a7e1..3227cb3 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models * DeepStream tutorials * YOLOX support * YOLOv6 support -* YOLOv7 support * Dynamic batch-size ### Improvements on this repository @@ -27,6 +26,7 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models * **GPU YOLO Decoder** [#138](https://github.com/marcoslucianops/DeepStream-Yolo/issues/138) * **GPU Batched NMS** [#142](https://github.com/marcoslucianops/DeepStream-Yolo/issues/142) * **PP-YOLOE support** +* **YOLOv7 support** ## @@ -42,6 +42,7 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models * [YOLOv5 usage](docs/YOLOv5.md) * [YOLOR usage](docs/YOLOR.md) * [PP-YOLOE usage](docs/PPYOLOE.md) +* [YOLOv7 usage](docs/YOLOv7.md) * [Using your custom model](docs/customModels.md) * [Multiple YOLO GIEs](docs/multipleGIEs.md) @@ -89,6 +90,7 @@ NVIDIA DeepStream SDK 6.1 / 6.0.1 / 6.0 configuration for YOLO models * [YOLOv5 >= 2.0](https://github.com/ultralytics/yolov5) * [YOLOR](https://github.com/WongKinYiu/yolor) * [PP-YOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe) +* [YOLOv7](https://github.com/WongKinYiu/yolov7) * [MobileNet-YOLO](https://github.com/dog-qiuqiu/MobileNet-Yolo) * [YOLO-Fastest](https://github.com/dog-qiuqiu/Yolo-Fastest) diff --git a/config_infer_primary_yoloV7.txt b/config_infer_primary_yoloV7.txt new file mode 100644 index 0000000..c8050f0 --- /dev/null +++ b/config_infer_primary_yoloV7.txt @@ -0,0 +1,24 @@ +[property] +gpu-id=0 +net-scale-factor=0.0039215697906911373 +model-color-format=0 +custom-network-config=yolov7.cfg +model-file=yolov7.wts +model-engine-file=model_b1_gpu0_fp32.engine +#int8-calib-file=calib.table +labelfile-path=labels.txt +batch-size=1 +network-mode=0 +num-detected-classes=80 +interval=0 +gie-unique-id=1 +process-mode=1 +network-type=0 +cluster-mode=4 +maintain-aspect-ratio=0 +parse-bbox-func-name=NvDsInferParseYolo +custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so +engine-create-func-name=NvDsInferYoloCudaEngineGet + +[class-attrs-all] +pre-cluster-threshold=0 diff --git a/docs/PPYOLOE.md b/docs/PPYOLOE.md index 9e97d0c..35a8465 100644 --- a/docs/PPYOLOE.md +++ b/docs/PPYOLOE.md @@ -74,7 +74,7 @@ Open the `DeepStream-Yolo` folder and compile the lib ## -### Edit the config_infer_primary_yoloV5 file +### Edit the config_infer_primary_ppyoloe file Edit the `config_infer_primary_ppyoloe.txt` file according to your model (example for PP-YOLOE-s) @@ -97,7 +97,7 @@ offsets=123.675;116.28;103.53 ## -### Edit the deepstream_app_config.txt file +### Edit the deepstream_app_config file ``` ... diff --git a/docs/YOLOR.md b/docs/YOLOR.md index efa0037..f0e3f9e 100644 --- a/docs/YOLOR.md +++ b/docs/YOLOR.md @@ -2,7 +2,7 @@ **NOTE**: You need to use the main branch of the YOLOR repo to convert the model. -**NOTE**: The cfg is required. +**NOTE**: The cfg file is required. * [Convert model](#convert-model) * [Compile the lib](#compile-the-lib) @@ -92,7 +92,7 @@ model-file=yolor_csp.wts ## -### Edit the deepstream_app_config.txt file +### Edit the deepstream_app_config file ``` ... diff --git a/docs/YOLOv5.md b/docs/YOLOv5.md index 44a5d40..8030247 100644 --- a/docs/YOLOv5.md +++ b/docs/YOLOv5.md @@ -2,7 +2,7 @@ **NOTE**: You can use the main branch of the YOLOv5 repo to convert all model versions. -**NOTE**: The yaml is not required. +**NOTE**: The yaml file is not required. * [Convert model](#convert-model) * [Compile the lib](#compile-the-lib) @@ -117,7 +117,7 @@ model-file=yolov5s.wts ## -### Edit the deepstream_app_config.txt file +### Edit the deepstream_app_config file ``` ... diff --git a/docs/YOLOv7.md b/docs/YOLOv7.md new file mode 100644 index 0000000..15cd771 --- /dev/null +++ b/docs/YOLOv7.md @@ -0,0 +1,133 @@ +# YOLOv7 usage + +**NOTE**: The yaml file is not required. + +* [Convert model](#convert-model) +* [Compile the lib](#compile-the-lib) +* [Edit the config_infer_primary_yoloV7 file](#edit-the-config_infer_primary_yolov7-file) +* [Edit the deepstream_app_config file](#edit-the-deepstream_app_config-file) +* [Testing the model](#testing-the-model) + +## + +### Convert model + +#### 1. Download the YOLOv7 repo and install the requirements + +``` +git clone https://github.com/WongKinYiu/yolov7.git +cd yolov7 +pip3 install -r requirements.txt +``` + +**NOTE**: It is recommended to use Python virtualenv. + +#### 2. Copy conversor + +Copy the `gen_wts_yoloV7.py` file from `DeepStream-Yolo/utils` directory to the `yolov7` folder. + +#### 3. Download the model + +Download the `pt` file from [YOLOv7](https://github.com/WongKinYiu/yolov7/releases/) releases (example for YOLOv7) + +``` +wget hhttps://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt +``` + +**NOTE**: You can use your custom model, but it is important to keep the YOLO model reference (`yolov7_`) in you `cfg` and `weights`/`wts` filenames to generate the engine correctly. + +#### 4. Convert model + +Generate the `cfg` and `wts` files (example for YOLOv7) + +``` +python3 gen_wts_yoloV7.py -w yolov7.pt +``` + +**NOTE**: To change the inference size (defaut: 640) + +``` +-s SIZE +--size SIZE +-s HEIGHT WIDTH +--size HEIGHT WIDTH +``` + +Example for 1280 + +``` +-s 1280 +``` + +or + +``` +-s 1280 1280 +``` + +#### 5. Copy generated files + +Copy the generated `cfg` and `wts` files to the `DeepStream-Yolo` folder. + +## + +### Compile the lib + +Open the `DeepStream-Yolo` folder and compile the lib + +* DeepStream 6.1 on x86 platform + + ``` + CUDA_VER=11.6 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.0.1 / 6.0 on x86 platform + + ``` + CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.1 on Jetson platform + + ``` + CUDA_VER=11.4 make -C nvdsinfer_custom_impl_Yolo + ``` + +* DeepStream 6.0.1 / 6.0 on Jetson platform + + ``` + CUDA_VER=10.2 make -C nvdsinfer_custom_impl_Yolo + ``` + +## + +### Edit the config_infer_primary_yoloV7 file + +Edit the `config_infer_primary_yoloV7.txt` file according to your model (example for YOLOv7) + +``` +[property] +... +custom-network-config=yolov7.cfg +model-file=yolov7.wts +... +``` + +## + +### Edit the deepstream_app_config file + +``` +... +[primary-gie] +... +config-file=config_infer_primary_yoloV7.txt +``` + +## + +### Testing the model + +``` +deepstream-app -c deepstream_app_config.txt +``` diff --git a/nvdsinfer_custom_impl_Yolo/Makefile b/nvdsinfer_custom_impl_Yolo/Makefile index d71080a..518a1a0 100644 --- a/nvdsinfer_custom_impl_Yolo/Makefile +++ b/nvdsinfer_custom_impl_Yolo/Makefile @@ -49,37 +49,16 @@ LIBS+= -lnvinfer_plugin -lnvinfer -lnvparsers -L/usr/local/cuda-$(CUDA_VER)/lib6 LFLAGS:= -shared -Wl,--start-group $(LIBS) -Wl,--end-group INCS:= $(wildcard *.h) -SRCFILES:= nvdsinfer_yolo_engine.cpp \ - nvdsparsebbox_Yolo.cpp \ - yoloPlugins.cpp \ - layers/convolutional_layer.cpp \ - layers/batchnorm_layer.cpp \ - layers/implicit_layer.cpp \ - layers/channels_layer.cpp \ - layers/shortcut_layer.cpp \ - layers/route_layer.cpp \ - layers/upsample_layer.cpp \ - layers/pooling_layer.cpp \ - layers/activation_layer.cpp \ - layers/reorg_layer.cpp \ - layers/reduce_layer.cpp \ - layers/shuffle_layer.cpp \ - layers/softmax_layer.cpp \ - layers/cls_layer.cpp \ - layers/reg_layer.cpp \ - utils.cpp \ - yolo.cpp \ - yoloForward.cu \ - yoloForward_v2.cu \ - yoloForward_nc.cu \ - yoloForward_r.cu \ - yoloForward_e.cu \ - sortDetections.cu + +SRCFILES:= $(filter-out calibrator.cpp, $(wildcard *.cpp)) ifeq ($(OPENCV), 1) SRCFILES+= calibrator.cpp endif +SRCFILES+= $(wildcard layers/*.cpp) +SRCFILES+= $(wildcard *.cu) + TARGET_LIB:= libnvdsinfer_custom_impl_Yolo.so TARGET_OBJS:= $(SRCFILES:.cpp=.o) diff --git a/utils/gen_wts_ppyoloe.py b/utils/gen_wts_ppyoloe.py index 85fe20d..8c985b2 100644 --- a/utils/gen_wts_ppyoloe.py +++ b/utils/gen_wts_ppyoloe.py @@ -333,7 +333,6 @@ def export_model(): merge_config(FLAGS.opt) check_config(cfg) - check_gpu(cfg.use_gpu) check_version() trainer = Trainer(cfg, mode='test') diff --git a/utils/gen_wts_yoloV5.py b/utils/gen_wts_yoloV5.py index 2990ce5..be8d71f 100644 --- a/utils/gen_wts_yoloV5.py +++ b/utils/gen_wts_yoloV5.py @@ -115,8 +115,10 @@ class Layers(object): self.current = child.i self.fc.write('\n# Concat\n') - r = self.get_route(child.f[1]) - self.route('-1, %d' % (r - 1)) + r = [] + for i in range(1, len(child.f)): + r.append(self.get_route(child.f[i])) + self.route('-1, %s' % str(r)[1:-1]) def Detect(self, child): self.current = child.i @@ -126,7 +128,7 @@ class Layers(object): for i, m in enumerate(child.m): r = self.get_route(child.f[i]) - self.route('%d' % (r - 1)) + self.route('%d' % r) self.convolutional(m, detect=True) self.yolo(i) @@ -137,12 +139,36 @@ class Layers(object): 'channels=3\n' + 'letter_box=1\n') + def CBH(self, child): + self.current = child.i + self.fc.write('\n# CBH\n') + + self.convolutional(child.conv, act='hardswish') + + def LC_Block(self, child): + self.current = child.i + self.fc.write('\n# LC_Block\n') + + self.convolutional(child.dw_conv, act='hardswish') + if child.use_se: + self.avgpool() + self.convolutional(child.se.conv1, act='relu') + self.convolutional(child.se.conv2, act='silu') + self.shortcut(-4, ew='mul') + self.convolutional(child.pw_conv, act='hardswish') + + def Dense(self, child): + self.current = child.i + self.fc.write('\n# Dense\n') + + self.convolutional(child.dense_conv, act='hardswish') + def reorg(self): self.blocks[self.current] += 1 self.fc.write('\n[reorg]\n') - def convolutional(self, cv, detect=False): + def convolutional(self, cv, act=None, detect=False): self.blocks[self.current] += 1 self.get_state_dict(cv.state_dict()) @@ -164,7 +190,8 @@ class Layers(object): groups = cv.conv.groups bias = cv.conv.bias bn = True if hasattr(cv, 'bn') else False - act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear' + if act is None: + act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear' b = 'batch_normalize=1\n' if bn is True else '' g = 'groups=%d\n' % groups if groups > 1 else '' @@ -198,12 +225,15 @@ class Layers(object): self.fc.write('\n[route]\n' + 'layers=%s\n' % layers) - def shortcut(self, r, activation='linear'): + def shortcut(self, r, ew='add', act='linear'): self.blocks[self.current] += 1 + m = 'mode=mul\n' if ew == 'mul' else '' + self.fc.write('\n[shortcut]\n' + 'from=%d\n' % r + - 'activation=%s\n' % activation) + m + + 'activation=%s\n' % act) def maxpool(self, m): self.blocks[self.current] += 1 @@ -226,6 +256,11 @@ class Layers(object): self.fc.write('\n[upsample]\n' + 'stride=%d\n' % stride) + def avgpool(self): + self.blocks[self.current] += 1 + + self.fc.write('\n[avgpool]\n') + def yolo(self, i): self.blocks[self.current] += 1 @@ -272,12 +307,19 @@ class Layers(object): def get_route(self, n): r = 0 - for i, b in enumerate(self.blocks): - if i <= n: - r += b - else: - break - return r + if n < 0: + for i, b in enumerate(self.blocks[self.current-1::-1]): + if i < abs(n) - 1: + r -= b + else: + break + else: + for i, b in enumerate(self.blocks): + if i <= n: + r += b + else: + break + return r - 1 def get_activation(self, act): if act == 'Hardswish': @@ -286,6 +328,7 @@ class Layers(object): return 'leaky' elif act == 'SiLU': return 'silu' + return 'linear' def parse_args(): @@ -336,6 +379,12 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc: layers.Concat(child) elif child._get_name() == 'Detect': layers.Detect(child) + elif child._get_name() == 'CBH': + layers.CBH(child) + elif child._get_name() == 'LC_Block': + layers.LC_Block(child) + elif child._get_name() == 'Dense': + layers.Dense(child) else: raise SystemExit('Model not supported') diff --git a/utils/gen_wts_yoloV7.py b/utils/gen_wts_yoloV7.py new file mode 100644 index 0000000..d827fac --- /dev/null +++ b/utils/gen_wts_yoloV7.py @@ -0,0 +1,353 @@ +import argparse +import os +import struct +import torch +from utils.torch_utils import select_device + + +class Layers(object): + def __init__(self, n, size, fw, fc): + self.blocks = [0 for _ in range(n)] + self.current = 0 + + self.width = size[0] if len(size) == 1 else size[1] + self.height = size[0] + + self.num = 0 + self.nc = 0 + self.anchors = '' + self.masks = [] + + self.fw = fw + self.fc = fc + self.wc = 0 + + self.net() + + def ReOrg(self, child): + self.current = child.i + self.fc.write('\n# ReOrg\n') + + self.reorg() + + def Conv(self, child): + self.current = child.i + self.fc.write('\n# Conv\n') + + if child.f != -1: + r = self.get_route(child.f) + self.route('%d' % r) + self.convolutional(child) + + def DownC(self, child): + self.current = child.i + self.fc.write('\n# DownC\n') + + self.maxpool(child.mp) + self.convolutional(child.cv3) + self.route('-3') + self.convolutional(child.cv1) + self.convolutional(child.cv2) + self.route('-1, -4') + + def MP(self, child): + self.current = child.i + self.fc.write('\n# MP\n') + + self.maxpool(child.m) + + def SP(self, child): + self.current = child.i + self.fc.write('\n# SP\n') + + if child.f != -1: + r = self.get_route(child.f) + self.route('%d' % r) + self.maxpool(child.m) + + def SPPCSPC(self, child): + self.current = child.i + self.fc.write('\n# SPPCSPC\n') + + self.convolutional(child.cv2) + self.route('-2') + self.convolutional(child.cv1) + self.convolutional(child.cv3) + self.convolutional(child.cv4) + self.maxpool(child.m[0]) + self.route('-2') + self.maxpool(child.m[1]) + self.route('-4') + self.maxpool(child.m[2]) + self.route('-6, -5, -3, -1') + self.convolutional(child.cv5) + self.convolutional(child.cv6) + self.route('-1, -13') + self.convolutional(child.cv7) + + def RepConv(self, child): + self.current = child.i + self.fc.write('\n# RepConv\n') + + if child.f != -1: + r = self.get_route(child.f) + self.route('%d' % r) + self.convolutional(child.rbr_1x1) + self.route('-2') + self.convolutional(child.rbr_dense) + self.shortcut(-3, act=self.get_activation(child.act._get_name())) + + def Upsample(self, child): + self.current = child.i + self.fc.write('\n# Upsample\n') + + self.upsample(child) + + def Concat(self, child): + self.current = child.i + self.fc.write('\n# Concat\n') + + r = [] + for i in range(1, len(child.f)): + r.append(self.get_route(child.f[i])) + self.route('-1, %s' % str(r)[1:-1]) + + def Shortcut(self, child): + self.current = child.i + self.fc.write('\n# Shortcut\n') + + r = self.get_route(child.f[1]) + self.shortcut(r) + + def Detect(self, child): + self.current = child.i + self.fc.write('\n# Detect\n') + + self.get_anchors(child.state_dict(), child.m[0].out_channels) + + for i, m in enumerate(child.m): + r = self.get_route(child.f[i]) + self.route('%d' % r) + self.convolutional(m, detect=True) + self.yolo(i) + + def net(self): + self.fc.write('[net]\n' + + 'width=%d\n' % self.width + + 'height=%d\n' % self.height + + 'channels=3\n') + + def reorg(self): + self.blocks[self.current] += 1 + + self.fc.write('\n[reorg]\n') + + def convolutional(self, cv, act=None, detect=False): + self.blocks[self.current] += 1 + + self.get_state_dict(cv.state_dict()) + + if cv._get_name() == 'Conv2d': + filters = cv.out_channels + size = cv.kernel_size + stride = cv.stride + pad = cv.padding + groups = cv.groups + bias = cv.bias + bn = False + act = 'linear' if not detect else 'logistic' + elif cv._get_name() == 'Sequential': + filters = cv[0].out_channels + size = cv[0].kernel_size + stride = cv[0].stride + pad = cv[0].padding + groups = cv[0].groups + bias = cv[0].bias + bn = True if cv[1]._get_name() == 'BatchNorm2d' else False + act = 'linear' + else: + filters = cv.conv.out_channels + size = cv.conv.kernel_size + stride = cv.conv.stride + pad = cv.conv.padding + groups = cv.conv.groups + bias = cv.conv.bias + bn = True if hasattr(cv, 'bn') else False + if act is None: + act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear' + + b = 'batch_normalize=1\n' if bn is True else '' + g = 'groups=%d\n' % groups if groups > 1 else '' + w = 'bias=0\n' if bias is None and bn is False else '' + + self.fc.write('\n[convolutional]\n' + + b + + 'filters=%d\n' % filters + + 'size=%s\n' % self.get_value(size) + + 'stride=%s\n' % self.get_value(stride) + + 'pad=%s\n' % self.get_value(pad) + + g + + w + + 'activation=%s\n' % act) + + def route(self, layers): + self.blocks[self.current] += 1 + + self.fc.write('\n[route]\n' + + 'layers=%s\n' % layers) + + def shortcut(self, r, act='linear'): + self.blocks[self.current] += 1 + + self.fc.write('\n[shortcut]\n' + + 'from=%d\n' % r + + 'activation=%s\n' % act) + + def maxpool(self, m): + self.blocks[self.current] += 1 + + stride = m.stride + size = m.kernel_size + mode = m.ceil_mode + + m = 'maxpool_up' if mode else 'maxpool' + + self.fc.write('\n[%s]\n' % m + + 'stride=%d\n' % stride + + 'size=%d\n' % size) + + def upsample(self, child): + self.blocks[self.current] += 1 + + stride = child.scale_factor + + self.fc.write('\n[upsample]\n' + + 'stride=%d\n' % stride) + + def yolo(self, i): + self.blocks[self.current] += 1 + + self.fc.write('\n[yolo]\n' + + 'mask=%s\n' % self.masks[i] + + 'anchors=%s\n' % self.anchors + + 'classes=%d\n' % self.nc + + 'num=%d\n' % self.num + + 'scale_x_y=2.0\n' + + 'new_coords=1\n') + + def get_state_dict(self, state_dict): + for k, v in state_dict.items(): + if 'num_batches_tracked' not in k: + vr = v.reshape(-1).numpy() + self.fw.write('{} {} '.format(k, len(vr))) + for vv in vr: + self.fw.write(' ') + self.fw.write(struct.pack('>f', float(vv)).hex()) + self.fw.write('\n') + self.wc += 1 + + def get_anchors(self, state_dict, out_channels): + anchor_grid = state_dict['anchor_grid'] + aa = anchor_grid.reshape(-1).tolist() + am = anchor_grid.tolist() + + self.num = (len(aa) / 2) + self.nc = int((out_channels / (self.num / len(am))) - 5) + self.anchors = str(aa)[1:-1] + + n = 0 + for m in am: + mask = [] + for _ in range(len(m)): + mask.append(n) + n += 1 + self.masks.append(str(mask)[1:-1]) + + def get_value(self, key): + if type(key) == int: + return key + return key[0] if key[0] == key[1] else str(key)[1:-1] + + def get_route(self, n): + r = 0 + if n < 0: + for i, b in enumerate(self.blocks[self.current-1::-1]): + if i < abs(n) - 1: + r -= b + else: + break + else: + for i, b in enumerate(self.blocks): + if i <= n: + r += b + else: + break + return r - 1 + + def get_activation(self, act): + if act == 'Hardswish': + return 'hardswish' + elif act == 'LeakyReLU': + return 'leaky' + elif act == 'SiLU': + return 'silu' + return 'linear' + + +def parse_args(): + parser = argparse.ArgumentParser(description='PyTorch YOLOv7 conversion') + parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)') + parser.add_argument( + '-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])') + args = parser.parse_args() + if not os.path.isfile(args.weights): + raise SystemExit('Invalid weights file') + return args.weights, args.size + + +pt_file, inference_size = parse_args() + +model_name = os.path.basename(pt_file).split('.pt')[0] +wts_file = model_name + '.wts' if 'yolov7' in model_name else 'yolov7_' + model_name + '.wts' +cfg_file = model_name + '.cfg' if 'yolov7' in model_name else 'yolov7_' + model_name + '.cfg' + +device = select_device('cpu') +model = torch.load(pt_file, map_location=device) +model = model['ema' if model.get('ema') else 'model'].float() + +anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None] +delattr(model.model[-1], 'anchor_grid') +model.model[-1].register_buffer('anchor_grid', anchor_grid) + +model.to(device).eval() + +with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc: + layers = Layers(len(model.model), inference_size, fw, fc) + + for child in model.model.children(): + if child._get_name() == 'ReOrg': + layers.ReOrg(child) + elif child._get_name() == 'Conv': + layers.Conv(child) + elif child._get_name() == 'DownC': + layers.DownC(child) + elif child._get_name() == 'MP': + layers.MP(child) + elif child._get_name() == 'SP': + layers.SP(child) + elif child._get_name() == 'SPPCSPC': + layers.SPPCSPC(child) + elif child._get_name() == 'RepConv': + layers.RepConv(child) + elif child._get_name() == 'Upsample': + layers.Upsample(child) + elif child._get_name() == 'Concat': + layers.Concat(child) + elif child._get_name() == 'Shortcut': + layers.Shortcut(child) + elif child._get_name() == 'Detect': + layers.Detect(child) + else: + raise SystemExit('Model not supported') + +os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))