first commit
This commit is contained in:
84
rtdetrv2_pytorch/references/deploy/rtdetrv2_torch.py
Normal file
84
rtdetrv2_pytorch/references/deploy/rtdetrv2_torch.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from src.core import YAMLConfig
|
||||
|
||||
|
||||
def draw(images, labels, boxes, scores, thrh = 0.6):
|
||||
for i, im in enumerate(images):
|
||||
draw = ImageDraw.Draw(im)
|
||||
|
||||
scr = scores[i]
|
||||
lab = labels[i][scr > thrh]
|
||||
box = boxes[i][scr > thrh]
|
||||
scrs = scores[i][scr > thrh]
|
||||
|
||||
for j,b in enumerate(box):
|
||||
draw.rectangle(list(b), outline='red',)
|
||||
draw.text((b[0], b[1]), text=f"{lab[j].item()} {round(scrs[j].item(),2)}", fill='blue', )
|
||||
|
||||
im.save(f'results_{i}.jpg')
|
||||
|
||||
|
||||
def main(args, ):
|
||||
"""main
|
||||
"""
|
||||
cfg = YAMLConfig(args.config, resume=args.resume)
|
||||
|
||||
if args.resume:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
if 'ema' in checkpoint:
|
||||
state = checkpoint['ema']['module']
|
||||
else:
|
||||
state = checkpoint['model']
|
||||
else:
|
||||
raise AttributeError('Only support resume to load model.state_dict by now.')
|
||||
|
||||
# NOTE load train mode state -> convert to deploy mode
|
||||
cfg.model.load_state_dict(state)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, ) -> None:
|
||||
super().__init__()
|
||||
self.model = cfg.model.deploy()
|
||||
self.postprocessor = cfg.postprocessor.deploy()
|
||||
|
||||
def forward(self, images, orig_target_sizes):
|
||||
outputs = self.model(images)
|
||||
outputs = self.postprocessor(outputs, orig_target_sizes)
|
||||
return outputs
|
||||
|
||||
model = Model().to(args.device)
|
||||
|
||||
im_pil = Image.open(args.im_file).convert('RGB')
|
||||
w, h = im_pil.size
|
||||
orig_size = torch.tensor([w, h])[None].to(args.device)
|
||||
|
||||
transforms = T.Compose([
|
||||
T.Resize((640, 640)),
|
||||
T.ToTensor(),
|
||||
])
|
||||
im_data = transforms(im_pil)[None].to(args.device)
|
||||
|
||||
output = model(im_data, orig_size)
|
||||
labels, boxes, scores = output
|
||||
|
||||
draw([im_pil], labels, boxes, scores)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', type=str, )
|
||||
parser.add_argument('-r', '--resume', type=str, )
|
||||
parser.add_argument('-f', '--im-file', type=str, )
|
||||
parser.add_argument('-d', '--device', type=str, default='cpu')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user