Created
October 6, 2018 06:19
-
-
Save dnishiyama85/bbd11c8987deb8f126b7cfc39da9a607 to your computer and use it in GitHub Desktop.
Kalman Smoothing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import matplotlib | |
matplotlib.use('TkAgg') | |
import matplotlib.pyplot as plt | |
#df = pd.read_csv('./nile.csv') | |
df = pd.read_csv('./fat_percent.csv') | |
#df = pd.read_csv('./weight.csv') | |
# 状態空間モデル | |
G = np.array([[1.0]]) | |
# W = np.array([np.exp(7.29)]) | |
W = np.array([0.1**2]) | |
F = np.array([[1.0]]) | |
# V = np.array([np.exp(9.62)]) | |
V = np.array([1.5**2]) | |
#V = np.array([0.3**2]) | |
p = 1 # 状態ベクトルの次元 | |
q = 1 # 観測ベクトルの次元 | |
# フィルタリングを1ステップ更新 | |
def update_filtering(m_prev, C_prev, y): | |
# 一期先予測分布 | |
a = G.dot(m_prev) | |
R= G.dot(C_prev.dot(G.T)) + W | |
# 一期先予測尤度 | |
f = F.dot(a) | |
Q = F.dot(R.dot(F.T)) + V | |
# カルマン利得 | |
K = R.dot(F.T.dot(np.linalg.inv(Q))) | |
# 状態 | |
m = a + K.dot(y - f) | |
C = (np.eye(p, p) - K.dot(F)).dot(R) | |
return a, R, f, Q, m, C | |
# カルマンフィルタリング | |
def kalman_filtering(y): | |
# 領域確保 | |
T = len(y) | |
a = np.zeros((T, p)) | |
R = np.zeros((T, p, p)) | |
f = np.zeros((T, q)) | |
Q = np.zeros((T, q, q)) | |
K = np.zeros((T, p, p)) | |
m = np.zeros((T, p)) | |
C = np.zeros((T, p, p)) | |
# 初期値 | |
m_init = np.array([0.0]) | |
C_init = np.array([[1e7]]) | |
# t = 0 のときは m_init, C_init を使うので特別 | |
a[0], R[0], f[0], Q[0], m[0], C[0] = update_filtering(m_init, C_init, y[0]) | |
# t >= 1 の逐次処理 | |
for t in range(1, T): | |
a[t], R[t], f[t], Q[t], m[t], C[t] = update_filtering(m[t - 1], C[t - 1], y[t]) | |
return a, R, f, Q, m, C | |
def update_smoothing(s_next, a, R, m, C, t): | |
# 平滑化利得 | |
A = C[t].dot(G.T).dot(np.linalg.inv(R[t + 1])) | |
# 状態の更新 | |
s = m[t] + A.dot(s_next - a[t + 1]) | |
return s | |
# カルマン平滑化(全時点でのフィルタリング結果を引数に取る) | |
def kalman_smoothing(a, R, m, C): | |
t_max = len(m) - 1 | |
# 領域確保 | |
s = np.zeros((t_max + 1, p)) | |
# 時点 t = t_max | |
s[t_max] = m[t_max] | |
# 時点 t < t_max | |
for t in range(t_max - 1, -1, -1): | |
s[t] = update_smoothing(s[t + 1], a, R, m, C, t) | |
return s | |
a, R, f, Q, m, C = kalman_filtering(df['x']) | |
s = kalman_smoothing(a, R, m, C) | |
print(m.flatten()) | |
print(s.flatten()) | |
plt.plot(df['x']) | |
# plt.plot(m) | |
plt.plot(s) | |
plt.savefig('./fat_percent.png') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment