first commit
This commit is contained in:
11
third_party/trt_yolov8/samples/gen_onnx.py
vendored
Normal file
11
third_party/trt_yolov8/samples/gen_onnx.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
from ultralytics import YOLO
|
||||
|
||||
'''
|
||||
NOTE: trt_yolov8 do not need ONNX format at all, this script just used for visualizing yolov8 networks on `https://netron.app`.
|
||||
'''
|
||||
|
||||
# Load a model
|
||||
model = YOLO("../../../../vp_data/models/trt/others/yolov8n-seg.pt") # load a pretrained model (recommended for training)
|
||||
|
||||
# Export the model
|
||||
path = model.export(format="onnx") # export the model to ONNX format which could be visualized on netron.app
|
||||
57
third_party/trt_yolov8/samples/gen_wts.py
vendored
Executable file
57
third_party/trt_yolov8/samples/gen_wts.py
vendored
Executable file
@@ -0,0 +1,57 @@
|
||||
import sys
|
||||
import argparse
|
||||
import os
|
||||
import struct
|
||||
import torch
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
|
||||
parser.add_argument('-w', '--weights', required=True,
|
||||
help='Input weights (.pt) file path (required)')
|
||||
parser.add_argument(
|
||||
'-o', '--output', help='Output (.wts) file path (optional)')
|
||||
parser.add_argument(
|
||||
'-t', '--type', type=str, default='detect', choices=['detect', 'cls', 'seg'],
|
||||
help='determines the model is detection/classification')
|
||||
args = parser.parse_args()
|
||||
if not os.path.isfile(args.weights):
|
||||
raise SystemExit('Invalid input file')
|
||||
if not args.output:
|
||||
args.output = os.path.splitext(args.weights)[0] + '.wts'
|
||||
elif os.path.isdir(args.output):
|
||||
args.output = os.path.join(
|
||||
args.output,
|
||||
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
|
||||
return args.weights, args.output, args.type
|
||||
|
||||
|
||||
pt_file, wts_file, m_type = parse_args()
|
||||
|
||||
print(f'Generating .wts for {m_type} model')
|
||||
|
||||
# Load model
|
||||
print(f'Loading {pt_file}')
|
||||
|
||||
# Initialize
|
||||
device = 'cpu'
|
||||
|
||||
# Load model
|
||||
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
|
||||
|
||||
if m_type in ['detect', 'seg']:
|
||||
anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
|
||||
|
||||
delattr(model.model[-1], 'anchors')
|
||||
|
||||
model.to(device).eval()
|
||||
|
||||
with open(wts_file, 'w') as f:
|
||||
f.write('{}\n'.format(len(model.state_dict().keys())))
|
||||
for k, v in model.state_dict().items():
|
||||
vr = v.reshape(-1).cpu().numpy()
|
||||
f.write('{} {} '.format(k, len(vr)))
|
||||
for vv in vr:
|
||||
f.write(' ')
|
||||
f.write(struct.pack('>f', float(vv)).hex())
|
||||
f.write('\n')
|
||||
36
third_party/trt_yolov8/samples/trt_yolov8_cls_test.cpp
vendored
Normal file
36
third_party/trt_yolov8/samples/trt_yolov8_cls_test.cpp
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
|
||||
#include "../trt_yolov8_classifier.h"
|
||||
|
||||
int main() {
|
||||
trt_yolov8::trt_yolov8_classifier detector("./vp_data/models/trt/others/yolov8s-cls_v8.5.engine");
|
||||
|
||||
auto image1 = cv::imread("./vp_data/test_images/vehicle_cls/1.jpg");
|
||||
auto image2 = cv::imread("./vp_data/test_images/vehicle_cls/5.jpg");
|
||||
std::unordered_map<int, std::string> labels_map;
|
||||
read_labels("./vp_data/models/imagenet_1000labels1.txt", labels_map);
|
||||
|
||||
|
||||
std::vector<std::vector<Classification>> classifications;
|
||||
std::vector<cv::Mat> images = {image1, image2};
|
||||
detector.classify(images, classifications, 5); // top3 by default
|
||||
|
||||
for (int i = 0; i < classifications.size(); ++i) {
|
||||
auto& classification = classifications[i];
|
||||
auto& image = images[i];
|
||||
|
||||
for (int j = 0; j < classification.size(); ++j) {
|
||||
std::cout << "(top" << j + 1 << ") class_id:" << classification[j].class_id << " conf:" << classification[j].conf << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// draw top1's label on image
|
||||
cv::putText(image, "top1: " + labels_map.at(classification[0].class_id), cv::Point(10, 10), 1.5, 1, cv::Scalar(0, 0, 255));
|
||||
}
|
||||
|
||||
cv::imshow("cls1", image1);
|
||||
cv::imshow("cls2", image2);
|
||||
|
||||
cv::waitKey(0);
|
||||
return 0;
|
||||
}
|
||||
25
third_party/trt_yolov8/samples/trt_yolov8_det_test.cpp
vendored
Normal file
25
third_party/trt_yolov8/samples/trt_yolov8_det_test.cpp
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
|
||||
#include "../trt_yolov8_detector.h"
|
||||
|
||||
int main() {
|
||||
trt_yolov8::trt_yolov8_detector detector("./vp_data/models/trt/others/yolov8s_v8.5.engine");
|
||||
|
||||
cv::VideoCapture cap("./vp_data/test_video/face2.mp4");
|
||||
cv::Mat frame;
|
||||
while (true) {
|
||||
if (!cap.read(frame)) {
|
||||
cap.set(cv::CAP_PROP_POS_FRAMES, 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Detection>> detections;
|
||||
std::vector<cv::Mat> frames = {frame};
|
||||
detector.detect(frames, detections);
|
||||
|
||||
draw_bbox(frames, detections);
|
||||
cv::imshow("detect", frame);
|
||||
cv::waitKey(40);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
25
third_party/trt_yolov8/samples/trt_yolov8_pose_test.cpp
vendored
Normal file
25
third_party/trt_yolov8/samples/trt_yolov8_pose_test.cpp
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
|
||||
#include "../trt_yolov8_pose_detector.h"
|
||||
|
||||
int main() {
|
||||
trt_yolov8::trt_yolov8_pose_detector detector("./vp_data/models/trt/others/yolov8s-pose_v8.5.engine");
|
||||
|
||||
cv::VideoCapture cap("./vp_data/test_video/face2.mp4");
|
||||
cv::Mat frame;
|
||||
while (true) {
|
||||
if (!cap.read(frame)) {
|
||||
cap.set(cv::CAP_PROP_POS_FRAMES, 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Detection>> detections;
|
||||
std::vector<cv::Mat> frames = {frame};
|
||||
detector.detect(frames, detections);
|
||||
|
||||
draw_bbox_keypoints_line(frames, detections);
|
||||
cv::imshow("pose", frame);
|
||||
cv::waitKey(40);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
29
third_party/trt_yolov8/samples/trt_yolov8_seg_test.cpp
vendored
Normal file
29
third_party/trt_yolov8/samples/trt_yolov8_seg_test.cpp
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
|
||||
|
||||
#include "../trt_yolov8_seg_detector.h"
|
||||
|
||||
int main() {
|
||||
trt_yolov8::trt_yolov8_seg_detector detector("./vp_data/models/trt/others/yolov8s-seg_v8.5.engine");
|
||||
|
||||
cv::VideoCapture cap("./vp_data/test_video/face2.mp4");
|
||||
std::unordered_map<int, std::string> labels_map;
|
||||
read_labels("./vp_data/models/coco_80classes.txt", labels_map);
|
||||
|
||||
cv::Mat frame;
|
||||
while (true) {
|
||||
if (!cap.read(frame)) {
|
||||
cap.set(cv::CAP_PROP_POS_FRAMES, 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Detection>> detections;
|
||||
std::vector<std::vector<cv::Mat>> masks;
|
||||
std::vector<cv::Mat> frames = {frame};
|
||||
detector.detect(frames, detections, masks);
|
||||
|
||||
draw_mask_bbox(frame, detections[0], masks[0], labels_map);
|
||||
cv::imshow("seg", frame);
|
||||
cv::waitKey(40);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
45
third_party/trt_yolov8/samples/trt_yolov8_wts_2_engine.cpp
vendored
Normal file
45
third_party/trt_yolov8/samples/trt_yolov8_wts_2_engine.cpp
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
|
||||
|
||||
#include "../trt_yolov8_detector.h"
|
||||
#include "../trt_yolov8_pose_detector.h"
|
||||
#include "../trt_yolov8_seg_detector.h"
|
||||
#include "../trt_yolov8_classifier.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
/* run command:
|
||||
* ./trt_yolov8_wts_2_engine [-det/-seg/-pose/-cls] [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6]
|
||||
*/
|
||||
|
||||
if (argc != 5) {
|
||||
std::cerr << "arguments not right!" << std::endl;
|
||||
std::cerr << "./trt_yolov8_wts_2_engine [-det/-seg/-pose/-cls] [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6]" << std::endl;
|
||||
}
|
||||
|
||||
std::string task_type = std::string(argv[1]);
|
||||
std::string wts_name = std::string(argv[2]);
|
||||
std::string engine_name = std::string(argv[3]);
|
||||
std::string sub_type = std::string(argv[4]);
|
||||
|
||||
if (task_type == "-det") {
|
||||
trt_yolov8::trt_yolov8_detector detector;
|
||||
detector.wts_2_engine(wts_name, engine_name, sub_type);
|
||||
}
|
||||
else if (task_type == "-seg") {
|
||||
trt_yolov8::trt_yolov8_seg_detector detector;
|
||||
detector.wts_2_engine(wts_name, engine_name, sub_type);
|
||||
}
|
||||
else if (task_type == "-pose") {
|
||||
trt_yolov8::trt_yolov8_pose_detector detector;
|
||||
detector.wts_2_engine(wts_name, engine_name, sub_type);
|
||||
}
|
||||
else if (task_type == "-cls") {
|
||||
trt_yolov8::trt_yolov8_classifier classifier;
|
||||
classifier.wts_2_engine(wts_name, engine_name, sub_type);
|
||||
}
|
||||
else {
|
||||
std::cerr << "arguments not right!" << std::endl;
|
||||
std::cerr << "./trt_yolov8_wts_2_engine [-det/-seg/-pose/-cls] [.wts] [.engine] [n/s/m/l/x/n2/s2/m2/l2/x2/n6/s6/m6/l6/x6]" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user