Created
May 5, 2025 11:09
-
-
Save jbanety/76f5a655f70d1996753e315c5db16b22 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Actor network (policy) | |
//! Le module actor contient tout ce qui est lié à la politique stochastique (l’Actor). Ce module peut inclure : | |
//! - Réseau de l’Actor : La structure et le forward pass. | |
//! - Paramètres de sortie : Moyennes, écart type/log-écart type, distributions. | |
//! - Calculs spécifiques : Gestion de l'entropie, application de transformations (par exemple, TANH). | |
//! | |
//! On modélise un réseau de neurones (NN) chargé de produire une distribution d’actions (ou des paramètres d’action) pour le trading. | |
//! L’entrée du NN correspond à l’état du marché (features normalisées). | |
//! La sortie correspond soit aux logits (pour un policy discret) ou aux paramètres d’une distribution gaussienne (pour des actions continues). | |
//! En scalping, les actions peuvent être : acheter / vendre / rester à l’écart, éventuellement avec un continuous size (pour la taille de position). | |
//! | |
use core::f32; | |
use burn::{ | |
module::Module, | |
nn::{ | |
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, | |
Dropout, LayerNorm, Linear, | |
}, | |
prelude::*, | |
tensor::{ | |
activation::{sigmoid, softmax, tanh}, | |
Distribution, Tensor, | |
}, | |
}; | |
use serde::{Deserialize, Serialize}; | |
use super::{countable::Countable, models::LayersModel, CustomInitMapper, RawAction, Stateful}; | |
use crate::trading_ai::{ | |
core::{ | |
action_mask::ActionMask, | |
raw_action::{ACTION_LEVERAGE_DIM, ACTION_ORDER_TYPE_DIM}, | |
}, | |
env::{ | |
EnvironmentStateTensor, Leverage, STATE_BALANCE_INDEX, STATE_DURATION_INDEX, | |
STATE_LEVERAGE_INDEX, STATE_POSITION_ENTRY_DISTANCE_INDEX, | |
STATE_POSITION_LIQUIDATION_DISTANCE_INDEX, STATE_POSITION_SIZE_INDEX, | |
STATE_POSITION_SL_INDEX, STATE_POSITION_TP_INDEX, STATE_POSITION_TYPE_DIM, | |
STATE_POSITION_TYPE_INDEX, STATE_TOTAL_DIM, | |
}, | |
memory::{ | |
MarketSnapshotTensor, TechnicalIndicatorsTensor, INDICATORS_ATR_INDEX, | |
INDICATORS_STOCH_RSI_INDEX, MARKET_SNAPSHOT_DIM, SNAPSHOT_ASK_VOLUME_INDEX, | |
SNAPSHOT_BID_VOLUME_INDEX, SNAPSHOT_SPREAD_INDEX, SNAPTSHOT_CLOSE_INDEX, | |
TECHNICAL_INDICATORS_DIM, | |
}, | |
training::CurriculumManager, | |
utils::{ | |
create_cross_mask, create_dropout_layer, create_kaiming_linear_layer, create_layer_norm_1d, | |
create_xavier_linear_layer, get_tensor_data_as_f32_vec, log_prob_categorical, | |
log_prob_gaussian, | |
}, | |
}; | |
/// Type de tensor pour les logits | |
pub type LogProbsTensor<B> = Tensor<B, 2>; | |
/// Offset pour les indicateurs techniques sur le tensor concaténé | |
const CONCATENED_INDICATORS_OFFSET: usize = MARKET_SNAPSHOT_DIM; | |
/// Offset pour l'état de l'environnement sur le tensor concaténé | |
const CONCATENED_STATE_OFFSET: usize = MARKET_SNAPSHOT_DIM + TECHNICAL_INDICATORS_DIM; | |
/// Configuration complète de l'acteur | |
#[derive(Clone, Debug, Deserialize, Serialize)] | |
pub struct ActorConfig { | |
/// Dimension finale des actions | |
pub action_dim: usize, | |
// Paramètres de l'attention | |
/// Valeur minimale pour le masquage. Défaut: -1.0e4 | |
pub attention_min_float: f64, | |
/// Longueur de la séquence d'attention | |
pub attention_sequence_length: usize, | |
/// Utilisation du "quiet softmax" au lieu du softmax standard | |
/// Son utilisation peut : | |
/// - améliorer les performances en permettant aux têtes d'attention de ne pas déposer d'informations (si la séquence ne contient aucune information pertinente pour cette tête). | |
/// - réduire l'entropie des poids du modèle, améliorant la quantification et la compression. | |
pub attention_quiet_softmax: bool, | |
/// Le gain à utiliser dans la formule d'initialisation | |
pub attention_init_gain: f64, | |
/// Si vrai, utilise uniquement le fan out pour l'initialisation | |
pub attention_init_fan_out_only: bool, | |
/// Nombre de têtes d'attention | |
pub attention_n_heads: usize, | |
/// Dimensions successives des couches du processeur des MarketSnapshot | |
/// Dimension d'entrée du MarketSnapshot | |
//pub snapshot_input_dim: usize, | |
/// Ex: [128, 64] crée deux couches avec 128 puis 64 neurones | |
pub snapshot_processor_dims: Vec<usize>, | |
/// Utilisation d'un biais | |
pub snapshot_with_bias: bool, | |
/// Le gain à utiliser dans la formule d'initialisation | |
pub snapshot_init_gain: f64, | |
/// Si vrai, utilise uniquement le fan out pour l'initialisation | |
pub snapshot_init_fan_out_only: bool, | |
/// Dimensions successives des couches du processeur des indicateurs techniques | |
/// Dimension d'entrée des TechnicalIndicators | |
//pub indicators_input_dim: usize, | |
/// Dimension cachée pour TechnicalIndicators | |
pub indicators_processor_dims: Vec<usize>, | |
/// Utilisation d'un biais pour les indicateurs | |
pub indicators_with_bias: bool, | |
/// Le gain à utiliser dans la formule d'initialisation | |
pub indicators_init_gain: f64, | |
/// Si vrai, utilise uniquement le fan out pour l'initialisation | |
pub indicators_init_fan_out_only: bool, | |
/// Dimensions successives des couches du processeur de l'état de l'environnement | |
/// Dimension d'entrée de l'état de l'environnement | |
//pub state_input_dim: usize, | |
/// Dimension cachée pour l'état de l'environnement | |
pub state_processor_dims: Vec<usize>, | |
/// Utilisation d'un biais pour l'état de l'environnement | |
pub state_with_bias: bool, | |
/// Le gain à utiliser dans la formule d'initialisation | |
pub state_init_gain: f64, | |
/// Si vrai, utilise uniquement le fan out pour l'initialisation | |
pub state_init_fan_out_only: bool, | |
/// Dimension latente pour la représentation fusionnée des features | |
pub fusion_latent_dim: usize, | |
/// Utilisation d'un biais pour la fusion | |
pub fusion_bias: bool, | |
/// Gain à utiliser dans la formule d'initialisation | |
pub fusion_init_gain: f64, | |
/// Si vrai, utilise uniquement le fan out pour l'initialisation | |
pub fusion_init_fan_out_only: bool, | |
/// Gain à utiliser dans la formule d'initialisation | |
pub output_init_gain: f64, | |
/// Si vrai, utilise uniquement le fan out pour l'initialisation | |
pub output_init_fan_out_only: bool, | |
/// Dimension de sortie pour le type de position | |
pub output_position_type_dim: usize, | |
/// Dimension de sortie pour le type d'ordre | |
pub output_order_type_dim: usize, | |
/// Dimension de sortie pour le levier | |
pub output_leverage_dim: usize, | |
/// Dimension de sortie pour la distance d'entrée | |
pub output_entry_distance_dim: usize, | |
/// Dimension de sortie pour la taille de position | |
pub output_position_size_dim: usize, | |
/// Dimension de sortie pour le ratio de Take Profit | |
pub output_tp_ratio_dim: usize, | |
/// Dimension de sortie pour le ratio de Stop Loss | |
pub output_sl_ratio_dim: usize, | |
// Paramètres de régularisation | |
/// Taux de dropout | |
pub dropout_rate: Option<f64>, | |
/// Une valeur epsilon pour la stabilité numérique. Défaut 1e-5 | |
pub epsilon: f64, | |
/// Momemtum utilisé pour mettre à jour les métriques. Défaut 0.1 | |
pub momentum: f64, | |
// Initialisation personnalisée | |
/// Option d'utiliser une initialisation personnalisée | |
pub use_custom_init: bool, | |
/// Seed pour l'initialisation personnalisée (déterministe pour les tests) | |
pub custom_init_seed: Option<u64>, | |
} | |
impl Default for ActorConfig { | |
fn default() -> Self { | |
Self { | |
action_dim: 64, | |
attention_sequence_length: 64, // = action_dim pour simplifier | |
attention_n_heads: 8, // On augmente le nombre de têtes d'attention pour renforcer l'état de position | |
attention_min_float: -1.0e4, | |
attention_quiet_softmax: false, // Désactiver quiet softmax pour des signaux plus forts | |
attention_init_gain: 1.0, // Valeur standard pour ReLU | |
attention_init_fan_out_only: false, // Considère fan_in et fan_out pour une meilleure distribution | |
snapshot_processor_dims: vec![128, 64], // Deux couches cachées | |
snapshot_with_bias: true, | |
snapshot_init_gain: 1.0, // Valeur standard pour ReLU | |
snapshot_init_fan_out_only: false, // Considère fan_in | |
indicators_processor_dims: vec![128, 64], // Deux couches cachées | |
indicators_with_bias: true, | |
indicators_init_gain: 1.0, // Valeur standard pour ReLU | |
indicators_init_fan_out_only: false, // Considère fan_in et fan_out pour une meilleure distribution | |
state_processor_dims: vec![256, 128, 64], // 3 couches pour renforcer l'état de position | |
state_with_bias: true, | |
state_init_gain: 2.0, // Plus de poids aux gradients des états pour renforcer l'état de position | |
state_init_fan_out_only: false, // Considère fan_in et fan_out pour une meilleure distribution | |
fusion_latent_dim: 64, // = action_dim pour simplifier. Si je veux mettre plus grand => https://gitlab.com/etdsolutions/autoscalp3000/-/snippets/4796382 | |
fusion_bias: true, // Utilisation d'un biais pour la fusion | |
fusion_init_gain: 1.0, // Valeur standard pour ReLU | |
fusion_init_fan_out_only: false, // Considère fan_in | |
output_init_gain: 1.5, // Valeur standard pour ReLU | |
output_init_fan_out_only: false, // Considère fan_in | |
output_position_type_dim: 4, // Long/Short/Close/Hold | |
output_order_type_dim: 2, // Maker/Taker | |
output_leverage_dim: 12, // Leverage 1x/2x/3x/4x/5x/10x/20x/25x/50x/75x/100x/125x | |
output_entry_distance_dim: 1, // Distance relative | |
output_position_size_dim: 1, // Taille de position | |
output_tp_ratio_dim: 1, // Ratio de Take Profit | |
output_sl_ratio_dim: 1, // Ratio de Stop Loss | |
dropout_rate: Some(0.2), | |
epsilon: 1e-6, // Plus petit pour plus de précision sur les variations fines | |
momentum: 0.01, // Plus faible pour une meilleure stabilité sur données HF | |
use_custom_init: false, | |
custom_init_seed: None, | |
} | |
} | |
} | |
/// Sous-module pour traiter le MarketSnapshot | |
type SnapshotProcessor<B> = LayersModel<B>; | |
/// Sous-module pour traiter les TechnicalIndicators | |
type IndicatorsProcessor<B> = LayersModel<B>; | |
/// Sous-module pour traiter l'état de l'environnement | |
type StateProcessor<B> = LayersModel<B>; | |
/// Sous-module pour les têtes de sortie | |
#[derive(Module, Debug)] | |
pub struct OutputHeads<B: Backend> { | |
/// Type de position (Long/Short/Close/Hold) | |
pub position_type: Linear<B>, | |
/// Taille de position (moyenne) | |
position_size_mean: Linear<B>, | |
/// Taille de position (écart type) | |
position_size_std: Linear<B>, | |
/// Type d'ordre (Maker/Taker) | |
order_type: Linear<B>, | |
/// Leverage (1x/2x/3x/4x/5x/10x/20x/25x/50x/75x/100x/125x) | |
leverage: Linear<B>, | |
/// Distance relative d'entrée (moyenne) | |
entry_distance_mean: Linear<B>, | |
/// Distance relative d'entrée (écart type) | |
entry_distance_std: Linear<B>, | |
/// Take Profit ratio | |
tp_ratio: Linear<B>, | |
/// Stop Loss ratio | |
sl_ratio: Linear<B>, | |
/// Epsilon pour la stabilité numérique | |
epsilon: f32, | |
} | |
impl<B: Backend> OutputHeads<B> { | |
/// Crée un nouveau OutputHeads | |
pub fn new(config: &ActorConfig, device: &B::Device) -> Self { | |
Self { | |
position_type: create_xavier_linear_layer( | |
config.fusion_latent_dim, | |
config.output_position_type_dim, | |
true, | |
config.output_init_gain, | |
device, | |
), | |
order_type: create_xavier_linear_layer( | |
config.fusion_latent_dim, | |
config.output_order_type_dim, | |
true, | |
config.output_init_gain, | |
device, | |
), | |
leverage: create_xavier_linear_layer( | |
config.fusion_latent_dim, | |
config.output_leverage_dim, | |
true, | |
config.output_init_gain, | |
device, | |
), | |
entry_distance_mean: create_kaiming_linear_layer( | |
config.fusion_latent_dim, | |
config.output_entry_distance_dim, | |
true, | |
config.output_init_gain, | |
config.output_init_fan_out_only, | |
device, | |
), | |
entry_distance_std: create_xavier_linear_layer( | |
config.fusion_latent_dim, | |
config.output_entry_distance_dim, | |
true, | |
config.output_init_gain, | |
device, | |
), | |
position_size_mean: create_xavier_linear_layer( | |
config.fusion_latent_dim, | |
config.output_position_size_dim, | |
true, | |
config.output_init_gain, | |
device, | |
), | |
position_size_std: create_xavier_linear_layer( | |
config.fusion_latent_dim, | |
config.output_position_size_dim, | |
true, | |
config.output_init_gain, | |
device, | |
), | |
tp_ratio: create_kaiming_linear_layer( | |
config.fusion_latent_dim, | |
config.output_tp_ratio_dim, | |
true, | |
config.fusion_init_gain, | |
config.fusion_init_fan_out_only, | |
device, | |
), | |
sl_ratio: create_kaiming_linear_layer( | |
config.fusion_latent_dim, | |
config.output_sl_ratio_dim, | |
true, | |
config.fusion_init_gain, | |
config.fusion_init_fan_out_only, | |
device, | |
), | |
epsilon: config.epsilon as f32, | |
} | |
} | |
/// Forward pass pour les têtes de sortie | |
/// | |
/// Arguments: | |
/// * `x`: [batch_size, latent_dim] | |
/// - batch_size: taille du batch (toujours 1 en pratique) | |
/// - latent_dim: dimension de la représentation fusionnée | |
/// * `state`: Etat de l'environnement pour l'action masking | |
/// | |
/// Returns: | |
/// * RawAction: Structure contenant les sorties scalaires | |
/// * LogProbsTensor<B>: Log-probabilité pour l'action | |
pub fn forward( | |
&self, | |
x: Tensor<B, 2>, | |
state: &EnvironmentStateTensor<B>, | |
max_leverage: Leverage, | |
curriculum_manager: &mut CurriculumManager<B>, | |
device: &B::Device, | |
) -> (Vec<RawAction>, LogProbsTensor<B>) { | |
log::debug!(" -> OutputHeads::forward"); | |
let batch_size = x.dims()[0]; | |
log::debug!(" -> Batch size: {:?}", batch_size); | |
// On détermine si on utilise le curriculum pour ce batch | |
let use_curriculum = curriculum_manager.is_active() | |
&& rand::random::<f32>() < curriculum_manager.force_prob(); | |
log::info!(" -> Use curriculum: {:?}", use_curriculum); | |
// 1. Actions discrètes (softmax) | |
// --------------------------- | |
// Position Type (Long/Short/Close/Hold) => représente des probabilités pour des actions discrètes. | |
let position_type = if use_curriculum { | |
curriculum_manager.position_type(state, batch_size, device) | |
} else { | |
self.position_type.forward(x.clone()) | |
}; | |
// Action Masking pour éviter les actions invalides | |
let position_type = ActionMask::position_type(&position_type, &state); | |
// On utilise un softmax pour transformer les sorties en une distribution de probabilité. | |
// [batch_size, 4] où 4 = [Long, Short, Close, Hold] | |
let position_type = softmax(position_type, 1); | |
// Order Type (Maker/Taker) => représente des probabilités pour des actions discrètes. | |
let order_type = if use_curriculum { | |
curriculum_manager.order_type(batch_size, device) | |
} else { | |
self.order_type.forward(x.clone()) | |
}; | |
// On utilise un softmax pour transformer les sorties en une distribution de probabilité. | |
// [batch_size, 2] où 2 = [Maker, Taker] | |
let order_type = softmax(order_type, 1); | |
// Leverage (1x/2x/3x/4x/5x/10x/20x/25x/50x/75x/100x/125x) => représente des probabilités pour des actions discrètes. | |
let leverage = if use_curriculum { | |
curriculum_manager.leverage(batch_size, device) | |
} else { | |
self.leverage.forward(x.clone()) | |
}; | |
// Action Masking pour éviter les actions invalides | |
let leverage = ActionMask::leverage(&leverage, max_leverage); | |
// On utilise un softmax pour transformer les sorties en une distribution de probabilité. | |
// [batch_size, 12] où 12 = [1x, 2x, 3x, 4x, 5x, 10x, 20x, 25x, 50x, 75x, 100x, 125x] | |
let leverage = softmax(leverage, 1); | |
// 2. Actions continues - mean et log_std | |
// position_size et entry_distance font partie de la décision d’action directe de l’Agent | |
// → Il doit explorer pour apprendre. | |
// ----------------------------------- | |
// Position Size (Taille de position) => une valeur continue | |
let (position_size_mean, position_size_std, position_size_log_std) = | |
if use_curriculum && (curriculum_manager.was_long_or_short()) { | |
curriculum_manager.position_size(batch_size, device) | |
} else { | |
let mean = self.position_size_mean.forward(x.clone()); | |
let log_std = self.position_size_std.forward(x.clone()).clamp(-5.0, 2.0); | |
let std = log_std.clone().exp(); | |
(mean, std, log_std) | |
}; | |
// Entry Distance (Distance relative) => une valeur continue | |
let (entry_distance_mean, entry_distance_std, entry_distance_log_std) = | |
if use_curriculum && (curriculum_manager.was_long_or_short()) { | |
curriculum_manager.entry_distance(batch_size, device) | |
} else { | |
let mean = self.entry_distance_mean.forward(x.clone()); | |
let log_std = self.entry_distance_std.forward(x.clone()).clamp(-5.0, 2.0); | |
let std = log_std.clone().exp(); | |
(mean, std, log_std) | |
}; | |
// 3. Échantillonnage des actions continues | |
// ---------------------------------------- | |
let distribution = Distribution::Uniform(0.0, 1.0); | |
let noise_position = | |
Tensor::<B, 2>::random(position_size_mean.shape(), distribution, &x.device()); | |
let position_size = | |
(tanh(position_size_mean.clone() + position_size_std.clone() * noise_position) + 1.0) | |
/ 2.0; | |
let noise_entry = | |
Tensor::<B, 2>::random(entry_distance_mean.shape(), distribution, &x.device()); | |
let entry_distance = | |
(tanh(entry_distance_mean.clone() + entry_distance_std.clone() * noise_entry) + 1.0) | |
/ 2.0; | |
// 4. Log-probabilité des actions échantillonnées | |
// → Les actions avec log-probs participent à l'optimisation entropique de SAC | |
// ---------------------------------------------- | |
// Actions continues | |
let log_prob_position_size = log_prob_gaussian( | |
position_size.clone(), | |
position_size_mean.clone(), | |
position_size_std, | |
position_size_log_std, | |
self.epsilon, | |
); | |
let log_prob_entry_distance = log_prob_gaussian( | |
entry_distance.clone(), | |
entry_distance_mean.clone(), | |
entry_distance_std, | |
entry_distance_log_std, | |
self.epsilon, | |
); | |
// Actions discrètes | |
let log_prob_position_type = log_prob_categorical(position_type.clone(), self.epsilon); | |
let log_prob_order_type = log_prob_categorical(order_type.clone(), self.epsilon); | |
let log_probs = log_prob_position_type | |
+ log_prob_order_type | |
+ log_prob_position_size | |
+ log_prob_entry_distance; | |
// 5. Autres actions continues (déterninistes) | |
// tp_ratio et sl_ratio sont des réglages de gestion du risque | |
// → Il doit optimiser, pas tester aléatoirement. | |
// → Les actions sans log-probs (déterministes) sont optimisées uniquement par le gradient des Q-values | |
// ------------------------------------------- | |
// Take Profit Ratio => Une valeur continue, toujours positive. | |
let tp_ratio = if use_curriculum && curriculum_manager.was_long_or_short() { | |
curriculum_manager.get_forced_tp(batch_size, device) | |
} else { | |
self.tp_ratio.forward(x.clone()) | |
}; | |
// On utilise un tanh pour éviter les écrasements de valeurs (par rapport à ReLU) et on convertit [-1,1] en [0,1] | |
let tp_ratio = tanh(tp_ratio).add_scalar(1.0).div_scalar(2.0); | |
// Stop Loss Ratio => Une valeur continue, toujours positive. | |
let sl_ratio = if use_curriculum && curriculum_manager.was_long_or_short() { | |
curriculum_manager.get_forced_sl(batch_size, device) | |
} else { | |
self.sl_ratio.forward(x) | |
}; | |
// On utilise un tanh pour éviter les écrasements de valeurs (par rapport à ReLU) et on convertit [-1,1] en [0,1] | |
let sl_ratio = tanh(sl_ratio).add_scalar(1.0).div_scalar(2.0); | |
// 6. Construction du résultat pour chaque élément du batch | |
// ---------------------------------------------------- | |
let mut actions = Vec::with_capacity(batch_size); | |
// Extraction des données tensorielles | |
let position_type_data = get_tensor_data_as_f32_vec(&position_type); | |
let order_type_data = get_tensor_data_as_f32_vec(&order_type); | |
let position_size_data = get_tensor_data_as_f32_vec(&position_size); | |
let entry_distance_data = get_tensor_data_as_f32_vec(&entry_distance); | |
let tp_ratio_data = get_tensor_data_as_f32_vec(&tp_ratio); | |
let sl_ratio_data = get_tensor_data_as_f32_vec(&sl_ratio); | |
let leverage_data = get_tensor_data_as_f32_vec(&leverage); | |
// Construction des actions pour chaque élément du batch | |
for i in 0..batch_size { | |
// Extraction des probabilités pour position_type (qui est un vecteur de 4 valeurs) | |
let pos_type_offset = i * STATE_POSITION_TYPE_DIM; | |
let position_type = vec![ | |
position_type_data[pos_type_offset], | |
position_type_data[pos_type_offset + 1], | |
position_type_data[pos_type_offset + 2], | |
position_type_data[pos_type_offset + 3], | |
]; | |
// Extraction des probabilités pour order_type (qui est un vecteur de 2 valeurs) | |
let order_type_offset = i * ACTION_ORDER_TYPE_DIM; | |
let order_type = vec![ | |
order_type_data[order_type_offset], | |
order_type_data[order_type_offset + 1], | |
]; | |
// Extraction des probabilités pour leverage (qui est un vecteur de 12 valeurs) | |
let leverage_offset = i * ACTION_LEVERAGE_DIM; | |
let leverage = vec![ | |
leverage_data[leverage_offset], | |
leverage_data[leverage_offset + 1], | |
leverage_data[leverage_offset + 2], | |
leverage_data[leverage_offset + 3], | |
leverage_data[leverage_offset + 4], | |
leverage_data[leverage_offset + 5], | |
leverage_data[leverage_offset + 6], | |
leverage_data[leverage_offset + 7], | |
leverage_data[leverage_offset + 8], | |
leverage_data[leverage_offset + 9], | |
leverage_data[leverage_offset + 10], | |
leverage_data[leverage_offset + 11], | |
]; | |
// Extraction des valeurs scalaires | |
let position_size = position_size_data[i]; | |
let entry_distance = entry_distance_data[i]; | |
let tp_ratio = tp_ratio_data[i]; | |
let sl_ratio = sl_ratio_data[i]; | |
// Création de l'action | |
let mut action = RawAction { | |
position_type, | |
order_type, | |
position_size, | |
entry_distance, | |
tp_ratio, | |
sl_ratio, | |
leverage, | |
}; | |
// Validation des contraintes | |
self.validate_outputs(&mut action); | |
// Ajout au résultat | |
actions.push(action); | |
} | |
(actions, log_probs) | |
} | |
/// Validation des sorties et application de contraintes métier | |
fn validate_outputs(&self, outputs: &mut RawAction) { | |
log::debug!(" -> OutputHeads::validate_outputs"); | |
// 1. Vérifier que les sommes des probabilités sont correctes (≈ 1.0) | |
let position_type_sum: f32 = outputs.position_type.iter().sum(); | |
if (position_type_sum - 1.0).abs() > 0.01 { | |
// Normaliser les probabilités | |
let factor = 1.0 / position_type_sum; | |
for p in outputs.position_type.iter_mut() { | |
*p *= factor; | |
} | |
} | |
let order_type_sum: f32 = outputs.order_type.iter().sum(); | |
if (order_type_sum - 1.0).abs() > 0.01 { | |
// Normaliser les probabilités | |
let factor = 1.0 / order_type_sum; | |
for p in outputs.order_type.iter_mut() { | |
*p *= factor; | |
} | |
} | |
// 2. Contraindre les valeurs continues dans des plages valides | |
// Position size entre 0.0 et 1.0 | |
outputs.position_size = outputs.position_size.clamp(0.0, 1.0); | |
// TP/SL ratio strictement positifs (minimum 0.005 soit 0.5%) | |
outputs.tp_ratio = outputs.tp_ratio.max(0.005); | |
outputs.sl_ratio = outputs.sl_ratio.max(0.005); | |
// Entry distance doit être positive mais pas trop grande | |
outputs.entry_distance = outputs.entry_distance.clamp(0.0, 0.05); // Max 5% | |
// 3. Vérifier les cas incohérents | |
// Si Close a la plus haute probabilité mais qu'aucune position n'est | |
// actuellement ouverte, redistribuer les probabilités | |
// let close_idx = POSITION_TYPE_CLOSE_INDEX as usize; | |
// Empêcher l'agent de choisir close s'il n'y a pas de position | |
// Cette vérification est également faite dans l'environment, | |
// mais ajouter cette contrainte ici aide l'agent à apprendre | |
// que close n'est pas une action valide sans position | |
// if outputs.position_type[close_idx] > 0.5 { | |
// // Ne pas faire de correction ici, laissons l'environnement | |
// // gérer cette règle et fournir un signal d'apprentissage négatif | |
// log::warn!( | |
// " → Action Close a une probabilité élevée: {:.2} alors qu'aucune position n'est ouverte", | |
// outputs.position_type[close_idx] | |
// ); | |
// } | |
// 4. Assurer que les rapports risk/reward sont raisonnables | |
if outputs.tp_ratio < outputs.sl_ratio { | |
log::info!( | |
" → TP < SL: TP={:.4}, SL={:.4}", | |
outputs.tp_ratio, | |
outputs.sl_ratio | |
); | |
// Le TP doit généralement être plus grand que le SL | |
// Ajuster pour avoir un ratio risk/reward minimum de 1:1 | |
outputs.tp_ratio = outputs.sl_ratio * 1.0; | |
log::info!( | |
" → TP ajusté pour assurer un ratio risk/reward minimum de 1:1: TP={:.4}, SL={:.4}", | |
outputs.tp_ratio, | |
outputs.sl_ratio | |
); | |
} | |
} | |
} | |
/// Sous-module pour la fusion des features | |
#[derive(Module, Debug)] | |
pub struct Fusion<B: Backend> { | |
// Gate dynamique | |
gate_layer: Linear<B>, | |
// Attention multi-tête | |
mha: MultiHeadAttention<B>, | |
// Projet final après la fusion | |
projection_layer: Linear<B>, | |
// Normalisation | |
norm: LayerNorm<B>, | |
/// Dropout | |
dropout: Option<Dropout>, | |
/// Dimension d'entrée pour les snapshots | |
//snapshot_input_dim: usize, | |
/// Dimension d'entrée pour les indicateurs | |
//indicators_input_dim: usize, | |
/// Dimension de la représentation fusionnée | |
sequence_length: usize, | |
} | |
impl<B: Backend> Fusion<B> { | |
pub fn new(config: &ActorConfig, device: &B::Device) -> Self { | |
let mha_config = MultiHeadAttentionConfig { | |
n_heads: config.attention_n_heads, | |
d_model: config.fusion_latent_dim, | |
min_float: config.attention_min_float, | |
quiet_softmax: config.attention_quiet_softmax, | |
// Kaiming Uniform : | |
// - Si les activations suivantes utilisent ReLU ou GELU. | |
// - Si vous voulez une meilleure robustesse sur des entrées non normalisées. | |
// Xavier Uniform : | |
// - Si vos activations incluent des fonctions symétriques comme tanh. | |
// - Si vous constatez une instabilité des gradients avec Kaiming. | |
// Normal : | |
// - Si vous voulez expérimenter avec des variances faibles pour un apprentissage plus stable. | |
initializer: burn::nn::Initializer::KaimingUniform { | |
gain: config.attention_init_gain, | |
fan_out_only: config.attention_init_fan_out_only, | |
}, | |
dropout: config.dropout_rate.unwrap_or(0.0), // 0 si pas de dropout | |
}; | |
let dropout = config.dropout_rate.map(|rate| create_dropout_layer(rate)); | |
let total_input_dim = config.snapshot_processor_dims.last().unwrap() | |
+ config.indicators_processor_dims.last().unwrap() | |
+ config.state_processor_dims.last().unwrap(); | |
Self { | |
gate_layer: create_kaiming_linear_layer( | |
total_input_dim, // dimension d'entrée (concaténation des features) | |
total_input_dim, // dimension de sortie (pour matcher les dimensions des features) | |
config.fusion_bias, | |
config.fusion_init_gain, | |
config.fusion_init_fan_out_only, | |
device, | |
), | |
mha: mha_config.init(device), | |
projection_layer: create_kaiming_linear_layer( | |
config.fusion_latent_dim, | |
config.fusion_latent_dim, | |
config.fusion_bias, | |
config.fusion_init_gain, | |
config.fusion_init_fan_out_only, | |
device, | |
), | |
norm: create_layer_norm_1d(config.fusion_latent_dim, config.epsilon, device), | |
dropout, | |
//snapshot_input_dim: config.snapshot_input_dim, | |
//indicators_input_dim: config.indicators_input_dim, | |
sequence_length: config.attention_sequence_length, | |
} | |
} | |
/// Fusionne les features avec attention multi-tête | |
/// | |
/// # Arguments | |
/// * `snapshot_features` - [batch_size, MARKET_SNAPSHOT_DIM] Features du marché | |
/// * `indicators_features` - [batch_size, TECHNICAL_INDICATORS_DIM] Features des indicateurs techniques | |
/// * `state_features` - [batch_size, TOTAL_STATE_DIM] Features de l'état de l'environnement | |
/// | |
/// # Returns | |
/// * Tensor<B, 2> - [batch_size, fusion_latent_dim] Features fusionnées | |
pub fn fusion_features( | |
&self, | |
snapshot_features: Tensor<B, 2>, | |
indicators_features: Tensor<B, 2>, | |
state_features: Tensor<B, 2>, | |
device: &B::Device, | |
) -> Tensor<B, 2> { | |
log::debug!(" -> Fusion::fusion_features"); | |
// Étape 1 : Concaténation des features | |
let concatenated = Tensor::cat( | |
vec![ | |
snapshot_features.clone(), | |
indicators_features.clone(), | |
state_features.clone(), | |
], | |
1, // Concaténation sur la dernière dimension | |
); | |
log::debug!(" -> Shape of concatenated: {:?}", concatenated.dims()); | |
// Étape 2 : Gate dynamique avec 3 portes | |
let gate_weights = sigmoid(self.gate_layer.forward(concatenated)); | |
log::debug!(" -> gate_weights shape: {:?}", gate_weights.dims()); | |
// On va créer des gates adaptées aux dimensions de chaque feature | |
let snapshot_dim = snapshot_features.dims()[1]; | |
let indicators_dim = indicators_features.dims()[1]; | |
let state_dim = state_features.dims()[1]; | |
// Séparation des poids pour chaque type de feature | |
let gate_chunks = | |
gate_weights.split_with_sizes(vec![snapshot_dim, indicators_dim, state_dim], 1); | |
let snapshot_gate = gate_chunks[0].clone(); | |
let indicators_gate = gate_chunks[1].clone(); | |
let state_gate = gate_chunks[2].clone(); | |
// Application des gates | |
let gated_snapshot = snapshot_features * snapshot_gate; | |
let gated_indicators = indicators_features * indicators_gate; | |
let gated_state = state_features * state_gate; | |
// Somme pondérée des features | |
let fused_input = gated_snapshot + gated_indicators + gated_state; | |
// Récupération dynamique du batch_size et de la dimension cachée | |
let batch_size = fused_input.dims()[0]; | |
let latent_dim = fused_input.dims()[1]; | |
log::debug!(" -> Batch size: {:?}", batch_size); | |
log::debug!(" -> Latent dim: {:?}", latent_dim); | |
// Étape 3 : Préparation pour l'attention multi-tête | |
// Reshape pour ajouter la dimension sequence_length | |
// [batch_size, fusion_latent_dim] -> [batch_size, sequence_length, fusion_latent_dim] | |
let fused_input = fused_input | |
.unsqueeze_dim::<3>(2) // [batch_size, 1, fusion_latent_dim] | |
.repeat(&[1, 1, self.sequence_length]); // [batch_size, sequence_length, fusion_latent_dim] | |
log::debug!(" -> Shape of fused_input: {:?}", fused_input.dims()); | |
// Masquage de l'attention | |
// @TODO: repenser les masques suivant les cas possibles | |
// https://chatgpt.com/share/67861039-88a4-8013-82c1-16bd724da74d | |
// Paires d'indices pour l'attention multi-tête. | |
// FUSET INPUT = SNAPSHOT + INDICATORS + STATE | |
let index_pairs = [ | |
// Favoriser les interactions entre `spread` et `ATR` | |
( | |
SNAPSHOT_SPREAD_INDEX, | |
CONCATENED_INDICATORS_OFFSET + INDICATORS_ATR_INDEX, | |
), | |
// Favoriser les interactions entre `bid_volume` et `stoch_rsi` | |
( | |
SNAPSHOT_BID_VOLUME_INDEX, | |
CONCATENED_INDICATORS_OFFSET + INDICATORS_STOCH_RSI_INDEX, | |
), | |
// Si une position est ouverte, les volumes peuvent influencer les décisions (exemple : prise de profits rapides si le volume baisse brusquement). | |
( | |
SNAPSHOT_BID_VOLUME_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_TYPE_INDEX, | |
), | |
( | |
SNAPSHOT_ASK_VOLUME_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_TYPE_INDEX, | |
), | |
// Le spread peut jouer un rôle critique dans la taille de position à prendre. Un spread large implique un coût d'entrée élevé, réduisant potentiellement la taille de position optimale. | |
( | |
SNAPSHOT_SPREAD_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_TYPE_INDEX, | |
), | |
// L'ATR donne une mesure de volatilité qui peut directement influencer la distance TP/SL. Un ATR élevé justifie des distances TP/SL plus larges, et inversement. | |
( | |
CONCATENED_INDICATORS_OFFSET + INDICATORS_ATR_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_TP_INDEX, | |
), | |
( | |
CONCATENED_INDICATORS_OFFSET + INDICATORS_ATR_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_SL_INDEX, | |
), | |
// Le Stoch RSI peut influencer le choix de la distance d'entrée. Un RSI élevé (sursollicitation) pourrait impliquer des entrées plus prudentes (distance plus grande), tandis qu'un RSI bas pourrait encourager des entrées plus agressives (distance plus courte). | |
( | |
CONCATENED_INDICATORS_OFFSET + INDICATORS_STOCH_RSI_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_ENTRY_DISTANCE_INDEX, | |
), | |
// La balance actuelle pourrait ajuster les tailles de position. Par exemple, une balance faible pourrait réduire la taille maximale des positions prises, et une balance élevée pourrait encourager des tailles plus agressives. | |
( | |
CONCATENED_STATE_OFFSET + STATE_BALANCE_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_SIZE_INDEX, | |
), | |
// Si une position est ouverte depuis longtemps, cela peut influencer les ajustements dynamiques des TP/SL. Par exemple, une position maintenue longtemps pourrait justifier un ajustement pour réduire le risque. | |
( | |
CONCATENED_STATE_OFFSET + STATE_DURATION_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_TP_INDEX, | |
), | |
( | |
CONCATENED_STATE_OFFSET + STATE_DURATION_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_SL_INDEX, | |
), | |
// La relation entre close (ou open) et ATR pourrait détecter si le marché est calme ou très volatil. | |
( | |
SNAPTSHOT_CLOSE_INDEX, | |
CONCATENED_INDICATORS_OFFSET + INDICATORS_ATR_INDEX, | |
), | |
// Si on est trop proche de la liquidation et avec un gros levier → l’attention peut orienter vers un Close ou une réduction d’exposition | |
( | |
CONCATENED_STATE_OFFSET + STATE_POSITION_LIQUIDATION_DISTANCE_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_LEVERAGE_INDEX, | |
), | |
// Si on est trop proche de la liquidation et avec un gros levier → l’attention peut orienter vers un Close ou une réduction d’exposition | |
( | |
CONCATENED_STATE_OFFSET + STATE_POSITION_LIQUIDATION_DISTANCE_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_LEVERAGE_INDEX, | |
), | |
// Si la position est proche de la liquidation, peut-être qu’il faut ajuster TP/SL plus vite | |
( | |
CONCATENED_STATE_OFFSET + STATE_POSITION_LIQUIDATION_DISTANCE_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_TP_INDEX, | |
), | |
( | |
CONCATENED_STATE_OFFSET + STATE_POSITION_LIQUIDATION_DISTANCE_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_SL_INDEX, | |
), | |
// Si on est en zone risquée, peut-être que l’agent doit apprendre à réduire les tailles futures | |
( | |
CONCATENED_STATE_OFFSET + STATE_POSITION_LIQUIDATION_DISTANCE_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_SIZE_INDEX, | |
), | |
// Si la volatilité (ATR) est élevée et la liquidation proche, la probabilité de crash rapide augmente → potentiellement justifie un comportement ultra défensif. | |
( | |
CONCATENED_STATE_OFFSET + STATE_POSITION_LIQUIDATION_DISTANCE_INDEX, | |
CONCATENED_INDICATORS_OFFSET + INDICATORS_ATR_INDEX, | |
), | |
// Apprendre à entrer plus loin du marché si la distance de liquidation est faible (ex: pour éviter un squeeze direct dès l’ouverture). | |
( | |
CONCATENED_STATE_OFFSET + STATE_POSITION_LIQUIDATION_DISTANCE_INDEX, | |
CONCATENED_STATE_OFFSET + STATE_POSITION_ENTRY_DISTANCE_INDEX, | |
), | |
]; | |
if !index_pairs | |
.iter() | |
.all(|(i, j)| *i < self.sequence_length && *j < self.sequence_length) | |
{ | |
log::warn!("Des indices dépassent la longueur de la séquence pour le masque. On filtre pour que ça passe. Tu dois être en train de faire un test ;)"); | |
} | |
// Filtrer les paires d'indices qui sont dans les limites de la séquence | |
// Utile lors des tests | |
let filtered_index_pairs: Vec<(usize, usize)> = index_pairs | |
.into_iter() | |
.filter(|(i, j)| *i < self.sequence_length && *j < self.sequence_length) | |
.collect(); | |
// Vérification que les indices ne dépassent pas la séquence | |
debug_assert!( | |
filtered_index_pairs | |
.iter() | |
.all(|(i, j)| *i < self.sequence_length && *j < self.sequence_length), | |
"Some indices exceed sequence length" | |
); | |
let cross_mask = create_cross_mask::<B>( | |
self.sequence_length, | |
batch_size, | |
&filtered_index_pairs, | |
device, | |
); | |
log::debug!(" -> Shape of cross_mask: {:?}", cross_mask.dims()); | |
// Créer l'input pour la MHA en mode self-attention avec le masque | |
let mha_input = MhaInput::self_attn(fused_input.clone()).mask_attn(cross_mask); | |
log::debug!(" -> Shape of fused_input: {:?}", fused_input.dims()); | |
// Appliquer l'attention multi-tête | |
let mha_output = self.mha.forward(mha_input); | |
// Moyenne sur la dimension de séquence | |
let context = mha_output | |
.context | |
.mean_dim(1) | |
// et réduction de [batch_size, sequence_length, fusion_latent_dim] vers [batch_size, fusion_latent_dim] | |
.squeeze_dims(&[1]); | |
// Étape 4 : Projection finale | |
let projected = self.projection_layer.forward(context); | |
// Étape 4 : Dropout (si nécessaire, à voir....) | |
//let projected = self.dropout.forward(projected); | |
// Étape 5 : Normalisation des features | |
let normalized_output = self.norm.forward(projected); | |
normalized_output | |
} | |
} | |
/// Structure principale de l'acteur | |
#[derive(Module, Debug)] | |
pub struct Actor<B: Backend> { | |
/// Couche de traitement du MarketSnapshot | |
pub snapshot_processor: SnapshotProcessor<B>, | |
/// Couche de traitement des indicateurs techniques | |
pub indicators_processor: IndicatorsProcessor<B>, | |
/// Couche de traitement de l'état de l'environnement | |
pub state_processor: StateProcessor<B>, | |
/// Couche d'attention multi-têtes | |
pub fusion: Fusion<B>, | |
/// Têtes de sortie | |
pub output_heads: OutputHeads<B>, | |
} | |
impl<B: Backend> Actor<B> { | |
/// Crée un nouvel acteur | |
pub fn new(config: ActorConfig, device: &B::Device) -> Self { | |
let mut actor = Self { | |
snapshot_processor: SnapshotProcessor::new( | |
MARKET_SNAPSHOT_DIM, | |
config.snapshot_processor_dims.clone(), | |
config.snapshot_with_bias, | |
config.snapshot_init_gain, | |
config.snapshot_init_fan_out_only, | |
config.epsilon, | |
config.dropout_rate, | |
device, | |
), | |
indicators_processor: IndicatorsProcessor::new( | |
TECHNICAL_INDICATORS_DIM, | |
config.indicators_processor_dims.clone(), | |
config.indicators_with_bias, | |
config.indicators_init_gain, | |
config.indicators_init_fan_out_only, | |
config.epsilon, | |
config.dropout_rate, | |
device, | |
), | |
state_processor: StateProcessor::new( | |
STATE_TOTAL_DIM, | |
config.state_processor_dims.clone(), | |
config.state_with_bias, | |
config.state_init_gain, | |
config.state_init_fan_out_only, | |
config.epsilon, | |
config.dropout_rate, | |
device, | |
), | |
fusion: Fusion::new(&config, device), | |
output_heads: OutputHeads::new(&config, device), | |
}; | |
// Appliquer l'initialisation personnalisée si demandé | |
if config.use_custom_init { | |
let mut mapper = if let Some(seed) = config.custom_init_seed { | |
CustomInitMapper::with_seed(seed) | |
} else { | |
CustomInitMapper::random() | |
}; | |
actor = actor.map(&mut mapper); | |
} | |
actor | |
} | |
/// Forward pass principal de l'acteur | |
/// | |
/// # Arguments | |
/// * `snapshot` - Tensor [batch_size, MARKET_SNAPSHOT_DIM] contenant les données de marché | |
/// * `indicators` - Tensor [batch_size, TECHNICAL_INDICATORS_DIM] contenant les indicateurs techniques | |
/// * `state` - Tensor [batch_size, STATE_TOTAL_DIM] contenant l'état de l'environnement | |
/// | |
/// # Returns | |
/// * `Vec<RawAction>` - Vecteur de taille `batch_size` des actions à effectuer | |
/// * `LogProbsTensor<B>` - Tensor des log-probabilité des actions (pour le calcul de la perte d'entropie) | |
pub fn forward( | |
&self, | |
snapshot: MarketSnapshotTensor<B>, | |
indicators: TechnicalIndicatorsTensor<B>, | |
state: EnvironmentStateTensor<B>, | |
max_leverage: Leverage, | |
curriculum_manager: &mut CurriculumManager<B>, | |
device: &B::Device, | |
) -> (Vec<RawAction>, LogProbsTensor<B>) { | |
log::info!(" -> Actor forward pass"); | |
// 0. On s'assure que les dimensions sont correctes | |
assert_eq!( | |
snapshot.dims(), | |
[snapshot.dims()[0], MARKET_SNAPSHOT_DIM], | |
"Snapshot dimensions mismatch" | |
); | |
assert_eq!( | |
indicators.dims(), | |
[indicators.dims()[0], TECHNICAL_INDICATORS_DIM], | |
"Indicators dimensions mismatch" | |
); | |
assert_eq!( | |
state.dims(), | |
[state.dims()[0], STATE_TOTAL_DIM], | |
"State dimensions mismatch" | |
); | |
// 1. Processing des features primaires | |
let snapshot_features = self.snapshot_processor.forward(snapshot); | |
let indicators_features = self.indicators_processor.forward(indicators); | |
let state_features = self.state_processor.forward(state.clone()); | |
// 2. Fusion avec attention | |
let combined = self.fusion.fusion_features( | |
snapshot_features, | |
indicators_features, | |
state_features, | |
device, | |
); | |
// 3. Génération et validation des sorties | |
log::debug!("TODO: Optimiser le forward pour ne pas calculer log_probs tout le temps"); | |
self.output_heads | |
.forward(combined, &state, max_leverage, curriculum_manager, device) | |
} | |
/// Forward pass déterministe | |
pub fn forward_deterministic( | |
&self, | |
snapshot: MarketSnapshotTensor<B>, | |
indicators: TechnicalIndicatorsTensor<B>, | |
state: EnvironmentStateTensor<B>, | |
max_leverage: Leverage, | |
curriculum_manager: &mut CurriculumManager<B>, | |
device: &B::Device, | |
) -> RawAction { | |
log::info!(" => Actor forward pass (deterministic)"); | |
let (actions, _) = self.forward( | |
snapshot, | |
indicators, | |
state, | |
max_leverage, | |
curriculum_manager, | |
device, | |
); | |
actions | |
.into_iter() | |
.next() | |
.expect("Le batch devrait contenir au moins une action") | |
} | |
} | |
// Rend l'acteur stateful | |
impl<B: Backend> Stateful<B> for Actor<B> {} | |
// Rend l'acteur comptable | |
impl<B: Backend> Countable<B> for Actor<B> {} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use crate::trading_ai::{ | |
env::{Leverage, STATE_TOTAL_DIM}, | |
memory::{data_buffer::VecDataTensors, MarketSnapshotTensor, TechnicalIndicatorsTensor}, | |
test_utils::init_test, | |
training::CurriculumConfig, | |
utils::{check_tensor_is_finite, debug_tensor}, | |
Data, DataBuffer, DefaultAutoDiffBackend, MarketSnapshot, TechnicalIndicators, | |
}; | |
use burn::tensor::{Device, Tensor}; | |
use chrono::{TimeZone, Utc}; | |
/// Génère un batch de test avec des data simulées mais réalistes | |
fn get_test_batch<B: Backend>( | |
count: usize, | |
device: &B::Device, | |
) -> (MarketSnapshotTensor<B>, TechnicalIndicatorsTensor<B>, usize) { | |
let mut data_buffer = DataBuffer::new(); | |
data_buffer.extend( | |
vec![ | |
Data { | |
snapshot: MarketSnapshot { | |
timestamp: Utc.with_ymd_and_hms(2024, 2, 15, 14, 30, 0).unwrap(), | |
open: 52143.25, | |
high: 52144.80, | |
low: 52142.90, | |
close: 52144.50, | |
volume: 0.75, | |
best_bid: 52144.40, | |
best_ask: 52144.60, | |
bid_volume: 1.25, | |
ask_volume: 0.98, | |
spread: 0.20, | |
book_depths: [2.50, 1.75, 1.25, 0.85, 0.45, 0.65, 1.15, 1.45, 1.85, 2.25], | |
}, | |
indicators: TechnicalIndicators { | |
stoch_rsi: (75.8, 68.4), | |
ichimoku: (52143.50, 52142.80), | |
atr: 1.85, | |
}, | |
}, | |
Data { | |
snapshot: MarketSnapshot { | |
timestamp: Utc.with_ymd_and_hms(2024, 2, 15, 14, 30, 3).unwrap(), | |
open: 52144.50, | |
high: 52145.30, | |
low: 52144.20, | |
close: 52145.10, | |
volume: 1.15, | |
best_bid: 52145.00, | |
best_ask: 52145.20, | |
bid_volume: 1.45, | |
ask_volume: 0.85, | |
spread: 0.20, | |
book_depths: [2.75, 1.95, 1.45, 0.95, 0.55, 0.45, 0.95, 1.35, 1.75, 2.15], | |
}, | |
indicators: TechnicalIndicators { | |
stoch_rsi: (82.3, 71.2), | |
ichimoku: (52144.20, 52142.95), | |
atr: 1.92, | |
}, | |
}, | |
Data { | |
snapshot: MarketSnapshot { | |
timestamp: Utc.with_ymd_and_hms(2024, 2, 15, 14, 30, 6).unwrap(), | |
open: 52145.10, | |
high: 52145.40, | |
low: 52144.80, | |
close: 52144.90, | |
volume: 0.95, | |
best_bid: 52144.80, | |
best_ask: 52145.00, | |
bid_volume: 1.65, | |
ask_volume: 1.25, | |
spread: 0.20, | |
book_depths: [3.15, 2.25, 1.65, 1.15, 0.75, 0.85, 1.25, 1.55, 1.95, 2.45], | |
}, | |
indicators: TechnicalIndicators { | |
stoch_rsi: (78.5, 73.6), | |
ichimoku: (52144.60, 52143.15), | |
atr: 1.88, | |
}, | |
}, | |
Data { | |
snapshot: MarketSnapshot { | |
timestamp: Utc.with_ymd_and_hms(2024, 2, 15, 14, 30, 9).unwrap(), | |
open: 52144.90, | |
high: 52145.20, | |
low: 52144.60, | |
close: 52144.70, | |
volume: 1.35, | |
best_bid: 52144.60, | |
best_ask: 52144.90, | |
bid_volume: 1.85, | |
ask_volume: 1.45, | |
spread: 0.30, | |
book_depths: [2.95, 2.15, 1.85, 1.25, 0.85, 0.95, 1.45, 1.75, 2.15, 2.65], | |
}, | |
indicators: TechnicalIndicators { | |
stoch_rsi: (72.1, 74.2), | |
ichimoku: (52144.45, 52143.40), | |
atr: 1.83, | |
}, | |
}, | |
Data { | |
snapshot: MarketSnapshot { | |
timestamp: Utc.with_ymd_and_hms(2024, 2, 15, 14, 30, 12).unwrap(), | |
open: 52144.70, | |
high: 52144.90, | |
low: 52144.30, | |
close: 52144.40, | |
volume: 0.85, | |
best_bid: 52144.30, | |
best_ask: 52144.50, | |
bid_volume: 1.55, | |
ask_volume: 1.35, | |
spread: 0.20, | |
book_depths: [2.65, 1.95, 1.55, 1.15, 0.65, 0.75, 1.25, 1.65, 2.05, 2.55], | |
}, | |
indicators: TechnicalIndicators { | |
stoch_rsi: (65.4, 73.1), | |
ichimoku: (52144.20, 52143.55), | |
atr: 1.78, | |
}, | |
}, | |
], | |
50000.0, | |
); | |
let d = data_buffer.normalize().as_tensors(count, device); | |
debug_tensor("data_buffer", &d.0); | |
d | |
} | |
/// L'initialisation correcte de l'Actor et ses sous-modules | |
#[test] | |
fn test_actor_initialization() { | |
init_test(); | |
// Test avec config par défaut | |
let config = ActorConfig::default(); | |
let device = &<DefaultAutoDiffBackend as burn::prelude::Backend>::Device::default(); | |
let _ = Actor::<DefaultAutoDiffBackend>::new(config, device); | |
} | |
/// Le forward pass du SnapshotProcessor avec les bonnes dimensions | |
#[test] | |
fn test_snapshot_processor_forward() { | |
init_test(); | |
let config = ActorConfig::default(); | |
let device = &<DefaultAutoDiffBackend as burn::prelude::Backend>::Device::default(); | |
let actor = Actor::<DefaultAutoDiffBackend>::new(config.clone(), device); | |
let (snapshot_batch, indicators_batch, batch_size) = get_test_batch(5, device); | |
// Forward pass | |
let snapshot_processor_output = actor.snapshot_processor.forward(snapshot_batch); | |
// Vérifie les dimensions de sortie | |
assert_eq!( | |
snapshot_processor_output.dims(), | |
[ | |
batch_size, | |
*config.clone().snapshot_processor_dims.last().unwrap() | |
] | |
); | |
// Vérifions que les valeurs sont dans des plages raisonnables | |
assert!(check_tensor_is_finite(&snapshot_processor_output)); | |
// Forward pass | |
let indicators_processor_output = actor.indicators_processor.forward(indicators_batch); | |
// Vérifie les dimensions de sortie | |
assert_eq!( | |
indicators_processor_output.dims(), | |
[ | |
batch_size, | |
*config.clone().indicators_processor_dims.last().unwrap() | |
] | |
); | |
// Vérifions que les valeurs sont dans des plages raisonnables | |
assert!(check_tensor_is_finite(&indicators_processor_output)); | |
} | |
/// Le mécanisme de fusion des features | |
#[test] | |
fn test_fusion_features() { | |
init_test(); | |
let config = ActorConfig::default(); | |
let device = &<DefaultAutoDiffBackend as burn::prelude::Backend>::Device::default(); | |
let actor = Actor::<DefaultAutoDiffBackend>::new(config.clone(), device); | |
let batch_size = 2; | |
// Crée des features simulées | |
let snapshot_features = Tensor::<DefaultAutoDiffBackend, 2>::ones( | |
[batch_size, *config.snapshot_processor_dims.last().unwrap()], | |
&Device::<DefaultAutoDiffBackend>::default(), | |
); | |
let indicators_features = Tensor::<DefaultAutoDiffBackend, 2>::ones( | |
[ | |
batch_size, | |
*config.indicators_processor_dims.last().unwrap(), | |
], | |
&Device::<DefaultAutoDiffBackend>::default(), | |
); | |
let state_features = Tensor::<DefaultAutoDiffBackend, 2>::ones( | |
[batch_size, *config.state_processor_dims.last().unwrap()], | |
&Device::<DefaultAutoDiffBackend>::default(), | |
); | |
// Test la fusion | |
let fused = actor.fusion.fusion_features( | |
snapshot_features, | |
indicators_features, | |
state_features, | |
device, | |
); | |
// Vérifie les dimensions de sortie | |
assert_eq!(fused.dims(), [batch_size, config.fusion_latent_dim]); | |
} | |
/// Le forward pass complet avec les bonnes dimensions en sortie | |
#[test] | |
fn test_full_forward_pass() { | |
init_test(); | |
let config = ActorConfig::default(); | |
let device = &<DefaultAutoDiffBackend as burn::prelude::Backend>::Device::default(); | |
let actor = Actor::<DefaultAutoDiffBackend>::new(config.clone(), device); | |
let mut curriculum_manager = CurriculumManager::new(CurriculumConfig { | |
enabled: false, | |
..Default::default() | |
}); | |
let (snapshot_batch, indicators_batch, batch_size) = get_test_batch(5, device); | |
let state = Tensor::<DefaultAutoDiffBackend, 2>::ones( | |
[batch_size, STATE_TOTAL_DIM], | |
&Device::<DefaultAutoDiffBackend>::default(), | |
); | |
// Forward pass complet | |
let (actions, log_probs) = actor.forward( | |
snapshot_batch, | |
indicators_batch, | |
state, | |
Leverage::One, | |
&mut curriculum_manager, | |
device, | |
); | |
// Vérifier le nombre d'actions retournées | |
assert_eq!( | |
actions.len(), | |
batch_size, | |
"Le nombre d'actions devrait correspondre au batch_size" | |
); | |
assert_eq!( | |
log_probs.dims()[0], | |
batch_size, | |
"Le nombre de log-probs devrait correspondre au batch_size" | |
); | |
// Vérifier la structure de la première action | |
let first_action = &actions[0]; | |
// Vérifier les dimensions des vecteurs | |
assert_eq!( | |
first_action.position_type.len(), | |
config.output_position_type_dim | |
); | |
assert_eq!(first_action.order_type.len(), config.output_order_type_dim); | |
} | |
/// Les contraintes métier sur les sorties (sizing, TP/SL, etc.) | |
#[test] | |
fn test_output_constraints() { | |
init_test(); | |
let config = ActorConfig::default(); | |
let device = &<DefaultAutoDiffBackend as burn::prelude::Backend>::Device::default(); | |
let actor = Actor::<DefaultAutoDiffBackend>::new(config.clone(), device); | |
let max_leverage = Leverage::One; | |
let mut curriculum_manager = CurriculumManager::new(CurriculumConfig { | |
enabled: false, | |
..Default::default() | |
}); | |
let (snapshot_batch, indicators_batch, batch_size) = get_test_batch(5, device); | |
let state = Tensor::<DefaultAutoDiffBackend, 2>::ones( | |
[batch_size, STATE_TOTAL_DIM], | |
&Device::<DefaultAutoDiffBackend>::default(), | |
); | |
// Forward pass | |
let (actions, _) = actor.forward( | |
snapshot_batch, | |
indicators_batch, | |
state, | |
max_leverage, | |
&mut curriculum_manager, | |
device, | |
); | |
// Vérifier chaque action du batch | |
for (i, action) in actions.iter().enumerate() { | |
println!("Vérification de l'action {}", i); | |
// Taille de position doit être entre 0 et 1 | |
assert!( | |
action.position_size >= 0.0 && action.position_size <= 1.0, | |
"La taille de position doit être entre 0 et 1, reçu: {}", | |
action.position_size | |
); | |
// TP doit être positif | |
assert!( | |
action.tp_ratio >= 0.0, | |
"Le ratio TP doit être positif, reçu: {}", | |
action.tp_ratio | |
); | |
// SL doit être positif | |
assert!( | |
action.sl_ratio >= 0.0, | |
"Le ratio SL doit être positif, reçu: {}", | |
action.sl_ratio | |
); | |
// Entry distance doit être positif | |
assert!( | |
action.entry_distance >= 0.0, | |
"La distance d'entrée doit être positive, reçu: {}", | |
action.entry_distance | |
); | |
// Order type et position_type doivent être des vecteurs de probabilités valides | |
let sum_position_type: f32 = action.position_type.iter().sum(); | |
let sum_order_type: f32 = action.order_type.iter().sum(); | |
assert!( | |
(sum_position_type - 1.0).abs() < 1e-5, | |
"La somme des probabilités de position_type doit être ~1, reçu: {}", | |
sum_position_type | |
); | |
assert!( | |
(sum_order_type - 1.0).abs() < 1e-5, | |
"La somme des probabilités d'order_type doit être ~1, reçu: {}", | |
sum_order_type | |
); | |
// Vérifier que chaque probabilité est entre 0 et 1 | |
for (j, &prob) in action.position_type.iter().enumerate() { | |
assert!( | |
prob >= 0.0 && prob <= 1.0, | |
"La probabilité position_type[{}] doit être entre 0 et 1, reçu: {}", | |
j, | |
prob | |
); | |
} | |
for (j, &prob) in action.order_type.iter().enumerate() { | |
assert!( | |
prob >= 0.0 && prob <= 1.0, | |
"La probabilité order_type[{}] doit être entre 0 et 1, reçu: {}", | |
j, | |
prob | |
); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment