From 842a22f6b96131c7fcc594aded79b2374499ead0 Mon Sep 17 00:00:00 2001 From: Marcos Luciano Date: Fri, 3 Feb 2023 12:55:46 -0300 Subject: [PATCH] Fix YOLOv8 conversion --- utils/gen_wts_yoloV8.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/utils/gen_wts_yoloV8.py b/utils/gen_wts_yoloV8.py index b583b91..8688e6e 100644 --- a/utils/gen_wts_yoloV8.py +++ b/utils/gen_wts_yoloV8.py @@ -3,6 +3,7 @@ import os import struct import torch from ultralytics.yolo.utils.torch_utils import select_device +from ultralytics.yolo.utils.tal import make_anchors class Layers(object): @@ -306,7 +307,15 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc: layers.Concat(child) elif child._get_name() == 'Detect': layers.Detect(child) - layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1])) + if child.anchors.nelement() > 0 and child.strides.nelement() > 0: + layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1])) + else: + x = [] + for stride in model.stride.tolist(): + x.append(torch.zeros([1, 1, int(layers.height / stride), int(layers.width / stride)], + dtype=torch.float32)) + anchor_points, stride_tensor = (x.transpose(0, 1) for x in make_anchors(x, child.stride, 0.5)) + layers.get_anchors(anchor_points.reshape([-1]), stride_tensor.reshape([-1])) else: raise SystemExit('Model not supported')