Fix YOLOv8 conversion

This commit is contained in:
Marcos Luciano
2023-02-08 01:02:35 -03:00
parent 842a22f6b9
commit 940d244ad4

View File

@@ -307,15 +307,11 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
layers.Concat(child) layers.Concat(child)
elif child._get_name() == 'Detect': elif child._get_name() == 'Detect':
layers.Detect(child) layers.Detect(child)
if child.anchors.nelement() > 0 and child.strides.nelement() > 0: x = []
layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1])) for stride in model.stride.tolist():
else: x.append(torch.zeros([1, 1, int(layers.height / stride), int(layers.width / stride)], dtype=torch.float32))
x = [] anchor_points, stride_tensor = (x.transpose(0, 1) for x in make_anchors(x, child.stride, 0.5))
for stride in model.stride.tolist(): layers.get_anchors(anchor_points.reshape([-1]), stride_tensor.reshape([-1]))
x.append(torch.zeros([1, 1, int(layers.height / stride), int(layers.width / stride)],
dtype=torch.float32))
anchor_points, stride_tensor = (x.transpose(0, 1) for x in make_anchors(x, child.stride, 0.5))
layers.get_anchors(anchor_points.reshape([-1]), stride_tensor.reshape([-1]))
else: else:
raise SystemExit('Model not supported') raise SystemExit('Model not supported')