New YOLOv5 conversion and support (>= v2.0)
This commit is contained in:
@@ -1,108 +1,305 @@
|
||||
import argparse
|
||||
import yaml
|
||||
import math
|
||||
import os
|
||||
import struct
|
||||
import torch
|
||||
from utils.torch_utils import select_device
|
||||
|
||||
|
||||
class YoloLayers():
|
||||
def get_route(self, n, layers):
|
||||
route = 0
|
||||
for i, layer in enumerate(layers):
|
||||
if i <= n:
|
||||
route += layer[1]
|
||||
else:
|
||||
break
|
||||
return route
|
||||
class Layers(object):
|
||||
def __init__(self, n, size, fw, fc):
|
||||
self.blocks = [0 for _ in range(n)]
|
||||
self.current = 0
|
||||
|
||||
def route(self, layers=''):
|
||||
return '\n[route]\n' + \
|
||||
'layers=%s\n' % layers
|
||||
self.width = size[0] if len(size) == 1 else size[1]
|
||||
self.height = size[0]
|
||||
|
||||
self.num = 0
|
||||
self.nc = 0
|
||||
self.anchors = ''
|
||||
self.masks = []
|
||||
|
||||
self.fw = fw
|
||||
self.fc = fc
|
||||
self.wc = 0
|
||||
|
||||
self.net()
|
||||
|
||||
def Focus(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# Focus\n')
|
||||
|
||||
self.reorg()
|
||||
self.convolutional(child.conv)
|
||||
|
||||
def Conv(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# Conv\n')
|
||||
|
||||
self.convolutional(child)
|
||||
|
||||
def BottleneckCSP(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# BottleneckCSP\n')
|
||||
|
||||
self.convolutional(child.cv2)
|
||||
self.route('-2')
|
||||
self.convolutional(child.cv1)
|
||||
idx = -3
|
||||
for m in child.m:
|
||||
if m.add:
|
||||
self.convolutional(m.cv1)
|
||||
self.convolutional(m.cv2)
|
||||
self.shortcut(-3)
|
||||
idx -= 3
|
||||
else:
|
||||
self.convolutional(m.cv1)
|
||||
self.convolutional(m.cv2)
|
||||
idx -= 2
|
||||
self.convolutional(child.cv3)
|
||||
self.route('-1, %d' % (idx - 1))
|
||||
self.batchnorm(child.bn, child.act)
|
||||
self.convolutional(child.cv4)
|
||||
|
||||
def C3(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# C3\n')
|
||||
|
||||
self.convolutional(child.cv2)
|
||||
self.route('-2')
|
||||
self.convolutional(child.cv1)
|
||||
idx = -3
|
||||
for m in child.m:
|
||||
if m.add:
|
||||
self.convolutional(m.cv1)
|
||||
self.convolutional(m.cv2)
|
||||
self.shortcut(-3)
|
||||
idx -= 3
|
||||
else:
|
||||
self.convolutional(m.cv1)
|
||||
self.convolutional(m.cv2)
|
||||
idx -= 2
|
||||
self.route('-1, %d' % idx)
|
||||
self.convolutional(child.cv3)
|
||||
|
||||
def SPP(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# SPP\n')
|
||||
|
||||
self.convolutional(child.cv1)
|
||||
self.maxpool(child.m[0])
|
||||
self.route('-2')
|
||||
self.maxpool(child.m[1])
|
||||
self.route('-4')
|
||||
self.maxpool(child.m[2])
|
||||
self.route('-6, -5, -3, -1')
|
||||
self.convolutional(child.cv2)
|
||||
|
||||
def SPPF(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# SPPF\n')
|
||||
|
||||
self.convolutional(child.cv1)
|
||||
self.maxpool(child.m)
|
||||
self.maxpool(child.m)
|
||||
self.maxpool(child.m)
|
||||
self.route('-4, -3, -2, -1')
|
||||
self.convolutional(child.cv2)
|
||||
|
||||
def Upsample(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# Upsample\n')
|
||||
|
||||
self.upsample(child)
|
||||
|
||||
def Concat(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# Concat\n')
|
||||
|
||||
r = self.get_route(child.f[1])
|
||||
self.route('-1, %d' % (r - 1))
|
||||
|
||||
def Detect(self, child):
|
||||
self.current = child.i
|
||||
self.fc.write('\n# Detect\n')
|
||||
|
||||
self.get_anchors(child.state_dict(), child.m[0].out_channels)
|
||||
|
||||
for i, m in enumerate(child.m):
|
||||
r = self.get_route(child.f[i])
|
||||
self.route('%d' % (r - 1))
|
||||
self.convolutional(m, detect=True)
|
||||
self.yolo(i)
|
||||
|
||||
def net(self):
|
||||
self.fc.write('[net]\n' +
|
||||
'width=%d\n' % self.width +
|
||||
'height=%d\n' % self.height +
|
||||
'channels=3\n' +
|
||||
'letter_box=1\n')
|
||||
|
||||
def reorg(self):
|
||||
return '\n[reorg]\n'
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
def shortcut(self, route=-1, activation='linear'):
|
||||
return '\n[shortcut]\n' + \
|
||||
'from=%d\n' % route + \
|
||||
'activation=%s\n' % activation
|
||||
self.fc.write('\n[reorg]\n')
|
||||
|
||||
def maxpool(self, stride=1, size=1):
|
||||
return '\n[maxpool]\n' + \
|
||||
'stride=%d\n' % stride + \
|
||||
'size=%d\n' % size
|
||||
def convolutional(self, cv, detect=False):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
def upsample(self, stride=1):
|
||||
return '\n[upsample]\n' + \
|
||||
'stride=%d\n' % stride
|
||||
self.get_state_dict(cv.state_dict())
|
||||
|
||||
if cv._get_name() == 'Conv2d':
|
||||
filters = cv.out_channels
|
||||
size = cv.kernel_size
|
||||
stride = cv.stride
|
||||
pad = cv.padding
|
||||
groups = cv.groups
|
||||
bias = cv.bias
|
||||
bn = False
|
||||
act = 'linear' if not detect else 'logistic'
|
||||
else:
|
||||
filters = cv.conv.out_channels
|
||||
size = cv.conv.kernel_size
|
||||
stride = cv.conv.stride
|
||||
pad = cv.conv.padding
|
||||
groups = cv.conv.groups
|
||||
bias = cv.conv.bias
|
||||
bn = True if hasattr(cv, 'bn') else False
|
||||
act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear'
|
||||
|
||||
def convolutional(self, bn=False, size=1, stride=1, pad=1, filters=1, groups=1, activation='linear'):
|
||||
b = 'batch_normalize=1\n' if bn is True else ''
|
||||
g = 'groups=%d\n' % groups if groups > 1 else ''
|
||||
return '\n[convolutional]\n' + \
|
||||
b + \
|
||||
'filters=%d\n' % filters + \
|
||||
'size=%d\n' % size + \
|
||||
'stride=%d\n' % stride + \
|
||||
'pad=%d\n' % pad + \
|
||||
g + \
|
||||
'activation=%s\n' % activation
|
||||
w = 'bias=0\n' if bias is None and bn is False else ''
|
||||
|
||||
def yolo(self, mask='', anchors='', classes=80, num=3):
|
||||
return '\n[yolo]\n' + \
|
||||
'mask=%s\n' % mask + \
|
||||
'anchors=%s\n' % anchors + \
|
||||
'classes=%d\n' % classes + \
|
||||
'num=%d\n' % num + \
|
||||
'scale_x_y=2.0\n' + \
|
||||
'beta_nms=0.6\n' + \
|
||||
'new_coords=1\n'
|
||||
self.fc.write('\n[convolutional]\n' +
|
||||
b +
|
||||
'filters=%d\n' % filters +
|
||||
'size=%s\n' % (size[0] if len(size) == 2 and size[0] == size[1] else str(size)[1:-1]) +
|
||||
'stride=%s\n' % (stride[0] if len(stride) == 2 and stride[0] == stride[1] else str(stride)[1:-1]) +
|
||||
'pad=%s\n' % (pad[0] if len(pad) == 2 and pad[0] == pad[1] else str(pad)[1:-1]) +
|
||||
g +
|
||||
w +
|
||||
'activation=%s\n' % act)
|
||||
|
||||
def batchnorm(self, bn, act):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
self.get_state_dict(bn.state_dict())
|
||||
|
||||
filters = bn.num_features
|
||||
act = self.get_activation(act._get_name())
|
||||
|
||||
self.fc.write('\n[batchnorm]\n' +
|
||||
'filters=%d\n' % filters +
|
||||
'activation=%s\n' % act)
|
||||
|
||||
def route(self, layers):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
self.fc.write('\n[route]\n' +
|
||||
'layers=%s\n' % layers)
|
||||
|
||||
def shortcut(self, r, activation='linear'):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
self.fc.write('\n[shortcut]\n' +
|
||||
'from=%d\n' % r +
|
||||
'activation=%s\n' % activation)
|
||||
|
||||
def maxpool(self, m):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
stride = m.stride
|
||||
size = m.kernel_size
|
||||
mode = m.ceil_mode
|
||||
|
||||
m = 'maxpool_up' if mode else 'maxpool'
|
||||
|
||||
self.fc.write('\n[%s]\n' % m +
|
||||
'stride=%d\n' % stride +
|
||||
'size=%d\n' % size)
|
||||
|
||||
def upsample(self, child):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
stride = child.scale_factor
|
||||
|
||||
self.fc.write('\n[upsample]\n' +
|
||||
'stride=%d\n' % stride)
|
||||
|
||||
def yolo(self, i):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
self.fc.write('\n[yolo]\n' +
|
||||
'mask=%s\n' % self.masks[i] +
|
||||
'anchors=%s\n' % self.anchors +
|
||||
'classes=%d\n' % self.nc +
|
||||
'num=%d\n' % self.num +
|
||||
'scale_x_y=2.0\n' +
|
||||
'new_coords=1\n')
|
||||
|
||||
def get_state_dict(self, state_dict):
|
||||
for k, v in state_dict.items():
|
||||
if 'num_batches_tracked' not in k:
|
||||
vr = v.reshape(-1).numpy()
|
||||
self.fw.write('{} {} '.format(k, len(vr)))
|
||||
for vv in vr:
|
||||
self.fw.write(' ')
|
||||
self.fw.write(struct.pack('>f', float(vv)).hex())
|
||||
self.fw.write('\n')
|
||||
self.wc += 1
|
||||
|
||||
def get_anchors(self, state_dict, out_channels):
|
||||
anchor_grid = state_dict['anchor_grid']
|
||||
aa = anchor_grid.reshape(-1).tolist()
|
||||
am = anchor_grid.tolist()
|
||||
|
||||
self.num = (len(aa) / 2)
|
||||
self.nc = int((out_channels / (self.num / len(am))) - 5)
|
||||
self.anchors = str(aa)[1:-1]
|
||||
|
||||
n = 0
|
||||
for m in am:
|
||||
mask = []
|
||||
for _ in range(len(m)):
|
||||
mask.append(n)
|
||||
n += 1
|
||||
self.masks.append(str(mask)[1:-1])
|
||||
|
||||
def get_route(self, n):
|
||||
r = 0
|
||||
for i, b in enumerate(self.blocks):
|
||||
if i <= n:
|
||||
r += b
|
||||
else:
|
||||
break
|
||||
return r
|
||||
|
||||
def get_activation(self, act):
|
||||
if act == 'Hardswish':
|
||||
return 'hardswish'
|
||||
elif act == 'LeakyReLU':
|
||||
return 'leaky'
|
||||
elif act == 'SiLU':
|
||||
return 'silu'
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='PyTorch YOLOv5 conversion')
|
||||
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
|
||||
parser.add_argument('-c', '--yaml', help='Input cfg (.yaml) file path')
|
||||
parser.add_argument(
|
||||
'-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])')
|
||||
args = parser.parse_args()
|
||||
if not os.path.isfile(args.weights):
|
||||
raise SystemExit('Invalid weights file')
|
||||
if not args.yaml:
|
||||
args.yaml = ''
|
||||
return args.weights, args.yaml, args.size
|
||||
return args.weights, args.size
|
||||
|
||||
|
||||
def get_width(x, gw, divisor=8):
|
||||
return int(math.ceil((x * gw) / divisor)) * divisor
|
||||
|
||||
|
||||
def get_depth(x, gd):
|
||||
if x == 1:
|
||||
return 1
|
||||
r = int(round(x * gd))
|
||||
if x * gd - int(x * gd) == 0.5 and int(x * gd) % 2 == 0:
|
||||
r -= 1
|
||||
return max(r, 1)
|
||||
|
||||
|
||||
pt_file, yaml_file, inference_size = parse_args()
|
||||
pt_file, inference_size = parse_args()
|
||||
|
||||
model_name = os.path.basename(pt_file).split('.pt')[0]
|
||||
wts_file = model_name + '.wts' if 'yolov5' in model_name else 'yolov5_' + model_name + '.wts'
|
||||
cfg_file = model_name + '.cfg' if 'yolov5' in model_name else 'yolov5_' + model_name + '.cfg'
|
||||
|
||||
if yaml_file == '':
|
||||
yaml_file = 'models/' + model_name + '.yaml'
|
||||
if not os.path.isfile(yaml_file):
|
||||
yaml_file = 'models/hub/' + model_name + '.yaml'
|
||||
if not os.path.isfile(yaml_file):
|
||||
raise SystemExit('YAML file not found')
|
||||
elif not os.path.isfile(yaml_file):
|
||||
raise SystemExit('Invalid YAML file')
|
||||
|
||||
device = select_device('cpu')
|
||||
model = torch.load(pt_file, map_location=device)['model'].float()
|
||||
|
||||
@@ -112,217 +309,29 @@ model.model[-1].register_buffer('anchor_grid', anchor_grid)
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
nc = 0
|
||||
anchors = ''
|
||||
masks = []
|
||||
with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
|
||||
layers = Layers(len(model.model), inference_size, fw, fc)
|
||||
|
||||
yolo_idx = 0
|
||||
spp_idx = 0
|
||||
for child in model.model.children():
|
||||
if child._get_name() == 'Focus':
|
||||
layers.Focus(child)
|
||||
elif child._get_name() == 'Conv':
|
||||
layers.Conv(child)
|
||||
elif child._get_name() == 'BottleneckCSP':
|
||||
layers.BottleneckCSP(child)
|
||||
elif child._get_name() == 'C3':
|
||||
layers.C3(child)
|
||||
elif child._get_name() == 'SPP':
|
||||
layers.SPP(child)
|
||||
elif child._get_name() == 'SPPF':
|
||||
layers.SPPF(child)
|
||||
elif child._get_name() == 'Upsample':
|
||||
layers.Upsample(child)
|
||||
elif child._get_name() == 'Concat':
|
||||
layers.Concat(child)
|
||||
elif child._get_name() == 'Detect':
|
||||
layers.Detect(child)
|
||||
else:
|
||||
raise SystemExit('Model not supported')
|
||||
|
||||
for k, v in model.state_dict().items():
|
||||
if 'anchor_grid' in k:
|
||||
yolo_idx = int(k.split('.')[1])
|
||||
vr = v.cpu().numpy().tolist()
|
||||
a = v.reshape(-1).cpu().numpy().astype(float).tolist()
|
||||
anchors = str(a)[1:-1]
|
||||
num = 0
|
||||
for m in vr:
|
||||
mask = []
|
||||
for _ in range(len(m)):
|
||||
mask.append(num)
|
||||
num += 1
|
||||
masks.append(mask)
|
||||
elif '.%d.m.0.weight' % yolo_idx in k:
|
||||
vr = v.cpu().numpy().tolist()
|
||||
nc = int((len(vr) / len(masks[0])) - 5)
|
||||
|
||||
with open(cfg_file, 'w') as c:
|
||||
with open(yaml_file, 'r', encoding='utf-8') as f:
|
||||
c.write('[net]\n')
|
||||
c.write('width=%d\n' % (inference_size[0] if len(inference_size) == 1 else inference_size[1]))
|
||||
c.write('height=%d\n' % inference_size[0])
|
||||
c.write('channels=3\n')
|
||||
c.write('letter_box=1\n')
|
||||
depth_multiple = 0
|
||||
width_multiple = 0
|
||||
layers = []
|
||||
yoloLayers = YoloLayers()
|
||||
f = yaml.load(f, Loader=yaml.FullLoader)
|
||||
for topic in f:
|
||||
if topic == 'depth_multiple':
|
||||
depth_multiple = f[topic]
|
||||
elif topic == 'width_multiple':
|
||||
width_multiple = f[topic]
|
||||
elif topic == 'backbone' or topic == 'head':
|
||||
for v in f[topic]:
|
||||
if v[2] == 'Focus':
|
||||
layer = '\n# Focus\n'
|
||||
blocks = 0
|
||||
layer += yoloLayers.reorg()
|
||||
blocks += 1
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple), size=v[3][1],
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
if v[2] == 'Conv':
|
||||
layer = '\n# Conv\n'
|
||||
blocks = 0
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple), size=v[3][1],
|
||||
stride=v[3][2], activation='silu')
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == 'C3':
|
||||
layer = '\n# C3\n'
|
||||
blocks = 0
|
||||
# SPLIT
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple) / 2,
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layer += yoloLayers.route(layers='-2')
|
||||
blocks += 1
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple) / 2,
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
# Residual Block
|
||||
if len(v[3]) == 1 or v[3][1] is True:
|
||||
for _ in range(get_depth(v[1], depth_multiple)):
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple) / 2,
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple) / 2,
|
||||
size=3, activation='silu')
|
||||
blocks += 1
|
||||
layer += yoloLayers.shortcut(route=-3)
|
||||
blocks += 1
|
||||
# Merge
|
||||
layer += yoloLayers.route(layers='-1, -%d' % (3 * get_depth(v[1], depth_multiple) + 3))
|
||||
blocks += 1
|
||||
else:
|
||||
for _ in range(get_depth(v[1], depth_multiple)):
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple) / 2,
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple) / 2,
|
||||
size=3, activation='silu')
|
||||
blocks += 1
|
||||
# Merge
|
||||
layer += yoloLayers.route(layers='-1, -%d' % (2 * get_depth(v[1], depth_multiple) + 3))
|
||||
blocks += 1
|
||||
# Transition
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple),
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == 'SPP':
|
||||
spp_idx = len(layers)
|
||||
layer = '\n# SPP\n'
|
||||
blocks = 0
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple) / 2,
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layer += yoloLayers.maxpool(size=v[3][1][0])
|
||||
blocks += 1
|
||||
layer += yoloLayers.route(layers='-2')
|
||||
blocks += 1
|
||||
layer += yoloLayers.maxpool(size=v[3][1][1])
|
||||
blocks += 1
|
||||
layer += yoloLayers.route(layers='-4')
|
||||
blocks += 1
|
||||
layer += yoloLayers.maxpool(size=v[3][1][2])
|
||||
blocks += 1
|
||||
layer += yoloLayers.route(layers='-6, -5, -3, -1')
|
||||
blocks += 1
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple),
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == 'SPPF':
|
||||
spp_idx = len(layers)
|
||||
layer = '\n# SPPF\n'
|
||||
blocks = 0
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple) / 2,
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layer += yoloLayers.maxpool(size=v[3][1])
|
||||
blocks += 1
|
||||
layer += yoloLayers.maxpool(size=v[3][1])
|
||||
blocks += 1
|
||||
layer += yoloLayers.maxpool(size=v[3][1])
|
||||
blocks += 1
|
||||
layer += yoloLayers.route(layers='-4, -3, -2, -1')
|
||||
blocks += 1
|
||||
layer += yoloLayers.convolutional(bn=True, filters=get_width(v[3][0], width_multiple),
|
||||
activation='silu')
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == 'nn.Upsample':
|
||||
layer = '\n# nn.Upsample\n'
|
||||
blocks = 0
|
||||
layer += yoloLayers.upsample(stride=v[3][1])
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == 'Concat':
|
||||
route = v[0][1]
|
||||
route = yoloLayers.get_route(route, layers) if route > 0 else \
|
||||
yoloLayers.get_route(len(layers) + route, layers)
|
||||
layer = '\n# Concat\n'
|
||||
blocks = 0
|
||||
layer += yoloLayers.route(layers='-1, %d' % (route - 1))
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
elif v[2] == 'Detect':
|
||||
for i, n in enumerate(v[0]):
|
||||
route = yoloLayers.get_route(n, layers)
|
||||
layer = '\n# Detect\n'
|
||||
blocks = 0
|
||||
layer += yoloLayers.route(layers='%d' % (route - 1))
|
||||
blocks += 1
|
||||
layer += yoloLayers.convolutional(filters=((nc + 5) * len(masks[i])), activation='logistic')
|
||||
blocks += 1
|
||||
layer += yoloLayers.yolo(mask=str(masks[i])[1:-1], anchors=anchors, classes=nc, num=num)
|
||||
blocks += 1
|
||||
layers.append([layer, blocks])
|
||||
for layer in layers:
|
||||
c.write(layer[0])
|
||||
|
||||
with open(wts_file, 'w') as f:
|
||||
wts_write = ''
|
||||
conv_count = 0
|
||||
cv1 = ''
|
||||
cv3 = ''
|
||||
cv3_idx = 0
|
||||
for k, v in model.state_dict().items():
|
||||
if 'num_batches_tracked' not in k and 'anchors' not in k and 'anchor_grid' not in k:
|
||||
vr = v.reshape(-1).cpu().numpy()
|
||||
idx = int(k.split('.')[1])
|
||||
if '.cv1.' in k and '.m.' not in k and idx != spp_idx:
|
||||
cv1 += '{} {} '.format(k, len(vr))
|
||||
for vv in vr:
|
||||
cv1 += ' '
|
||||
cv1 += struct.pack('>f', float(vv)).hex()
|
||||
cv1 += '\n'
|
||||
conv_count += 1
|
||||
elif cv1 != '' and '.m.' in k:
|
||||
wts_write += cv1
|
||||
cv1 = ''
|
||||
if '.cv3.' in k:
|
||||
cv3 += '{} {} '.format(k, len(vr))
|
||||
for vv in vr:
|
||||
cv3 += ' '
|
||||
cv3 += struct.pack('>f', float(vv)).hex()
|
||||
cv3 += '\n'
|
||||
cv3_idx = idx
|
||||
conv_count += 1
|
||||
elif cv3 != '' and cv3_idx != idx:
|
||||
wts_write += cv3
|
||||
cv3 = ''
|
||||
cv3_idx = 0
|
||||
if '.cv3.' not in k and not ('.cv1.' in k and '.m.' not in k and idx != spp_idx):
|
||||
wts_write += '{} {} '.format(k, len(vr))
|
||||
for vv in vr:
|
||||
wts_write += ' '
|
||||
wts_write += struct.pack('>f', float(vv)).hex()
|
||||
wts_write += '\n'
|
||||
conv_count += 1
|
||||
f.write('{}\n'.format(conv_count))
|
||||
f.write(wts_write)
|
||||
os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))
|
||||
|
||||
Reference in New Issue
Block a user