first commit
This commit is contained in:
100
rtdetrv2_pytorch/tools/export_onnx.py
Normal file
100
rtdetrv2_pytorch/tools/export_onnx.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from src.core import YAMLConfig, yaml_utils
|
||||
|
||||
|
||||
def main(args, ):
|
||||
"""main
|
||||
"""
|
||||
update_dict = yaml_utils.parse_cli(args.update) if args.update else {}
|
||||
update_dict.update({k: v for k, v in args.__dict__.items() \
|
||||
if k not in ['update', ] and v is not None})
|
||||
cfg = YAMLConfig(args.config, **update_dict)
|
||||
|
||||
if args.resume:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
if 'ema' in checkpoint:
|
||||
state = checkpoint['ema']['module']
|
||||
else:
|
||||
state = checkpoint['model']
|
||||
|
||||
# NOTE load train mode state -> convert to deploy mode
|
||||
cfg.model.load_state_dict(state)
|
||||
|
||||
else:
|
||||
# raise AttributeError('Only support resume to load model.state_dict by now.')
|
||||
print('not load model.state_dict, use default init state dict...')
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__()
|
||||
self.model = cfg.model.deploy()
|
||||
self.postprocessor = cfg.postprocessor.deploy()
|
||||
|
||||
def forward(self, images, orig_target_sizes):
|
||||
outputs = self.model(images)
|
||||
outputs = self.postprocessor(outputs, orig_target_sizes)
|
||||
return outputs
|
||||
|
||||
model = Model()
|
||||
|
||||
data = torch.rand(1, 3, args.input_size, args.input_size)
|
||||
size = torch.tensor([[args.input_size, args.input_size]])
|
||||
_ = model(data, size)
|
||||
|
||||
dynamic_axes = {
|
||||
'images': {0: 'N', },
|
||||
'orig_target_sizes': {0: 'N'}
|
||||
}
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(data, size),
|
||||
args.output_file,
|
||||
input_names=['images', 'orig_target_sizes'],
|
||||
output_names=['labels', 'boxes', 'scores'],
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=16,
|
||||
verbose=False,
|
||||
do_constant_folding=True,
|
||||
)
|
||||
|
||||
if args.check:
|
||||
import onnx
|
||||
onnx_model = onnx.load(args.output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
print('Check export onnx model done...')
|
||||
|
||||
if args.simplify:
|
||||
import onnx
|
||||
import onnxsim
|
||||
dynamic = True
|
||||
# input_shapes = {'images': [1, 3, 640, 640], 'orig_target_sizes': [1, 2]} if dynamic else None
|
||||
input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else None
|
||||
onnx_model_simplify, check = onnxsim.simplify(args.output_file, input_shapes=input_shapes, dynamic_input_shape=dynamic)
|
||||
onnx.save(onnx_model_simplify, args.output_file)
|
||||
print(f'Simplify onnx model {check}...')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', '-c', type=str)
|
||||
parser.add_argument('--resume', '-r', type=str)
|
||||
parser.add_argument('--output_file', '-o', type=str, default='model.onnx')
|
||||
parser.add_argument('--input_size', '-s', type=int, default=640)
|
||||
parser.add_argument('--check', action='store_true', default=False)
|
||||
parser.add_argument('--simplify', action='store_true', default=False)
|
||||
parser.add_argument('--update', '-u', nargs='+', help='update yaml config')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user