statspai.assimilation¶
assimilation ¶
Assimilative Causal Inference (sp.assimilation).
Bridges Bayesian data assimilation — the workhorse of numerical weather prediction, climate reanalysis, and oceanography — with causal inference. Proposed in
*Assimilative Causal Inference*,
Nature Communications, 2026.
The core idea is to treat the causal effect as a latent time-varying state, and update its posterior belief as new randomised or observational data batches arrive. Each update fuses:
- A forecast step propagating the prior through a user-supplied dynamics model (default: random-walk).
- An analysis step that incorporates the fresh observational or experimental batch via a Kalman-style innovation.
The result is a running posterior over the causal effect, with an effective-sample-size diagnostic that flags when new evidence should trigger a re-design of the experiment.
Why this module exists in StatsPAI
- Streaming A/B tests: the treatment effect may drift over seasons; assimilation lets you pool evidence without pretending the effect is static.
- Adaptive monitoring: public-health or policy evaluation that wants to update the target estimate monthly instead of waiting for a single large study.
- Multi-source evidence synthesis: combine an RCT prior with streaming
RWE (real-world-evidence) updates under a transport-compatible
framework (pairs naturally with :mod:
sp.transport).
AssimilationResult
dataclass
¶
Output of :func:assimilative_causal / :func:causal_kalman.
Attributes:
| Name | Type | Description |
|---|---|---|
posterior_mean |
ndarray
|
Running posterior means |
posterior_sd |
ndarray
|
Running posterior standard deviations |
posterior_ci |
ndarray
|
|
innovations |
ndarray
|
|
ess |
ndarray
|
Effective sample size per step — number of past batches the current posterior is "worth" in precision terms. |
final_mean |
float
|
|
final_sd |
float
|
|
final_ci |
tuple
|
|
method |
str
|
|
diagnostics |
dict
|
|
assimilative_causal ¶
assimilative_causal(batches: Sequence[Any], estimator: Callable[[Any], Tuple[float, float]], *, prior_mean: float = 0.0, prior_var: float = 1.0, process_var: float = 0.0, alpha: float = 0.05, backend: str = 'kalman') -> AssimilationResult
Run the Nature-Comms-2026 assimilation pipeline end-to-end.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batches
|
sequence of Any
|
Each element is a batch dataset (DataFrame, ndarray, or anything
|
required |
estimator
|
callable
|
Maps one batch to |
required |
prior_mean
|
float
|
Forwarded to :func: |
0.0
|
prior_var
|
float
|
Forwarded to :func: |
0.0
|
process_var
|
float
|
Forwarded to :func: |
0.0
|
alpha
|
float
|
Forwarded to :func: |
0.0
|
backend
|
('kalman', 'particle')
|
|
'kalman'
|
Returns:
| Type | Description |
|---|---|
AssimilationResult
|
|
Notes
Assimilative Causal Inference is non-adversarial by design: it
assumes the per-batch estimator is well-calibrated (i.e. the CIs
actually cover at their nominal rate). If the estimator is biased
the filter inherits that bias. Run the per-batch estimator through
:func:sp.smart.assumption_audit before feeding it to this pipeline.
Examples:
>>> import statspai as sp, numpy as np, pandas as pd
>>> rng = np.random.default_rng(0)
>>> def one_batch(n):
... x = rng.normal(size=n)
... d = rng.integers(0, 2, n)
... y = 0.5 * d + 0.2 * x + rng.normal(scale=0.3, size=n)
... return pd.DataFrame({'y': y, 'd': d, 'x': x})
>>> batches = [one_batch(200) for _ in range(10)]
>>> def est(df):
... r = sp.regress('y ~ d + x', data=df)
... return float(r.params['d']), float(r.std_errors['d'])
>>> res = sp.assimilation.assimilative_causal(
... batches, est, prior_mean=0.0, prior_var=1.0,
... )
>>> abs(res.final_mean - 0.5) < 0.1
True
causal_kalman ¶
causal_kalman(estimates: Sequence[float], standard_errors: Sequence[float], *, prior_mean: float = 0.0, prior_var: float = 1.0, process_var: float = 0.0, alpha: float = 0.05) -> AssimilationResult
Closed-form Kalman filter over a stream of causal-effect estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
estimates
|
sequence of float
|
Batch-level estimates |
required |
standard_errors
|
sequence of float
|
Batch-level standard errors |
required |
prior_mean
|
float
|
Prior |
(0.0, 1.0)
|
prior_var
|
float
|
Prior |
(0.0, 1.0)
|
process_var
|
float
|
State noise variance |
0.0
|
alpha
|
float
|
|
0.05
|
Returns:
| Type | Description |
|---|---|
AssimilationResult
|
|
Examples:
>>> import numpy as np
>>> import statspai as sp
>>> rng = np.random.default_rng(0)
>>> T = 20
>>> true_tau = 0.5
>>> ests = [true_tau + rng.normal(0, 0.1) for _ in range(T)]
>>> ses = [0.1] * T
>>> res = sp.assimilation.causal_kalman(
... ests, ses, prior_mean=0.0, prior_var=1.0, process_var=0.0,
... )
>>> abs(res.final_mean - 0.5) < 0.1
True
particle_filter ¶
particle_filter(estimates: Sequence[float], standard_errors: Sequence[float], *, prior_sampler: Optional[Callable[[Generator, int], ndarray]] = None, prior_mean: float = 0.0, prior_var: float = 1.0, transition_sampler: Optional[Callable[[ndarray, Generator], ndarray]] = None, process_sd: float = 0.0, observation_log_pdf: Optional[Callable[[float, ndarray, float], ndarray]] = None, n_particles: int = 2000, ess_resample_threshold: float = 0.5, alpha: float = 0.05, random_state: Optional[int] = None) -> AssimilationResult
SIR bootstrap particle filter over a stream of causal estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
estimates
|
sequences
|
Batch-level |
required |
standard_errors
|
sequences
|
Batch-level |
required |
prior_sampler
|
callable
|
|
None
|
prior_mean
|
float
|
Used only when |
0.0
|
prior_var
|
float
|
Used only when |
0.0
|
transition_sampler
|
callable
|
|
None
|
process_sd
|
float
|
Used only when |
0.0
|
observation_log_pdf
|
callable
|
|
None
|
n_particles
|
int
|
|
2000
|
ess_resample_threshold
|
float
|
Resample whenever |
0.5
|
alpha
|
float
|
|
0.05
|
random_state
|
int
|
|
None
|
Returns:
| Type | Description |
|---|---|
AssimilationResult
|
|
Examples:
>>> import numpy as np
>>> import statspai as sp
>>> rng = np.random.default_rng(0)
>>> T = 15
>>> tau = 0.5
>>> ests = [tau + rng.normal(0, 0.1) for _ in range(T)]
>>> ses = [0.1] * T
>>> # Default Gaussian obs + random-walk dynamics — should match the
>>> # Kalman filter to within Monte-Carlo noise.
>>> res_p = sp.assimilation.particle_filter(
... ests, ses, n_particles=3000, random_state=0,
... )
>>> abs(res_p.final_mean - 0.5) < 0.15
True
assimilative_causal_particle ¶
assimilative_causal_particle(batches: Sequence[Any], estimator: Callable[[Any], Tuple[float, float]], **kwargs: Any) -> AssimilationResult
Same as :func:assimilative_causal but forces the particle filter.
Forwards every kwarg to :func:particle_filter.