#include "vp_classifier_node.h" namespace vp_nodes { vp_classifier_node::vp_classifier_node(std::string node_name, std::string model_path, std::string model_config_path, std::string labels_path, int input_width, int input_height, int batch_size, std::vector p_class_ids_applied_to, int min_width_applied_to, int min_height_applied_to, int crop_padding, bool need_softmax, float scale, cv::Scalar mean, cv::Scalar std, bool swap_rb, bool swap_chn): vp_secondary_infer_node(node_name, model_path, model_config_path, labels_path, input_width, input_height, batch_size, p_class_ids_applied_to, min_width_applied_to, min_height_applied_to, crop_padding, scale, mean, std, swap_rb, swap_chn), need_softmax(need_softmax){ this->initialized(); } vp_classifier_node::~vp_classifier_node() { deinitialized(); } void vp_classifier_node::postprocess(const std::vector& raw_outputs, const std::vector>& frame_meta_with_batch) { // make sure heads of output are not zero assert(raw_outputs.size() > 0); assert(frame_meta_with_batch.size() == 1); // just one head of output auto& output = raw_outputs[0]; assert(output.dims == 2); auto count = output.rows; auto index = 0; auto& frame_meta = frame_meta_with_batch[0]; cv::Point class_id_point; int class_id; double score; for (int i = 0; i < count; i++) { for (int j = index; j < frame_meta->targets.size(); j++) { if (!need_apply(frame_meta->targets[j]->primary_class_id, frame_meta->targets[j]->width, frame_meta->targets[j]->height)) { // continue as its primary_class_id is not in p_class_ids_applied_to continue; } auto prob = output.row(i); cv::minMaxLoc(prob, 0, &score, 0, &class_id_point); if (need_softmax) { float maxProb = 0.0; float sum = 0.0; cv::Mat softmaxProb; maxProb = *std::max_element(prob.begin(), prob.end()); cv::exp(prob - maxProb, softmaxProb); sum = (float)cv::sum(softmaxProb)[0]; softmaxProb /= sum; minMaxLoc(softmaxProb.reshape(1, 1), 0, &score, 0, &class_id_point); } class_id = class_id_point.x; auto label = (labels.size() < class_id + 1) ? "" : labels[class_id]; // update back to frame meta frame_meta->targets[j]->secondary_class_ids.push_back(class_id); frame_meta->targets[j]->secondary_scores.push_back(score); frame_meta->targets[j]->secondary_labels.push_back(label); // break as we found the right target! index = j + 1; break; } } } }