From d09879d557eead3c2c76e6f38239abef8577103d Mon Sep 17 00:00:00 2001 From: Marcos Luciano Date: Thu, 21 Jul 2022 11:30:20 -0300 Subject: [PATCH] Fix gen_wts_yoloV5.wts --- utils/gen_wts_yoloV5.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/utils/gen_wts_yoloV5.py b/utils/gen_wts_yoloV5.py index 1cb12d5..2990ce5 100644 --- a/utils/gen_wts_yoloV5.py +++ b/utils/gen_wts_yoloV5.py @@ -173,9 +173,9 @@ class Layers(object): self.fc.write('\n[convolutional]\n' + b + 'filters=%d\n' % filters + - 'size=%s\n' % (size[0] if type(size) != int and size[0] == size[1] else str(size)[1:-1]) + - 'stride=%s\n' % (stride[0] if type(stride) != int and stride[0] == stride[1] else str(stride)[1:-1]) + - 'pad=%s\n' % (pad[0] if type(pad) != int and pad[0] == pad[1] else str(pad)[1:-1]) + + 'size=%s\n' % self.get_value(size) + + 'stride=%s\n' % self.get_value(stride) + + 'pad=%s\n' % self.get_value(pad) + g + w + 'activation=%s\n' % act) @@ -265,6 +265,11 @@ class Layers(object): n += 1 self.masks.append(str(mask)[1:-1]) + def get_value(self, key): + if type(key) == int: + return key + return key[0] if key[0] == key[1] else str(key)[1:-1] + def get_route(self, n): r = 0 for i, b in enumerate(self.blocks):