Skip to content

statspai.causal_rl

causal_rl

Causal Reinforcement Learning (StatsPAI v0.10).

Bridges between RL and causal inference for offline / batch learning scenarios with unobserved confounding.

References
  • Li, Zhang & Bareinboim (2025), arXiv 2510.21110 — Confounding-Robust Deep RL.
  • Cunha, Liu, French & Mian (2025), arXiv 2512.18135 — Unifying Causal RL.
  • Chemingui, Deshwal, Fern, Nguyen-Tang & Doppa (2025), arXiv 2510.22027 — Online Optimization for Offline Safe RL.

CausalDQNResult dataclass

Output of confounding-robust Q-learning.

BanditBenchmarkResult dataclass

Output from a causal-RL benchmark run.

OfflineSafeResult dataclass

Output of safe offline policy learning.

StructuralMDPResult dataclass

counterfactual_rollout

counterfactual_rollout(initial_state: ndarray, policy: Callable[[ndarray], ndarray], horizon: int = 10) -> Dict[str, ndarray]

Roll out the fitted SVAR under a new policy to get a counterfactual (state, action, reward) trajectory.

causal_rl_benchmark

causal_rl_benchmark(name: str = 'confounded_bandit', n_episodes: int = 1000, confounding_strength: float = 0.5, seed: int = 0) -> BanditBenchmarkResult

Generate a synthetic causal-RL benchmark dataset.

Parameters:

Name Type Description Default
name {'confounded_bandit', 'confounded_dosage', 'confounded_pricing',
'confounded_targeting', 'confounded_routing'}
'confounded_bandit'
n_episodes int
1000
confounding_strength float in [0, 1]

Magnitude of unmeasured confounding U → (action, reward).

0.5
seed int
0

Returns:

Type Description
BanditBenchmarkResult

offline_safe_policy

offline_safe_policy(data: DataFrame, state: str, action: str, reward: str, cost: str, cost_threshold: float = 0.5, discount: float = 0.95, n_iter: int = 100, seed: int = 0) -> OfflineSafeResult

Safe offline policy learning with a cost-constraint.

Parameters:

Name Type Description Default
data DataFrame

Transition data (s, a, r, cost).

required
state str

Column names. state and action must be discrete.

required
action str

Column names. state and action must be discrete.

required
reward str

Column names. state and action must be discrete.

required
cost str

Column names. state and action must be discrete.

required
cost_threshold float

Max allowed expected cost per step.

0.5
discount float
0.95
n_iter int
100
seed int
0

Returns:

Type Description
OfflineSafeResult

causal_bandit

causal_bandit(arms: Sequence[str], *, reward_fn: Callable[[str, dict], float], context: Optional[dict] = None, n_samples: int = 500, rng_seed: int = 0) -> CausalBanditResult

Bareinboim-Forney-Pearl contextual causal bandit.

Given a callable reward_fn(arm, context) that samples the potential outcome of an arm under the current context, Monte Carlo estimates E[Y(a) | context] for each arm and returns the argmax.

Parameters:

Name Type Description Default
arms sequence of str

Arm labels.

required
reward_fn callable

Stochastic reward sampler. Must accept (arm, context) and return a scalar reward.

required
context dict
None
n_samples int

Monte Carlo draws per arm.

500
rng_seed int
0

Returns:

Type Description
CausalBanditResult

counterfactual_policy_optimization

counterfactual_policy_optimization(data: DataFrame, *, state: str, action: str, reward: str, target_policy: Callable[[float], float], noise_sd: float = 1.0) -> CFPolicyResult

Counterfactual policy evaluation under a linear-Gaussian SCM.

Assumes a one-step SCM

r = alpha * s + beta * a + eps,  eps ~ Normal(0, noise_sd²)

so that fixing s and changing a uniquely determines a new reward via noise inversion.

Parameters:

Name Type Description Default
data DataFrame

One row per trajectory; must contain numeric state, action, and reward columns.

required
state str
required
action str
required
reward str
required
target_policy callable(float) -> float

Proposed policy a_new = π(s).

required
noise_sd float
1.0

Returns:

Type Description
CFPolicyResult

structural_mdp

structural_mdp(data: DataFrame, *, state_cols: Sequence[str], action_cols: Sequence[str], reward: str, next_state_cols: Optional[Sequence[str]] = None, time: Optional[str] = None, trajectory: Optional[str] = None) -> StructuralMDPResult

Fit a linear SVAR for a Markov decision process.

Estimates:

s_{t+1} = A s_t + B a_t + noise
r_t     = coef_s @ s_t + coef_a @ a_t

from logged tuples. Supports per-trajectory data (trajectory column groups consecutive transitions) or single-stream data with a time column.

Parameters:

Name Type Description Default
data DataFrame
required
state_cols sequence of str
required
action_cols sequence of str
required
reward str
required
next_state_cols sequence of str

If present, each row is a complete (s, a, r, s') tuple. If omitted, the function derives s_{t+1} by shifting s within each trajectory.

None
time str

Required if next_state_cols is None — used to order rows within each trajectory.

None
trajectory str

Trajectory identifier for multi-episode data.

None

Returns:

Type Description
StructuralMDPResult