Skip to content

Instantly share code, notes, and snippets.

@dnishiyama85
Created October 6, 2018 06:19
Show Gist options
  • Save dnishiyama85/bbd11c8987deb8f126b7cfc39da9a607 to your computer and use it in GitHub Desktop.
Save dnishiyama85/bbd11c8987deb8f126b7cfc39da9a607 to your computer and use it in GitHub Desktop.
Kalman Smoothing
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
#df = pd.read_csv('./nile.csv')
df = pd.read_csv('./fat_percent.csv')
#df = pd.read_csv('./weight.csv')
# 状態空間モデル
G = np.array([[1.0]])
# W = np.array([np.exp(7.29)])
W = np.array([0.1**2])
F = np.array([[1.0]])
# V = np.array([np.exp(9.62)])
V = np.array([1.5**2])
#V = np.array([0.3**2])
p = 1 # 状態ベクトルの次元
q = 1 # 観測ベクトルの次元
# フィルタリングを1ステップ更新
def update_filtering(m_prev, C_prev, y):
# 一期先予測分布
a = G.dot(m_prev)
R= G.dot(C_prev.dot(G.T)) + W
# 一期先予測尤度
f = F.dot(a)
Q = F.dot(R.dot(F.T)) + V
# カルマン利得
K = R.dot(F.T.dot(np.linalg.inv(Q)))
# 状態
m = a + K.dot(y - f)
C = (np.eye(p, p) - K.dot(F)).dot(R)
return a, R, f, Q, m, C
# カルマンフィルタリング
def kalman_filtering(y):
# 領域確保
T = len(y)
a = np.zeros((T, p))
R = np.zeros((T, p, p))
f = np.zeros((T, q))
Q = np.zeros((T, q, q))
K = np.zeros((T, p, p))
m = np.zeros((T, p))
C = np.zeros((T, p, p))
# 初期値
m_init = np.array([0.0])
C_init = np.array([[1e7]])
# t = 0 のときは m_init, C_init を使うので特別
a[0], R[0], f[0], Q[0], m[0], C[0] = update_filtering(m_init, C_init, y[0])
# t >= 1 の逐次処理
for t in range(1, T):
a[t], R[t], f[t], Q[t], m[t], C[t] = update_filtering(m[t - 1], C[t - 1], y[t])
return a, R, f, Q, m, C
def update_smoothing(s_next, a, R, m, C, t):
# 平滑化利得
A = C[t].dot(G.T).dot(np.linalg.inv(R[t + 1]))
# 状態の更新
s = m[t] + A.dot(s_next - a[t + 1])
return s
# カルマン平滑化(全時点でのフィルタリング結果を引数に取る)
def kalman_smoothing(a, R, m, C):
t_max = len(m) - 1
# 領域確保
s = np.zeros((t_max + 1, p))
# 時点 t = t_max
s[t_max] = m[t_max]
# 時点 t < t_max
for t in range(t_max - 1, -1, -1):
s[t] = update_smoothing(s[t + 1], a, R, m, C, t)
return s
a, R, f, Q, m, C = kalman_filtering(df['x'])
s = kalman_smoothing(a, R, m, C)
print(m.flatten())
print(s.flatten())
plt.plot(df['x'])
# plt.plot(m)
plt.plot(s)
plt.savefig('./fat_percent.png')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment