OpenCV  5.0.0alpha-pre
Open Source Computer Vision
Loading...
Searching...
No Matches
samples/python/snippets/kalman.py

An example using the standard Kalman filter in Python.

1#!/usr/bin/env python
2"""
3 Tracking of rotating point.
4 Point moves in a circle and is characterized by a 1D state.
5 state_k+1 = state_k + speed + process_noise N(0, 1e-5)
6 The speed is constant.
7 Both state and measurements vectors are 1D (a point angle),
8 Measurement is the real state + gaussian noise N(0, 1e-1).
9 The real and the measured points are connected with red line segment,
10 the real and the estimated points are connected with yellow line segment,
11 the real and the corrected estimated points are connected with green line segment.
12 (if Kalman filter works correctly,
13 the yellow segment should be shorter than the red one and
14 the green segment should be shorter than the yellow one).
15 Pressing any key (except ESC) will reset the tracking.
16 Pressing ESC will stop the program.
17"""
18
19import numpy as np
20import cv2 as cv
21
22from math import cos, sin, sqrt, pi
23
24def main():
25 img_height = 500
26 img_width = 500
27 kalman = cv.KalmanFilter(2, 1, 0)
28
29 code = -1
30 num_circle_steps = 12
31 while True:
32 img = np.zeros((img_height, img_width, 3), np.uint8)
33 state = np.array([[0.0],[(2 * pi) / num_circle_steps]]) # start state
34 kalman.transitionMatrix = np.array([[1., 1.], [0., 1.]]) # F. input
35 kalman.measurementMatrix = 1. * np.eye(1, 2) # H. input
36 kalman.processNoiseCov = 1e-5 * np.eye(2) # Q. input
37 kalman.measurementNoiseCov = 1e-1 * np.ones((1, 1)) # R. input
38 kalman.errorCovPost = 1. * np.eye(2, 2) # P._k|k KF state var
39 kalman.statePost = 0.1 * np.random.randn(2, 1) # x^_k|k KF state var
40
41 while True:
42 def calc_point(angle):
43 return (np.around(img_width / 2. + img_width / 3.0 * cos(angle), 0).astype(int),
44 np.around(img_height / 2. - img_width / 3.0 * sin(angle), 1).astype(int))
45 img = img * 1e-3
46 state_angle = state[0, 0]
47 state_pt = calc_point(state_angle)
48 # advance Kalman filter to next timestep
49 # updates statePre, statePost, errorCovPre, errorCovPost
50 # k-> k+1, x'(k) = A*x(k)
51 # P'(k) = temp1*At + Q
52 prediction = kalman.predict()
53
54 predict_pt = calc_point(prediction[0, 0]) # equivalent to calc_point(kalman.statePre[0,0])
55 # generate measurement
56 measurement = kalman.measurementNoiseCov * np.random.randn(1, 1)
57 measurement = np.dot(kalman.measurementMatrix, state) + measurement
58
59 measurement_angle = measurement[0, 0]
60 measurement_pt = calc_point(measurement_angle)
61
62 # correct the state estimates based on measurements
63 # updates statePost & errorCovPost
64 kalman.correct(measurement)
65 improved_pt = calc_point(kalman.statePost[0, 0])
66
67 # plot points
68 cv.drawMarker(img, measurement_pt, (0, 0, 255), cv.MARKER_SQUARE, 5, 2)
69 cv.drawMarker(img, predict_pt, (0, 255, 255), cv.MARKER_SQUARE, 5, 2)
70 cv.drawMarker(img, improved_pt, (0, 255, 0), cv.MARKER_SQUARE, 5, 2)
71 cv.drawMarker(img, state_pt, (255, 255, 255), cv.MARKER_STAR, 10, 1)
72 # forecast one step
73 cv.drawMarker(img, calc_point(np.dot(kalman.transitionMatrix, kalman.statePost)[0, 0]),
74 (255, 255, 0), cv.MARKER_SQUARE, 12, 1)
75
76 cv.line(img, state_pt, measurement_pt, (0, 0, 255), 1, cv.LINE_AA, 0) # red measurement error
77 cv.line(img, state_pt, predict_pt, (0, 255, 255), 1, cv.LINE_AA, 0) # yellow pre-meas error
78 cv.line(img, state_pt, improved_pt, (0, 255, 0), 1, cv.LINE_AA, 0) # green post-meas error
79
80 # update the real process
81 process_noise = sqrt(kalman.processNoiseCov[0, 0]) * np.random.randn(2, 1)
82 state = np.dot(kalman.transitionMatrix, state) + process_noise # x_k+1 = F x_k + w_k
83
84 cv.imshow("Kalman", img)
85 code = cv.waitKey(1000)
86 if code != -1:
87 break
88
89 if code in [27, ord('q'), ord('Q')]:
90 break
91
92 print('Done')
93
94
95if __name__ == '__main__':
96 print(__doc__)
97 main()
Kalman filter class.
Definition tracking.hpp:375
void imshow(const String &winname, InputArray mat)
Displays an image in the specified window.
int waitKey(int delay=0)
Waits for a pressed key.
void destroyAllWindows()
Destroys all of the HighGUI windows.
void drawMarker(InputOutputArray img, Point position, const Scalar &color, int markerType=MARKER_CROSS, int markerSize=20, int thickness=1, int line_type=8)
Draws a marker on a predefined position in an image.
void line(InputOutputArray img, Point pt1, Point pt2, const Scalar &color, int thickness=1, int lineType=LINE_8, int shift=0)
Draws a line segment connecting two points.
int main(int argc, char *argv[])
Definition highgui_qt.cpp:3