diff --git a/utils/gen_wts_yoloV8.py b/utils/gen_wts_yoloV8.py index b583b91..8688e6e 100644 --- a/utils/gen_wts_yoloV8.py +++ b/utils/gen_wts_yoloV8.py @@ -3,6 +3,7 @@ import os import struct import torch from ultralytics.yolo.utils.torch_utils import select_device +from ultralytics.yolo.utils.tal import make_anchors class Layers(object): @@ -306,7 +307,15 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc: layers.Concat(child) elif child._get_name() == 'Detect': layers.Detect(child) - layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1])) + if child.anchors.nelement() > 0 and child.strides.nelement() > 0: + layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1])) + else: + x = [] + for stride in model.stride.tolist(): + x.append(torch.zeros([1, 1, int(layers.height / stride), int(layers.width / stride)], + dtype=torch.float32)) + anchor_points, stride_tensor = (x.transpose(0, 1) for x in make_anchors(x, child.stride, 0.5)) + layers.get_anchors(anchor_points.reshape([-1]), stride_tensor.reshape([-1])) else: raise SystemExit('Model not supported')