From 940d244ad455627112c1cdb6b69e3fd5447e2a3d Mon Sep 17 00:00:00 2001 From: Marcos Luciano Date: Wed, 8 Feb 2023 01:02:35 -0300 Subject: [PATCH] Fix YOLOv8 conversion --- utils/gen_wts_yoloV8.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/utils/gen_wts_yoloV8.py b/utils/gen_wts_yoloV8.py index 8688e6e..1ba56e6 100644 --- a/utils/gen_wts_yoloV8.py +++ b/utils/gen_wts_yoloV8.py @@ -307,15 +307,11 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc: layers.Concat(child) elif child._get_name() == 'Detect': layers.Detect(child) - if child.anchors.nelement() > 0 and child.strides.nelement() > 0: - layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1])) - else: - x = [] - for stride in model.stride.tolist(): - x.append(torch.zeros([1, 1, int(layers.height / stride), int(layers.width / stride)], - dtype=torch.float32)) - anchor_points, stride_tensor = (x.transpose(0, 1) for x in make_anchors(x, child.stride, 0.5)) - layers.get_anchors(anchor_points.reshape([-1]), stride_tensor.reshape([-1])) + x = [] + for stride in model.stride.tolist(): + x.append(torch.zeros([1, 1, int(layers.height / stride), int(layers.width / stride)], dtype=torch.float32)) + anchor_points, stride_tensor = (x.transpose(0, 1) for x in make_anchors(x, child.stride, 0.5)) + layers.get_anchors(anchor_points.reshape([-1]), stride_tensor.reshape([-1])) else: raise SystemExit('Model not supported')