Skip to content

statspai.assimilation

assimilation

Assimilative Causal Inference (sp.assimilation).

Bridges Bayesian data assimilation — the workhorse of numerical weather prediction, climate reanalysis, and oceanography — with causal inference. Proposed in

*Assimilative Causal Inference*,
Nature Communications, 2026.

The core idea is to treat the causal effect as a latent time-varying state, and update its posterior belief as new randomised or observational data batches arrive. Each update fuses:

  1. A forecast step propagating the prior through a user-supplied dynamics model (default: random-walk).
  2. An analysis step that incorporates the fresh observational or experimental batch via a Kalman-style innovation.

The result is a running posterior over the causal effect, with an effective-sample-size diagnostic that flags when new evidence should trigger a re-design of the experiment.

Why this module exists in StatsPAI
  • Streaming A/B tests: the treatment effect may drift over seasons; assimilation lets you pool evidence without pretending the effect is static.
  • Adaptive monitoring: public-health or policy evaluation that wants to update the target estimate monthly instead of waiting for a single large study.
  • Multi-source evidence synthesis: combine an RCT prior with streaming RWE (real-world-evidence) updates under a transport-compatible framework (pairs naturally with :mod:sp.transport).

AssimilationResult dataclass

Output of :func:assimilative_causal / :func:causal_kalman.

Attributes:

Name Type Description
posterior_mean ndarray

Running posterior means m_t.

posterior_sd ndarray

Running posterior standard deviations sqrt(P_t).

posterior_ci ndarray

(T, 2) array of 95% CIs.

innovations ndarray

θ̂_t - m_{t|t-1} at each step (useful for diagnostics).

ess ndarray

Effective sample size per step — number of past batches the current posterior is "worth" in precision terms.

final_mean float
final_sd float
final_ci tuple
method str
diagnostics dict

trajectory

trajectory() -> DataFrame

Convert the running posterior to a tidy DataFrame.

assimilative_causal

assimilative_causal(batches: Sequence[Any], estimator: Callable[[Any], Tuple[float, float]], *, prior_mean: float = 0.0, prior_var: float = 1.0, process_var: float = 0.0, alpha: float = 0.05, backend: str = 'kalman') -> AssimilationResult

Run the Nature-Comms-2026 assimilation pipeline end-to-end.

Parameters:

Name Type Description Default
batches sequence of Any

Each element is a batch dataset (DataFrame, ndarray, or anything estimator knows how to consume).

required
estimator callable

Maps one batch to (theta_hat, standard_error). For example, a lambda calling sp.dml(...) and extracting .estimate, .se.

required
prior_mean float

Forwarded to :func:causal_kalman.

0.0
prior_var float

Forwarded to :func:causal_kalman.

0.0
process_var float

Forwarded to :func:causal_kalman.

0.0
alpha float

Forwarded to :func:causal_kalman.

0.0
backend ('kalman', 'particle')

'kalman' uses the exact closed form. 'particle' routes to :func:sp.assimilation.particle_filter — a bootstrap-SIR filter with systematic resampling that handles non-Gaussian priors, non-Gaussian observation noise, or nonlinear dynamics. Under Gaussian DGPs the two backends agree to within Monte-Carlo noise.

'kalman'

Returns:

Type Description
AssimilationResult
Notes

Assimilative Causal Inference is non-adversarial by design: it assumes the per-batch estimator is well-calibrated (i.e. the CIs actually cover at their nominal rate). If the estimator is biased the filter inherits that bias. Run the per-batch estimator through :func:sp.smart.assumption_audit before feeding it to this pipeline.

Examples:

>>> import statspai as sp, numpy as np, pandas as pd
>>> rng = np.random.default_rng(0)
>>> def one_batch(n):
...     x = rng.normal(size=n)
...     d = rng.integers(0, 2, n)
...     y = 0.5 * d + 0.2 * x + rng.normal(scale=0.3, size=n)
...     return pd.DataFrame({'y': y, 'd': d, 'x': x})
>>> batches = [one_batch(200) for _ in range(10)]
>>> def est(df):
...     r = sp.regress('y ~ d + x', data=df)
...     return float(r.params['d']), float(r.std_errors['d'])
>>> res = sp.assimilation.assimilative_causal(
...     batches, est, prior_mean=0.0, prior_var=1.0,
... )
>>> abs(res.final_mean - 0.5) < 0.1
True

causal_kalman

causal_kalman(estimates: Sequence[float], standard_errors: Sequence[float], *, prior_mean: float = 0.0, prior_var: float = 1.0, process_var: float = 0.0, alpha: float = 0.05) -> AssimilationResult

Closed-form Kalman filter over a stream of causal-effect estimates.

Parameters:

Name Type Description Default
estimates sequence of float

Batch-level estimates θ̂_t.

required
standard_errors sequence of float

Batch-level standard errors σ_t. Converted to variances.

required
prior_mean float

Prior N(m_0, P_0) on the causal effect.

(0.0, 1.0)
prior_var float

Prior N(m_0, P_0) on the causal effect.

(0.0, 1.0)
process_var float

State noise variance Q. 0 = static effect (all batches share one truth); >0 = random-walk drift.

0.0
alpha float
0.05

Returns:

Type Description
AssimilationResult

Examples:

>>> import numpy as np
>>> import statspai as sp
>>> rng = np.random.default_rng(0)
>>> T = 20
>>> true_tau = 0.5
>>> ests = [true_tau + rng.normal(0, 0.1) for _ in range(T)]
>>> ses = [0.1] * T
>>> res = sp.assimilation.causal_kalman(
...     ests, ses, prior_mean=0.0, prior_var=1.0, process_var=0.0,
... )
>>> abs(res.final_mean - 0.5) < 0.1
True

particle_filter

particle_filter(estimates: Sequence[float], standard_errors: Sequence[float], *, prior_sampler: Optional[Callable[[Generator, int], ndarray]] = None, prior_mean: float = 0.0, prior_var: float = 1.0, transition_sampler: Optional[Callable[[ndarray, Generator], ndarray]] = None, process_sd: float = 0.0, observation_log_pdf: Optional[Callable[[float, ndarray, float], ndarray]] = None, n_particles: int = 2000, ess_resample_threshold: float = 0.5, alpha: float = 0.05, random_state: Optional[int] = None) -> AssimilationResult

SIR bootstrap particle filter over a stream of causal estimates.

Parameters:

Name Type Description Default
estimates sequences

Batch-level (θ̂_t, σ_t).

required
standard_errors sequences

Batch-level (θ̂_t, σ_t).

required
prior_sampler callable

fn(rng, n) -> ndarray draws n particles from p(θ_0). Defaults to N(prior_mean, prior_var).

None
prior_mean float

Used only when prior_sampler=None.

0.0
prior_var float

Used only when prior_sampler=None.

0.0
transition_sampler callable

fn(particles, rng) -> ndarray propagates the particles one step (the state-transition density). Defaults to a Gaussian random walk with SD process_sd.

None
process_sd float

Used only when transition_sampler=None. 0 = static effect.

0.0
observation_log_pdf callable

fn(theta_hat_t, particles, sigma_t) -> ndarray returns log p(θ̂_t | θ_t) for every particle. Defaults to the Gaussian obs model N(θ_t, σ_t^2).

None
n_particles int
2000
ess_resample_threshold float

Resample whenever ESS / N falls below this value.

0.5
alpha float
0.05
random_state int
None

Returns:

Type Description
AssimilationResult

Examples:

>>> import numpy as np
>>> import statspai as sp
>>> rng = np.random.default_rng(0)
>>> T = 15
>>> tau = 0.5
>>> ests = [tau + rng.normal(0, 0.1) for _ in range(T)]
>>> ses = [0.1] * T
>>> # Default Gaussian obs + random-walk dynamics — should match the
>>> # Kalman filter to within Monte-Carlo noise.
>>> res_p = sp.assimilation.particle_filter(
...     ests, ses, n_particles=3000, random_state=0,
... )
>>> abs(res_p.final_mean - 0.5) < 0.15
True

assimilative_causal_particle

assimilative_causal_particle(batches: Sequence[Any], estimator: Callable[[Any], Tuple[float, float]], **kwargs: Any) -> AssimilationResult

Same as :func:assimilative_causal but forces the particle filter.

Forwards every kwarg to :func:particle_filter.