Fix YOLOv8 conversion
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
|||||||
import struct
|
import struct
|
||||||
import torch
|
import torch
|
||||||
from ultralytics.yolo.utils.torch_utils import select_device
|
from ultralytics.yolo.utils.torch_utils import select_device
|
||||||
|
from ultralytics.yolo.utils.tal import make_anchors
|
||||||
|
|
||||||
|
|
||||||
class Layers(object):
|
class Layers(object):
|
||||||
@@ -306,7 +307,15 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
|
|||||||
layers.Concat(child)
|
layers.Concat(child)
|
||||||
elif child._get_name() == 'Detect':
|
elif child._get_name() == 'Detect':
|
||||||
layers.Detect(child)
|
layers.Detect(child)
|
||||||
layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1]))
|
if child.anchors.nelement() > 0 and child.strides.nelement() > 0:
|
||||||
|
layers.get_anchors(child.anchors.reshape([-1]), child.strides.reshape([-1]))
|
||||||
|
else:
|
||||||
|
x = []
|
||||||
|
for stride in model.stride.tolist():
|
||||||
|
x.append(torch.zeros([1, 1, int(layers.height / stride), int(layers.width / stride)],
|
||||||
|
dtype=torch.float32))
|
||||||
|
anchor_points, stride_tensor = (x.transpose(0, 1) for x in make_anchors(x, child.stride, 0.5))
|
||||||
|
layers.get_anchors(anchor_points.reshape([-1]), stride_tensor.reshape([-1]))
|
||||||
else:
|
else:
|
||||||
raise SystemExit('Model not supported')
|
raise SystemExit('Model not supported')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user