373 lines
12 KiB
Python
373 lines
12 KiB
Python
import argparse
|
|
import os
|
|
import struct
|
|
import torch
|
|
from yolox.exp import get_exp
|
|
|
|
|
|
class Layers(object):
|
|
def __init__(self, size, fw, fc):
|
|
self.blocks = [0 for _ in range(300)]
|
|
self.current = -1
|
|
|
|
self.width = size[0] if len(size) == 1 else size[1]
|
|
self.height = size[0]
|
|
|
|
self.backbone_outs = []
|
|
self.fpn_feats = []
|
|
self.pan_feats = []
|
|
self.yolo_head = []
|
|
|
|
self.fw = fw
|
|
self.fc = fc
|
|
self.wc = 0
|
|
|
|
self.net()
|
|
|
|
def Conv(self, child):
|
|
self.current += 1
|
|
|
|
if child._get_name() == 'DWConv':
|
|
self.convolutional(child.dconv)
|
|
self.convolutional(child.pconv)
|
|
else:
|
|
self.convolutional(child)
|
|
|
|
def Focus(self, child):
|
|
self.current += 1
|
|
|
|
self.reorg()
|
|
self.convolutional(child.conv)
|
|
|
|
def BaseConv(self, child, stage='', act=None):
|
|
self.current += 1
|
|
|
|
self.convolutional(child, act=act)
|
|
if stage == 'fpn':
|
|
self.fpn_feats.append(self.current)
|
|
|
|
def CSPLayer(self, child, stage=''):
|
|
self.current += 1
|
|
|
|
self.convolutional(child.conv2)
|
|
self.route('-2')
|
|
self.convolutional(child.conv1)
|
|
idx = -3
|
|
for m in child.m:
|
|
if m.use_add:
|
|
self.convolutional(m.conv1)
|
|
if m.conv2._get_name() == 'DWConv':
|
|
self.convolutional(m.conv2.dconv)
|
|
self.convolutional(m.conv2.pconv)
|
|
self.shortcut(-4)
|
|
idx -= 4
|
|
else:
|
|
self.convolutional(m.conv2)
|
|
self.shortcut(-3)
|
|
idx -= 3
|
|
else:
|
|
self.convolutional(m.conv1)
|
|
if m.conv2._get_name() == 'DWConv':
|
|
self.convolutional(m.conv2.dconv)
|
|
self.convolutional(m.conv2.pconv)
|
|
idx -= 3
|
|
else:
|
|
self.convolutional(m.conv2)
|
|
idx -= 2
|
|
self.route('-1, %d' % idx)
|
|
self.convolutional(child.conv3)
|
|
if stage == 'backbone':
|
|
self.backbone_outs.append(self.current)
|
|
elif stage == 'pan':
|
|
self.pan_feats.append(self.current)
|
|
|
|
def SPPBottleneck(self, child):
|
|
self.current += 1
|
|
|
|
self.convolutional(child.conv1)
|
|
self.maxpool(child.m[0])
|
|
self.route('-2')
|
|
self.maxpool(child.m[1])
|
|
self.route('-4')
|
|
self.maxpool(child.m[2])
|
|
self.route('-6, -5, -3, -1')
|
|
self.convolutional(child.conv2)
|
|
|
|
def Upsample(self, child):
|
|
self.current += 1
|
|
|
|
self.upsample(child)
|
|
|
|
def Concat(self, route):
|
|
self.current += 1
|
|
|
|
r = self.get_route(route)
|
|
self.route('-1, %d' % r)
|
|
|
|
def Route(self, route):
|
|
self.current += 1
|
|
|
|
if route > 0:
|
|
r = self.get_route(route)
|
|
self.route('%d' % r)
|
|
else:
|
|
self.route('%d' % route)
|
|
|
|
def RouteShuffleOut(self, route):
|
|
self.current += 1
|
|
|
|
self.route(route)
|
|
self.shuffle(reshape=['c', 'hw'])
|
|
self.yolo_head.append(self.current)
|
|
|
|
def Detect(self, strides):
|
|
self.current += 1
|
|
|
|
routes = self.yolo_head[::-1]
|
|
|
|
for i, route in enumerate(routes):
|
|
routes[i] = self.get_route(route)
|
|
self.route(str(routes)[1:-1], axis=1)
|
|
self.shuffle(transpose1=[1, 0])
|
|
self.yolo(strides)
|
|
|
|
def net(self):
|
|
self.fc.write('[net]\n' +
|
|
'width=%d\n' % self.width +
|
|
'height=%d\n' % self.height +
|
|
'channels=3\n' +
|
|
'letter_box=1\n')
|
|
|
|
def reorg(self):
|
|
self.blocks[self.current] += 1
|
|
|
|
self.fc.write('\n[reorg]\n')
|
|
|
|
def convolutional(self, cv, act=None, detect=False):
|
|
self.blocks[self.current] += 1
|
|
|
|
self.get_state_dict(cv.state_dict())
|
|
|
|
if cv._get_name() == 'Conv2d':
|
|
filters = cv.out_channels
|
|
size = cv.kernel_size
|
|
stride = cv.stride
|
|
pad = cv.padding
|
|
groups = cv.groups
|
|
bias = cv.bias
|
|
bn = False
|
|
act = act if act is not None else 'linear'
|
|
else:
|
|
filters = cv.conv.out_channels
|
|
size = cv.conv.kernel_size
|
|
stride = cv.conv.stride
|
|
pad = cv.conv.padding
|
|
groups = cv.conv.groups
|
|
bias = cv.conv.bias
|
|
bn = True if hasattr(cv, 'bn') else False
|
|
if act is None:
|
|
act = self.get_activation(cv.act._get_name()) if hasattr(cv, 'act') else 'linear'
|
|
|
|
b = 'batch_normalize=1\n' if bn is True else ''
|
|
g = 'groups=%d\n' % groups if groups > 1 else ''
|
|
w = 'bias=0\n' if bias is None and bn is False else ''
|
|
|
|
self.fc.write('\n[convolutional]\n' +
|
|
b +
|
|
'filters=%d\n' % filters +
|
|
'size=%s\n' % self.get_value(size) +
|
|
'stride=%s\n' % self.get_value(stride) +
|
|
'pad=%s\n' % self.get_value(pad) +
|
|
g +
|
|
w +
|
|
'activation=%s\n' % act)
|
|
|
|
def route(self, layers, axis=0):
|
|
self.blocks[self.current] += 1
|
|
|
|
a = 'axis=%d\n' % axis if axis != 0 else ''
|
|
|
|
self.fc.write('\n[route]\n' +
|
|
'layers=%s\n' % layers +
|
|
a)
|
|
|
|
def shortcut(self, r, ew='add', act='linear'):
|
|
self.blocks[self.current] += 1
|
|
|
|
m = 'mode=mul\n' if ew == 'mul' else ''
|
|
|
|
self.fc.write('\n[shortcut]\n' +
|
|
'from=%d\n' % r +
|
|
m +
|
|
'activation=%s\n' % act)
|
|
|
|
def maxpool(self, m):
|
|
self.blocks[self.current] += 1
|
|
|
|
stride = m.stride
|
|
size = m.kernel_size
|
|
mode = m.ceil_mode
|
|
|
|
m = 'maxpool_up' if mode else 'maxpool'
|
|
|
|
self.fc.write('\n[%s]\n' % m +
|
|
'stride=%d\n' % stride +
|
|
'size=%d\n' % size)
|
|
|
|
def upsample(self, child):
|
|
self.blocks[self.current] += 1
|
|
|
|
stride = child.scale_factor
|
|
|
|
self.fc.write('\n[upsample]\n' +
|
|
'stride=%d\n' % stride)
|
|
|
|
def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None):
|
|
self.blocks[self.current] += 1
|
|
|
|
r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else ''
|
|
t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else ''
|
|
t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else ''
|
|
f = 'from=%d\n' % route if route is not None else ''
|
|
|
|
self.fc.write('\n[shuffle]\n' +
|
|
r +
|
|
t1 +
|
|
t2 +
|
|
f)
|
|
|
|
def yolo(self, strides):
|
|
self.blocks[self.current] += 1
|
|
|
|
self.fc.write('\n[detect_x]\n' +
|
|
'strides=%s\n' % str(strides)[1:-1])
|
|
|
|
def get_state_dict(self, state_dict):
|
|
for k, v in state_dict.items():
|
|
if 'num_batches_tracked' not in k:
|
|
vr = v.reshape(-1).numpy()
|
|
self.fw.write('{} {} '.format(k, len(vr)))
|
|
for vv in vr:
|
|
self.fw.write(' ')
|
|
self.fw.write(struct.pack('>f', float(vv)).hex())
|
|
self.fw.write('\n')
|
|
self.wc += 1
|
|
|
|
def get_value(self, key):
|
|
if type(key) == int:
|
|
return key
|
|
return key[0] if key[0] == key[1] else str(key)[1:-1]
|
|
|
|
def get_route(self, n):
|
|
r = 0
|
|
for i, b in enumerate(self.blocks):
|
|
if i <= n:
|
|
r += b
|
|
else:
|
|
break
|
|
return r - 1
|
|
|
|
def get_activation(self, act):
|
|
if act == 'Hardswish':
|
|
return 'hardswish'
|
|
elif act == 'LeakyReLU':
|
|
return 'leaky'
|
|
elif act == 'SiLU':
|
|
return 'silu'
|
|
return 'linear'
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='PyTorch YOLOX conversion')
|
|
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pth) file path (required)')
|
|
parser.add_argument('-e', '--exp', required=True, help='Input exp (.py) file path (required)')
|
|
args = parser.parse_args()
|
|
if not os.path.isfile(args.weights):
|
|
raise SystemExit('Invalid weights file')
|
|
if not os.path.isfile(args.exp):
|
|
raise SystemExit('Invalid exp file')
|
|
return args.weights, args.exp
|
|
|
|
|
|
pth_file, exp_file = parse_args()
|
|
|
|
exp = get_exp(exp_file)
|
|
model = exp.get_model()
|
|
model.load_state_dict(torch.load(pth_file, map_location='cpu')['model'])
|
|
model.to('cpu').eval()
|
|
|
|
model_name = exp.exp_name
|
|
inference_size = (exp.input_size[1], exp.input_size[0])
|
|
|
|
backbone = model.backbone._get_name()
|
|
head = model.head._get_name()
|
|
|
|
wts_file = model_name + '.wts' if 'yolox' in model_name else 'yolox_' + model_name + '.wts'
|
|
cfg_file = model_name + '.cfg' if 'yolox' in model_name else 'yolox_' + model_name + '.cfg'
|
|
|
|
with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
|
|
layers = Layers(inference_size, fw, fc)
|
|
|
|
if backbone == 'YOLOPAFPN':
|
|
layers.fc.write('\n# YOLOPAFPN\n')
|
|
|
|
layers.Focus(model.backbone.backbone.stem)
|
|
layers.Conv(model.backbone.backbone.dark2[0])
|
|
layers.CSPLayer(model.backbone.backbone.dark2[1])
|
|
layers.Conv(model.backbone.backbone.dark3[0])
|
|
layers.CSPLayer(model.backbone.backbone.dark3[1], 'backbone')
|
|
layers.Conv(model.backbone.backbone.dark4[0])
|
|
layers.CSPLayer(model.backbone.backbone.dark4[1], 'backbone')
|
|
layers.Conv(model.backbone.backbone.dark5[0])
|
|
layers.SPPBottleneck(model.backbone.backbone.dark5[1])
|
|
layers.CSPLayer(model.backbone.backbone.dark5[2], 'backbone')
|
|
layers.BaseConv(model.backbone.lateral_conv0, 'fpn')
|
|
layers.Upsample(model.backbone.upsample)
|
|
layers.Concat(layers.backbone_outs[1])
|
|
layers.CSPLayer(model.backbone.C3_p4)
|
|
layers.BaseConv(model.backbone.reduce_conv1, 'fpn')
|
|
layers.Upsample(model.backbone.upsample)
|
|
layers.Concat(layers.backbone_outs[0])
|
|
layers.CSPLayer(model.backbone.C3_p3, 'pan')
|
|
layers.Conv(model.backbone.bu_conv2)
|
|
layers.Concat(layers.fpn_feats[1])
|
|
layers.CSPLayer(model.backbone.C3_n3, 'pan')
|
|
layers.Conv(model.backbone.bu_conv1)
|
|
layers.Concat(layers.fpn_feats[0])
|
|
layers.CSPLayer(model.backbone.C3_n4, 'pan')
|
|
layers.pan_feats = layers.pan_feats[::-1]
|
|
else:
|
|
raise SystemExit('Model not supported')
|
|
|
|
if head == 'YOLOXHead':
|
|
layers.fc.write('\n# YOLOXHead\n')
|
|
|
|
for i, feat in enumerate(layers.pan_feats):
|
|
idx = len(layers.pan_feats) - i - 1
|
|
dw = True if model.head.cls_convs[idx][0]._get_name() == 'DWConv' else False
|
|
if i > 0:
|
|
layers.Route(feat)
|
|
layers.BaseConv(model.head.stems[idx])
|
|
layers.Conv(model.head.cls_convs[idx][0])
|
|
layers.Conv(model.head.cls_convs[idx][1])
|
|
layers.BaseConv(model.head.cls_preds[idx], act='logistic')
|
|
if dw:
|
|
layers.Route(-6)
|
|
else:
|
|
layers.Route(-4)
|
|
layers.Conv(model.head.reg_convs[idx][0])
|
|
layers.Conv(model.head.reg_convs[idx][1])
|
|
layers.BaseConv(model.head.obj_preds[idx], act='logistic')
|
|
layers.Route(-2)
|
|
layers.BaseConv(model.head.reg_preds[idx])
|
|
if dw:
|
|
layers.RouteShuffleOut('-1, -3, -9')
|
|
else:
|
|
layers.RouteShuffleOut('-1, -3, -7')
|
|
layers.Detect(model.head.strides)
|
|
|
|
else:
|
|
raise SystemExit('Model not supported')
|
|
|
|
os.system('echo "%d" | cat - %s > temp && mv temp %s' % (layers.wc, wts_file, wts_file))
|