first commit
This commit is contained in:
147
rtdetr_pytorch/tools/export_onnx.py
Normal file
147
rtdetr_pytorch/tools/export_onnx.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""by lyuwenyu
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from src.core import YAMLConfig
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def main(args, ):
|
||||
"""main
|
||||
"""
|
||||
cfg = YAMLConfig(args.config, resume=args.resume)
|
||||
|
||||
if args.resume:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
if 'ema' in checkpoint:
|
||||
state = checkpoint['ema']['module']
|
||||
else:
|
||||
state = checkpoint['model']
|
||||
else:
|
||||
raise AttributeError('only support resume to load model.state_dict by now.')
|
||||
|
||||
# NOTE load train mode state -> convert to deploy mode
|
||||
cfg.model.load_state_dict(state)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__()
|
||||
self.model = cfg.model.deploy()
|
||||
self.postprocessor = cfg.postprocessor.deploy()
|
||||
print(self.postprocessor.deploy_mode)
|
||||
|
||||
def forward(self, images, orig_target_sizes):
|
||||
outputs = self.model(images)
|
||||
return self.postprocessor(outputs, orig_target_sizes)
|
||||
|
||||
|
||||
model = Model()
|
||||
|
||||
dynamic_axes = {
|
||||
'images': {0: 'N', },
|
||||
'orig_target_sizes': {0: 'N'}
|
||||
}
|
||||
|
||||
data = torch.rand(1, 3, 640, 640)
|
||||
size = torch.tensor([[640, 640]])
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(data, size),
|
||||
args.file_name,
|
||||
input_names=['images', 'orig_target_sizes'],
|
||||
output_names=['labels', 'boxes', 'scores'],
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=16,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
|
||||
if args.check:
|
||||
import onnx
|
||||
onnx_model = onnx.load(args.file_name)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
print('Check export onnx model done...')
|
||||
|
||||
|
||||
if args.simplify:
|
||||
import onnxsim
|
||||
dynamic = True
|
||||
input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else None
|
||||
onnx_model_simplify, check = onnxsim.simplify(args.file_name, input_shapes=input_shapes, dynamic_input_shape=dynamic)
|
||||
onnx.save(onnx_model_simplify, args.file_name)
|
||||
print(f'Simplify onnx model {check}...')
|
||||
|
||||
|
||||
# import onnxruntime as ort
|
||||
# from PIL import Image, ImageDraw, ImageFont
|
||||
# from torchvision.transforms import ToTensor
|
||||
# from src.data.coco.coco_dataset import mscoco_category2name, mscoco_category2label, mscoco_label2category
|
||||
|
||||
# # print(onnx.helper.printable_graph(mm.graph))
|
||||
|
||||
# # Load the original image without resizing
|
||||
# original_im = Image.open('./hongkong.jpg').convert('RGB')
|
||||
# original_size = original_im.size
|
||||
|
||||
# # Resize the image for model input
|
||||
# im = original_im.resize((640, 640))
|
||||
# im_data = ToTensor()(im)[None]
|
||||
# print(im_data.shape)
|
||||
|
||||
# sess = ort.InferenceSession(args.file_name)
|
||||
# output = sess.run(
|
||||
# # output_names=['labels', 'boxes', 'scores'],
|
||||
# output_names=None,
|
||||
# input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
|
||||
# )
|
||||
|
||||
# # print(type(output))
|
||||
# # print([out.shape for out in output])
|
||||
|
||||
# labels, boxes, scores = output
|
||||
|
||||
# draw = ImageDraw.Draw(original_im) # Draw on the original image
|
||||
# thrh = 0.6
|
||||
|
||||
# for i in range(im_data.shape[0]):
|
||||
|
||||
# scr = scores[i]
|
||||
# lab = labels[i][scr > thrh]
|
||||
# box = boxes[i][scr > thrh]
|
||||
|
||||
# print(i, sum(scr > thrh))
|
||||
|
||||
# for b, l in zip(box, lab):
|
||||
# # Scale the bounding boxes back to the original image size
|
||||
# b = [coord * original_size[j % 2] / 640 for j, coord in enumerate(b)]
|
||||
# # Get the category name from the label
|
||||
# category_name = mscoco_category2name[mscoco_label2category[l]]
|
||||
# draw.rectangle(list(b), outline='red', width=2)
|
||||
# font = ImageFont.truetype("Arial.ttf", 15)
|
||||
# draw.text((b[0], b[1]), text=category_name, fill='yellow', font=font)
|
||||
|
||||
# # Save the original image with bounding boxes
|
||||
# original_im.save('test.jpg')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', '-c', type=str, )
|
||||
parser.add_argument('--resume', '-r', type=str, )
|
||||
parser.add_argument('--file-name', '-f', type=str, default='model.onnx')
|
||||
parser.add_argument('--check', action='store_true', default=False,)
|
||||
parser.add_argument('--simplify', action='store_true', default=False,)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user