62 lines
1.5 KiB
Python
62 lines
1.5 KiB
Python
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
import torch
|
|
import torchvision.transforms as T
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
from PIL import Image, ImageDraw
|
|
|
|
|
|
def draw(images, labels, boxes, scores, thrh = 0.6):
|
|
for i, im in enumerate(images):
|
|
draw = ImageDraw.Draw(im)
|
|
|
|
scr = scores[i]
|
|
lab = labels[i][scr > thrh]
|
|
box = boxes[i][scr > thrh]
|
|
|
|
for b in box:
|
|
draw.rectangle(list(b), outline='red',)
|
|
draw.text((b[0], b[1]), text=str(lab[i].item()), fill='blue', )
|
|
|
|
im.save(f'results_{i}.jpg')
|
|
|
|
|
|
def main(args, ):
|
|
"""main
|
|
"""
|
|
sess = ort.InferenceSession(args.onnx_file)
|
|
print(ort.get_device())
|
|
|
|
im_pil = Image.open(args.im_file).convert('RGB')
|
|
w, h = im_pil.size
|
|
orig_size = torch.tensor([w, h])[None]
|
|
|
|
transforms = T.Compose([
|
|
T.Resize((640, 640)),
|
|
T.ToTensor(),
|
|
])
|
|
im_data = transforms(im_pil)[None]
|
|
|
|
output = sess.run(
|
|
# output_names=['labels', 'boxes', 'scores'],
|
|
output_names=None,
|
|
input_feed={'images': im_data.data.numpy(), "orig_target_sizes": orig_size.data.numpy()}
|
|
)
|
|
|
|
labels, boxes, scores = output
|
|
|
|
draw([im_pil], labels, boxes, scores)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--onnx-file', type=str, )
|
|
parser.add_argument('--im-file', type=str, )
|
|
# parser.add_argument('-d', '--device', type=str, default='cpu')
|
|
args = parser.parse_args()
|
|
main(args)
|