Add YOLOv8 support
This commit is contained in:
@@ -8,6 +8,7 @@ from ppdet.utils.cli import ArgsParser
|
||||
from ppdet.engine import Trainer
|
||||
from ppdet.slim import build_slim_model
|
||||
|
||||
|
||||
class Layers(object):
|
||||
def __init__(self, size, fw, fc, letter_box):
|
||||
self.blocks = [0 for _ in range(300)]
|
||||
@@ -123,7 +124,7 @@ class Layers(object):
|
||||
def Shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None, output=''):
|
||||
self.current += 1
|
||||
|
||||
r = 0
|
||||
r = None
|
||||
if route is not None:
|
||||
r = self.get_route(route)
|
||||
self.shuffle(reshape=reshape, transpose1=transpose1, transpose2=transpose2, route=r)
|
||||
@@ -156,7 +157,7 @@ class Layers(object):
|
||||
'channels=3\n' +
|
||||
lb)
|
||||
|
||||
def convolutional(self, cv, act='linear', detect=False):
|
||||
def convolutional(self, cv, act='linear'):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
self.get_state_dict(cv.state_dict())
|
||||
@@ -178,9 +179,6 @@ class Layers(object):
|
||||
bias = cv.conv.bias
|
||||
bn = True if hasattr(cv, 'bn') else False
|
||||
|
||||
if detect:
|
||||
act = 'logistic'
|
||||
|
||||
b = 'batch_normalize=1\n' if bn is True else ''
|
||||
g = 'groups=%d\n' % groups if groups > 1 else ''
|
||||
w = 'bias=0\n' if bias is None and bn is False else ''
|
||||
@@ -251,9 +249,9 @@ class Layers(object):
|
||||
def shuffle(self, reshape=None, transpose1=None, transpose2=None, route=None):
|
||||
self.blocks[self.current] += 1
|
||||
|
||||
r = 'reshape=%s\n' % str(reshape)[1:-1] if reshape is not None else ''
|
||||
t1 = 'transpose1=%s\n' % str(transpose1)[1:-1] if transpose1 is not None else ''
|
||||
t2 = 'transpose2=%s\n' % str(transpose2)[1:-1] if transpose2 is not None else ''
|
||||
r = 'reshape=%s\n' % ', '.join(str(x) for x in reshape) if reshape is not None else ''
|
||||
t1 = 'transpose1=%s\n' % ', '.join(str(x) for x in transpose1) if transpose1 is not None else ''
|
||||
t2 = 'transpose2=%s\n' % ', '.join(str(x) for x in transpose2) if transpose2 is not None else ''
|
||||
f = 'from=%d\n' % route if route is not None else ''
|
||||
|
||||
self.fc.write('\n[shuffle]\n' +
|
||||
@@ -419,13 +417,13 @@ with open(wts_file, 'w') as fw, open(cfg_file, 'w') as fc:
|
||||
layers.AvgPool2d()
|
||||
layers.ESEAttn(model.yolo_head.stem_cls[i])
|
||||
layers.Conv2D(model.yolo_head.pred_cls[i], act='sigmoid')
|
||||
layers.Shuffle(reshape=[model.yolo_head.num_classes, 0], route=feat, output='cls')
|
||||
layers.Shuffle(reshape=[model.yolo_head.num_classes, 'hw'], route=feat, output='cls')
|
||||
layers.ESEAttn(model.yolo_head.stem_reg[i], route=-7)
|
||||
layers.Conv2D(model.yolo_head.pred_reg[i])
|
||||
layers.Shuffle(reshape=[4, model.yolo_head.reg_max + 1, 0], transpose2=[1, 0, 2], route=feat)
|
||||
layers.Shuffle(reshape=[4, model.yolo_head.reg_max + 1, 'hw'], transpose2=[1, 0, 2], route=feat)
|
||||
layers.SoftMax(0)
|
||||
layers.Conv2D(model.yolo_head.proj_conv)
|
||||
layers.Shuffle(reshape=[4, 0], route=feat, output='reg')
|
||||
layers.Shuffle(reshape=[4, 'hw'], route=feat, output='reg')
|
||||
layers.Detect('cls')
|
||||
layers.Detect('reg')
|
||||
layers.get_anchors(model.yolo_head.anchor_points.reshape([-1]), model.yolo_head.stride_tensor)
|
||||
|
||||
Reference in New Issue
Block a user