Files
VideoPipe/nodes/track/sort/KalmanTracker.cpp
2026-06-03 12:43:14 +08:00

177 lines
4.4 KiB
C++

///////////////////////////////////////////////////////////////////////////////
// KalmanTracker.cpp: KalmanTracker Class Implementation Declaration
#include "KalmanTracker.h"
int KalmanTracker::kf_count = 0;
// initialize Kalman filter
void KalmanTracker::init_kf(StateType stateMat)
{
int stateNum = 7;
int measureNum = 4;
kf = KalmanFilter(stateNum, measureNum, 0);
measurement = cv::Mat::zeros(measureNum, 1, CV_32F);
kf.transitionMatrix = (cv::Mat_<float>(stateNum, stateNum) << 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1);
setIdentity(kf.measurementMatrix);
setIdentity(kf.processNoiseCov, Scalar::all(1e-2));
setIdentity(kf.measurementNoiseCov, Scalar::all(1e-1));
setIdentity(kf.errorCovPost, Scalar::all(1));
// initialize state vector with bounding box in [cx,cy,s,r] style
kf.statePost.at<float>(0, 0) = stateMat.x + stateMat.width / 2;
kf.statePost.at<float>(1, 0) = stateMat.y + stateMat.height / 2;
kf.statePost.at<float>(2, 0) = stateMat.area();
kf.statePost.at<float>(3, 0) = stateMat.width / stateMat.height;
}
// Predict the estimated bounding box.
StateType KalmanTracker::predict()
{
// predict
Mat p = kf.predict();
m_age += 1;
if (m_time_since_update > 0)
m_hit_streak = 0;
m_time_since_update += 1;
StateType predictBox = get_rect_xysr(p.at<float>(0, 0), p.at<float>(1, 0), p.at<float>(2, 0), p.at<float>(3, 0));
m_history.push_back(predictBox);
return m_history.back();
}
// Update the state vector with observed bounding box.
void KalmanTracker::update(StateType stateMat)
{
m_time_since_update = 0;
m_history.clear();
m_hits += 1;
m_hit_streak += 1;
// measurement
measurement.at<float>(0, 0) = stateMat.x + stateMat.width / 2;
measurement.at<float>(1, 0) = stateMat.y + stateMat.height / 2;
measurement.at<float>(2, 0) = stateMat.area();
measurement.at<float>(3, 0) = stateMat.width / stateMat.height;
// update
kf.correct(measurement);
}
// Return the current state vector
StateType KalmanTracker::get_state()
{
Mat s = kf.statePost;
return get_rect_xysr(s.at<float>(0, 0), s.at<float>(1, 0), s.at<float>(2, 0), s.at<float>(3, 0));
}
// Convert bounding box from [cx,cy,s,r] to [x,y,w,h] style.
StateType KalmanTracker::get_rect_xysr(float cx, float cy, float s, float r)
{
float w = sqrt(s * r);
float h = s / w;
float x = (cx - w / 2);
float y = (cy - h / 2);
if (x < 0 && cx > 0)
x = 0;
if (y < 0 && cy > 0)
y = 0;
return StateType(x, y, w, h);
}
/*
// --------------------------------------------------------------------
// Kalman Filter Demonstrating, a 2-d ball demo
// --------------------------------------------------------------------
const int winHeight = 600;
const int winWidth = 800;
Point mousePosition = Point(winWidth >> 1, winHeight >> 1);
// mouse event callback
void mouseEvent(int event, int x, int y, int flags, void *param)
{
if (event == CV_EVENT_MOUSEMOVE) {
mousePosition = Point(x, y);
}
}
void TestKF();
void main()
{
TestKF();
}
void TestKF()
{
int stateNum = 4;
int measureNum = 2;
KalmanFilter kf = KalmanFilter(stateNum, measureNum, 0);
// initialization
Mat processNoise(stateNum, 1, CV_32F);
Mat measurement = Mat::zeros(measureNum, 1, CV_32F);
kf.transitionMatrix = *(Mat_<float>(stateNum, stateNum) <<
1, 0, 1, 0,
0, 1, 0, 1,
0, 0, 1, 0,
0, 0, 0, 1);
setIdentity(kf.measurementMatrix);
setIdentity(kf.processNoiseCov, Scalar::all(1e-2));
setIdentity(kf.measurementNoiseCov, Scalar::all(1e-1));
setIdentity(kf.errorCovPost, Scalar::all(1));
randn(kf.statePost, Scalar::all(0), Scalar::all(winHeight));
namedWindow("Kalman");
setMouseCallback("Kalman", mouseEvent);
Mat img(winHeight, winWidth, CV_8UC3);
while (1)
{
// predict
Mat prediction = kf.predict();
Point predictPt = Point(prediction.at<float>(0, 0), prediction.at<float>(1, 0));
// generate measurement
Point statePt = mousePosition;
measurement.at<float>(0, 0) = statePt.x;
measurement.at<float>(1, 0) = statePt.y;
// update
kf.correct(measurement);
// visualization
img.setTo(Scalar(255, 255, 255));
circle(img, predictPt, 8, CV_RGB(0, 255, 0), -1); // predicted point as green
circle(img, statePt, 8, CV_RGB(255, 0, 0), -1); // current position as red
imshow("Kalman", img);
char code = (char)waitKey(100);
if (code == 27 || code == 'q' || code == 'Q')
break;
}
destroyWindow("Kalman");
}
*/