175 lines
5.5 KiB
Python
175 lines
5.5 KiB
Python
"""Copyright(c) 2024 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from urllib.parse import urlparse
|
|
|
|
ROOT = Path(__file__).absolute().parent / 'rtdetrv2_pytorch'
|
|
sys.path.append(str(ROOT))
|
|
|
|
from src.core import YAMLConfig
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
dependencies = ['torch', 'torchvision',]
|
|
|
|
|
|
def _load_checkpoint(path: str, map_location='cpu'):
|
|
scheme = urlparse(str(path)).scheme
|
|
if not scheme:
|
|
state = torch.load(path, map_location=map_location)
|
|
else:
|
|
state = torch.hub.load_state_dict_from_url(path, map_location=map_location)
|
|
return state
|
|
|
|
|
|
def _build_model(args, ):
|
|
"""main
|
|
"""
|
|
cfg = YAMLConfig(args.config)
|
|
|
|
if args.resume:
|
|
checkpoint = _load_checkpoint(args.resume, map_location='cpu')
|
|
if 'ema' in checkpoint:
|
|
state = checkpoint['ema']['module']
|
|
else:
|
|
state = checkpoint['model']
|
|
|
|
# NOTE load train mode state
|
|
cfg.model.load_state_dict(state)
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, ) -> None:
|
|
super().__init__()
|
|
self.model = cfg.model.deploy()
|
|
self.postprocessor = cfg.postprocessor.deploy()
|
|
|
|
def forward(self, images, orig_target_sizes):
|
|
outputs = self.model(images)
|
|
outputs = self.postprocessor(outputs, orig_target_sizes)
|
|
return outputs
|
|
|
|
return Model()
|
|
|
|
|
|
CONFIG = {
|
|
# rtdetr
|
|
'rtdetr_r18vd': {
|
|
'config': ROOT / 'configs/rtdetr/rtdetr_r18vd_6x_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_dec3_6x_coco_from_paddle.pth',
|
|
},
|
|
'rtdetr_r34vd': {
|
|
'config': ROOT / 'configs/rtdetr/rtdetr_r34vd_6x_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r34vd_dec4_6x_coco_from_paddle.pth',
|
|
},
|
|
'rtdetr_r50vd_m': {
|
|
'config': ROOT / 'configs/rtdetr/rtdetr_r50vd_m_6x_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_m_6x_coco_from_paddle.pth',
|
|
},
|
|
'rtdetr_r50vd': {
|
|
'config': ROOT / 'configs/rtdetr/rtdetr_r50vd_6x_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_6x_coco_from_paddle.pth',
|
|
},
|
|
'rtdetr_r101vd': {
|
|
'config': ROOT / 'configs/rtdetr/rtdetr_r101vd_6x_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_6x_coco_from_paddle.pth',
|
|
},
|
|
|
|
# rtdetrv2
|
|
'rtdetrv2_r18vd': {
|
|
'config': ROOT / 'configs/rtdetrv2/rtdetrv2_r18vd_120e_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.2/rtdetrv2_r18vd_120e_coco_rerun_48.1.pth',
|
|
},
|
|
'rtdetrv2_r34vd': {
|
|
'config': ROOT / 'configs/rtdetrv2/rtdetrv2_r34vd_120e_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r34vd_120e_coco_ema.pth',
|
|
},
|
|
'rtdetrv2_r50vd_m': {
|
|
'config': ROOT / 'configs/rtdetrv2/rtdetrv2_r50vd_m_7x_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_m_7x_coco_ema.pth',
|
|
},
|
|
'rtdetrv2_r50vd': {
|
|
'config': ROOT / 'configs/rtdetrv2/rtdetrv2_r50vd_6x_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r50vd_6x_coco_ema.pth',
|
|
},
|
|
'rtdetrv2_r101vd': {
|
|
'config': ROOT / 'configs/rtdetrv2/rtdetrv2_r101vd_6x_coco.yml',
|
|
'resume': 'https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetrv2_r101vd_6x_coco_from_paddle.pth',
|
|
},
|
|
}
|
|
|
|
|
|
# rtdetr
|
|
def rtdetr_r18vd(pretrained=True,):
|
|
args = type('Args', (), CONFIG['rtdetr_r18vd'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
def rtdetr_r34vd(pretrained=True,):
|
|
args = type('Args', (), CONFIG['rtdetr_r34vd'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
def rtdetr_r50vd_m(pretrained=True):
|
|
args = type('Args', (), CONFIG['rtdetr_r50vd_m'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
def rtdetr_r50vd(pretrained=True):
|
|
args = type('Args', (), CONFIG['rtdetr_r50vd'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
def rtdetr_r101vd(pretrained=True):
|
|
args = type('Args', (), CONFIG['rtdetr_r101vd'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
# rtdetrv2
|
|
def rtdetrv2_r18vd(pretrained=True,):
|
|
args = type('Args', (), CONFIG['rtdetrv2_r18vd'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
def rtdetrv2_r34vd(pretrained=True,):
|
|
args = type('Args', (), CONFIG['rtdetrv2_r34vd'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
def rtdetrv2_r50vd_m(pretrained=True):
|
|
args = type('Args', (), CONFIG['rtdetrv2_r50vd_m'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
def rtdetrv2_r50vd(pretrained=True):
|
|
args = type('Args', (), CONFIG['rtdetrv2_r50vd'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
def rtdetrv2_r101vd(pretrained=True):
|
|
args = type('Args', (), CONFIG['rtdetrv2_r101vd'])()
|
|
args.resume = args.resume if pretrained else ''
|
|
return _build_model(args, )
|
|
|
|
|
|
rtdetrv2_s = rtdetrv2_r18vd
|
|
rtdetrv2_m_r34 = rtdetrv2_r34vd
|
|
rtdetrv2_m_r50 = rtdetrv2_r50vd_m
|
|
rtdetrv2_l = rtdetrv2_r50vd
|
|
rtdetrv2_x = rtdetrv2_r101vd
|
|
|