Skip to content

Instantly share code, notes, and snippets.

@bakfoo
Created August 16, 2025 13:50
Show Gist options
  • Save bakfoo/03c18a336e32de5ef0cc40fc21310ddf to your computer and use it in GitHub Desktop.
Save bakfoo/03c18a336e32de5ef0cc40fc21310ddf to your computer and use it in GitHub Desktop.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from tqdm import tqdm
# --- モデルの微分方程式を定義する関数 ---
def attrition_model_ode(t, y, params, rand_vals):
P1, P2, Q1, Q2, C1, C2, N1, N2 = y
P1, P2, Q1, Q2, N1, N2 = max(0, P1), max(0, P2), max(0, Q1), max(0, Q2), max(0, N1), max(0, N2)
dP1_dt = params['a1'] * (1 + params['eps'] * rand_vals['p1']) * (1 - P1 / params['S1'])
dP2_dt = params['a2'] * (1 + params['eps'] * rand_vals['p2']) * (1 - P2 / params['S2'])
dQ1_dt = params['b1'] * (1 + params['eps'] * rand_vals['q1_prod']) * P1 - \
params['g1'] * (1 + params['eps'] * rand_vals['q1_cons']) * Q1
dQ2_dt = params['b2'] * (1 + params['eps'] * rand_vals['q2_prod']) * P2 - \
params['g2'] * (1 + params['eps'] * rand_vals['q2_cons']) * Q2
loss1_rate = params['d2'] * params['g2'] * (1 + 2 * params['eps'] * rand_vals['loss1']) * Q2
loss2_rate = params['d1'] * params['g1'] * (1 + 2 * params['eps'] * rand_vals['loss2']) * Q1
if N1 <= 0: loss1_rate = 0
if N2 <= 0: loss2_rate = 0
dC1_dt = loss1_rate
dC2_dt = loss2_rate
recruitment_factor1 = max(0, 1 - (N1 + C1) / params['M1'])
recruitment_factor2 = max(0, 1 - (N2 + C2) / params['M2'])
dN1_dt = params['h1'] * (1 + params['eps'] * rand_vals['n1']) * recruitment_factor1 - loss1_rate
dN2_dt = params['h2'] * (1 + params['eps'] * rand_vals['n2']) * recruitment_factor2 - loss2_rate
return [dP1_dt, dP2_dt, dQ1_dt, dQ2_dt, dC1_dt, dC2_dt, dN1_dt, dN2_dt]
# --- 1回のシミュレーションを実行する関数(atol修正版) ---
def run_single_simulation(params, y0, dur, rand_keys, atol):
y_current = y0.copy()
results = [y0]
for t_day in range(dur - 1):
rand_vals = {key: np.random.normal(0, 1) for key in rand_keys}
sol = solve_ivp(
fun=lambda t, y: attrition_model_ode(t, y, params, rand_vals),
t_span=[t_day, t_day + 1],
y0=y_current,
dense_output=True,
method='LSODA',
atol=atol # ★ 各変数のスケールに合わせた許容誤差を指定
)
if not sol.success:
print(f"Solver failed at day {t_day} with message: {sol.message}")
break # 失敗したらこのシミュレーションは中断
y_current = sol.sol(t_day + 1)
results.append(y_current)
df = pd.DataFrame(np.array(results), columns=['P1', 'P2', 'Q1', 'Q2', 'C1', 'C2', 'N1', 'N2'])
df['month'] = (np.arange(len(results))) / 30
return df
# --- パラメータと初期値の設定 ---
dur = 30 * 12 * 6 # ★ 6年に設定
params = {
'a1': 10, 'a2': 20, 'S2': 20*10**3, 'S1': 20*(20*10**3),
'b1': 1.0, 'b2': 1.0, 'g1': 20000 / (10*10**6 / 4), 'g2': 20000 / (10*10**6),
'd1': 1/30, 'd2': 1/30, 'h1': 1000, 'h2': 300,
'M1': 1.5*10**6, 'M2': (1.5*10**6) * 4, 'eps': 2.0
}
rand_keys = ['p1', 'p2', 'q1_prod', 'q1_cons', 'q2_prod', 'q2_cons',
'loss1', 'loss2', 'n1', 'n2']
y0 = np.array([1000, 5000, 2.5e6, 10e6, 0, 0, 5e5, 3e5])
# ★ 絶対許容誤差(atol)を各変数のスケールに合わせて設定
# P1, P2, Q1, Q2, C1, C2, N1, N2
abs_tol = [1, 1, 100, 100, 10, 10, 10, 10]
# --- アンサンブル予測の実行 ---
num_realizations = 20
c1_trajectories = []
print(f"Running {num_realizations} simulations for {dur/30/12:.0f} years...")
for i in tqdm(range(num_realizations)):
result_df = run_single_simulation(params, y0, dur, rand_keys, atol=abs_tol)
if not result_df.empty:
c1_trajectories.append(result_df['C1'])
print("Simulations complete. Plotting results...")
# --- グラフ描画 ---
fig, ax = plt.subplots(figsize=(10, 7))
months_axis = np.arange(dur) / 30
for trajectory in c1_trajectories:
# シミュレーションが途中で失敗した場合も考慮して、長さを合わせる
ax.plot(months_axis[:len(trajectory)], trajectory / 1000, color='brown', alpha=0.8, linewidth=1.5)
ax.axhspan(500, 1000, color='skyblue', alpha=0.4, zorder=0)
ax.axhline(750, color='blue', linestyle='--', linewidth=2)
ax.set_title(f'{len(c1_trajectories)} realizations for the dynamics of C1 (Ukrainian casualties)')
ax.set_xlabel('month')
ax.set_ylabel('Casualties, k')
ax.set_xlim(0, dur/30) # x軸の範囲をdurに合わせる
ax.set_ylim(0, 1500) # y軸の範囲を少し広げる
ax.set_xticks(np.arange(0, dur/30 + 1, 12))
ax.grid(axis='x', linestyle='--', color='lightgray')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment