statspai.causal¶
causal ¶
DEPRECATED — use :mod:statspai.forest instead.
This package was renamed in v1.10 to statspai.forest because it
only ever housed the four forest-based causal estimators
(causal_forest / iv_forest / multi_arm_forest / forest
inference helpers). The causal name was misleading — it
implied a top-level causal-inference namespace when in fact the
content is a single-method family.
This shim keeps both from statspai.causal import X and
from statspai.causal.causal_forest import X working by aliasing
the deprecated submodule paths through :data:sys.modules to the
real :mod:statspai.forest modules. Users see one
:class:DeprecationWarning on first import; everything else
continues working.
Plan to migrate within one minor cycle.
CausalForest ¶
Bases: BaseModel
Causal Forest for heterogeneous treatment effect estimation
This class implements the Causal Forest algorithm, which uses random forests to estimate conditional average treatment effects (CATE) in a non-parametric way.
The method combines ideas from: 1. Honest estimation to avoid overfitting 2. Double machine learning to handle confounding 3. Random forests for flexible function approximation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_estimators
|
int
|
Number of trees in the forest |
100
|
min_samples_leaf
|
int
|
Minimum number of samples required to be at a leaf node |
5
|
max_depth
|
int
|
Maximum depth of trees |
None
|
max_samples
|
float
|
Fraction of samples to use for each tree |
0.5
|
model_y
|
estimator
|
Model for outcome regression (first stage) |
None
|
model_t
|
(estimator, optional)
|
Model for treatment propensity (first stage) |
None
|
discrete_treatment
|
bool
|
Whether treatment is discrete (binary/categorical) or continuous |
True
|
honest
|
bool
|
Whether to use honest estimation (separate samples for splitting and effects) |
True
|
bootstrap
|
bool
|
Whether to use bootstrap sampling for trees |
True
|
random_state
|
int
|
Random state for reproducibility |
None
|
n_jobs
|
int
|
Number of parallel jobs |
1
|
verbose
|
int
|
Verbosity level |
0
|
Attributes:
| Name | Type | Description |
|---|---|---|
fitted_ |
bool
|
Whether the model has been fitted |
params |
Series
|
Not applicable for non-parametric methods, returns empty Series |
std_errors |
Series
|
Not applicable for non-parametric methods, returns empty Series |
tvalues |
Series
|
Not applicable for non-parametric methods, returns empty Series |
pvalues |
ndarray
|
Not applicable for non-parametric methods, returns empty array |
diagnostics |
dict
|
Model diagnostics and fit statistics |
data_info |
dict
|
Information about the data used in fitting |
Notes
This implementation is inspired by the EconML library's CausalForestDML but adapted to fit the StatsPAI architecture and interface.
Examples:
>>> import numpy as np
>>> import pandas as pd
>>> from statspai.forest import CausalForest
>>>
>>> # Generate sample data
>>> np.random.seed(42)
>>> n = 1000
>>> X = np.random.normal(0, 1, (n, 3))
>>> T = np.random.binomial(1, 0.5, n)
>>> Y = X[:, 0] * T + X[:, 1] + np.random.normal(0, 1, n)
>>> data = pd.DataFrame({
... 'Y': Y, 'T': T, 'X1': X[:, 0], 'X2': X[:, 1], 'X3': X[:, 2]
... })
>>>
>>> # Fit Causal Forest
>>> cf = CausalForest(n_estimators=50, random_state=42)
>>> cf.fit('Y ~ T | X1 + X2 + X3', data=data)
>>>
>>> # Estimate treatment effects
>>> cate = cf.effect(data[['X1', 'X2', 'X3']])
>>> print(f"Average treatment effect: {cate.mean():.3f}")
fit ¶
fit(formula: Optional[str] = None, data: Optional[DataFrame] = None, Y: Optional[ndarray] = None, T: Optional[ndarray] = None, X: Optional[ndarray] = None, W: Optional[ndarray] = None) -> CausalForest
Fit the Causal Forest model
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
formula
|
str
|
Formula specification in the form "Y ~ T | X1 + X2 + ... [| W1 + W2 + ...]" where Y is outcome, T is treatment, X are effect modifiers, W are controls |
None
|
data
|
DataFrame
|
Data containing all variables if using formula interface |
None
|
Y
|
array - like
|
Outcome variable (n_samples,) |
None
|
T
|
(array - like, optional)
|
Treatment variable (n_samples,) |
None
|
X
|
array - like
|
Effect modifier variables (n_samples, n_features) |
None
|
W
|
array - like
|
Control variables for confounding adjustment (n_samples, n_controls) |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
self |
CausalForest
|
Fitted estimator |
effect ¶
Estimate conditional average treatment effects
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
(array - like, shape(n_samples, n_features))
|
Effect modifier variables |
required |
Returns:
| Name | Type | Description |
|---|---|---|
effects |
(array - like, shape(n_samples))
|
Estimated conditional average treatment effects |
predict ¶
Generate treatment effect predictions (required by BaseModel)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
DataFrame
|
Data containing effect modifier variables. If None, uses training data. |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Predicted treatment effects |
Notes
This method is required by the BaseModel interface. For Causal Forest, "predictions" are treatment effect estimates rather than outcome predictions.
effect_interval ¶
Compute confidence intervals for treatment effects using bootstrap
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
(array - like, shape(n_samples, n_features))
|
Effect modifier variables |
required |
alpha
|
float
|
Significance level (1-alpha is confidence level) |
0.05
|
Returns:
| Name | Type | Description |
|---|---|---|
lower |
(array - like, shape(n_samples))
|
Lower bounds of confidence intervals |
upper |
(array - like, shape(n_samples))
|
Upper bounds of confidence intervals |
average_treatment_effect ¶
average_treatment_effect(X: Optional[ndarray] = None, T: Optional[ndarray] = None, target_sample: str = 'all', alpha: float = 0.05) -> Dict[str, float]
GRF-style ATE/ATT/ATC/ATO aggregation of CATE predictions.
forest_diagnostics ¶
forest_diagnostics(X: Optional[ndarray] = None, T: Optional[ndarray] = None, propensity_bounds: Tuple[float, float] = (0.05, 0.95)) -> Dict[str, object]
Overlap and CATE-distribution diagnostics for this fitted forest.
summary ¶
Return a summary of the fitted model
Returns:
| Name | Type | Description |
|---|---|---|
summary |
str
|
Model summary string |
variable_importance ¶
Permutation-based variable importance for the causal forest.
For each feature j, shuffle its column in the effect-modifier matrix and measure how much the cross-validated CATE predictions degrade (MSE increase). Higher degradation → more important for treatment-effect heterogeneity.
Returns a normalised importance score (sums to 1).
best_linear_projection ¶
best_linear_projection(X_test: Optional[ndarray] = None, alpha: float = 0.05, clip: float = 0.01) -> DataFrame
Best Linear Projection (BLP) of CATE on features (Semenova-Chernozhukov 2021).
Constructs the augmented inverse-propensity-weighted (AIPW) doubly-robust
score :math:\Gamma_i and regresses it on :math:X_i with HC1 standard
errors:
.. math:: \Gamma_i = \hat{\tau}(X_i) + \frac{T_i - \hat{e}(X_i)}{\hat{e}(X_i)(1-\hat{e}(X_i))} \bigl(Y_i - \hat{m}(X_i) - (T_i - \hat{e}(X_i))\hat{\tau}(X_i)\bigr).
:math:\Gamma_i is unbiased for :math:\tau(X_i) under standard
cross-fitting / overlap conditions; OLS of :math:\Gamma_i on
:math:(1, X_i) recovers the population BLP coefficients with
valid heteroscedasticity-robust inference (HC1).
This replaces the earlier plug-in OLS of :math:\hat{\tau}(X_i) on
:math:X_i, which produces anti-conservative SEs (the SE on a fitted
model is not the SE on the population BLP).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X_test
|
array - like
|
Features to evaluate the BLP at; defaults to in-sample X. |
None
|
alpha
|
float
|
Significance level for CIs (reported alongside coef/SE). |
0.05
|
clip
|
float
|
Propensity clip (binary discrete treatment only) to prevent the
inverse-propensity term from blowing up under near-violations of
overlap. Counts of clipped units are exposed via |
0.01
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Index |
References
Semenova V., Chernozhukov V. (2021). "Debiased Machine Learning of Conditional Average Treatment Effects and Other Causal Functions." Econometrics Journal 24(2): 264-289. DOI: 10.1093/ectj/utaa027.
att ¶
Average Treatment Effect on the Treated.
calibration_test ¶
calibration_test(forest: 'CausalForest', X: Optional[ndarray] = None, Y: Optional[ndarray] = None, T: Optional[ndarray] = None, alpha: float = 0.05) -> DataFrame
BLP-of-CATE calibration test (Chernozhukov-Demirer-Duflo-Fernandez-Val 2020).
Pseudo-outcome regression:
Ψ_i = α + β₁ · τ̂(X_i) + β₂ · (τ̂(X_i) - Eτ̂) + ε_i
where Ψ_i is the orthogonal AIPW pseudo-outcome built from the
forest's own propensity/outcome-model predictions. Hypothesis:
H₀^{(1)}: β₁ = 1 (well-calibrated mean forest prediction)
H₀^{(2)}: β₂ = 0 (no systematic CATE heterogeneity)
Rejecting H₀^{(2)} is the headline finding — it demonstrates that the forest captures real heterogeneity rather than noise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
forest
|
fitted CausalForest
|
|
required |
X
|
optional arrays
|
If not given, the forest's stored training arrays are used. |
None
|
Y
|
optional arrays
|
If not given, the forest's stored training arrays are used. |
None
|
T
|
optional arrays
|
If not given, the forest's stored training arrays are used. |
None
|
alpha
|
float
|
Significance level for reported CIs. |
0.05
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Rows |
rate ¶
rate(forest: 'CausalForest', X: Optional[ndarray] = None, Y: Optional[ndarray] = None, T: Optional[ndarray] = None, target: str = 'AUTOC', q_grid: int = 100, alpha: float = 0.05, seed: Optional[int] = None) -> Dict[str, float]
Rank-Average Treatment Effect (Yadlowsky et al. 2023).
Let S(x) = τ̂(x) denote a prioritisation score (higher = higher
priority). Define the TOC (targeting operator characteristic) curve:
TOC(q) = E[τ(X) | S(X) ≥ Q_{1-q}(S)] - E[τ(X)]
i.e. the expected CATE among the top-q fraction minus the
population ATE. Two scalar summaries are supported:
- AUTOC:
∫₀¹ TOC(q) dq— the unweighted area under the TOC curve. Emphasises prioritisation performance uniformly across the quantile range. - QINI:
∫₀¹ q · TOC(q) dq— down-weights narrow top fractions; closer to the classical uplift / Qini coefficient.
Estimation
The DR-RATE estimator from Yadlowsky et al. uses an AIPW pseudo-
outcome Ψ_i (computed from the forest's own cross-fitted
nuisance predictions) and reduces AUTOC / Qini to a weighted sum:
AUTOC_hat = (1/n) Σ_i Ψ_i · w_{AUTOC}(R_i / n) - Ψ̄
QINI_hat = (1/n) Σ_i Ψ_i · w_{QINI}(R_i / n) - (1/2) Ψ̄
where R_i is the descending rank of S(X_i) and the weights
are closed-form rank kernels. This representation makes the
estimator a sample mean of per-observation contributions φ_i, so
the variance admits the standard influence-function form
Var(AUTOC_hat) = (1/(n(n-1))) Σ_i (φ_i - φ̄)²
which replaces the conservative half-sample estimator used in the earlier draft of this function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
forest
|
fitted CausalForest
|
|
required |
X
|
arrays
|
If omitted, falls back to the forest's stored training arrays. |
None
|
Y
|
arrays
|
If omitted, falls back to the forest's stored training arrays. |
None
|
T
|
arrays
|
If omitted, falls back to the forest's stored training arrays. |
None
|
target
|
('AUTOC', 'QINI')
|
|
'AUTOC'
|
q_grid
|
int
|
Number of quantile grid points used to report the TOC curve. Does not affect the point estimate or SE (those are computed from ranks exactly). |
100
|
alpha
|
float
|
|
0.05
|
seed
|
int
|
Ignored; kept for API backwards compatibility. |
None
|
Returns:
| Type | Description |
|---|---|
dict with keys ``estimate``, ``se``, ``ci_low``, ``ci_high``,
|
|
``target``, ``toc_curve`` (``(q_grid, 2)``), ``n``, ``method``.
|
|
References
Yadlowsky, S., Fleming, S., Shah, N., Brunskill, E., Wager, S. (2021). "Evaluating Treatment Prioritization Rules via Rank- Weighted Average Treatment Effects." arXiv:2111.07966.
honest_variance ¶
honest_variance(forest: 'CausalForest', X: Optional[ndarray] = None, n_splits: int = 25, seed: Optional[int] = None) -> Dict[str, float]
Half-sample bootstrap variance of the ATE/GATE estimate.
Repeatedly partition the sample into two halves, compute the mean predicted CATE on each, and aggregate. Returns the sample variance of the per-split means divided by the number of splits — a crude but robust uncertainty quantifier when the forest's internal variance estimator is unavailable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
forest
|
fitted CausalForest
|
|
required |
X
|
ndarray
|
|
None
|
n_splits
|
int
|
Number of random half-sample draws. |
25
|
seed
|
int
|
|
None
|
Returns:
| Type | Description |
|---|---|
dict with ``ate``, ``se``, ``ci_low``, ``ci_high`` (95 %).
|
|
average_treatment_effect ¶
average_treatment_effect(forest: 'CausalForest', X: Optional[ndarray] = None, T: Optional[ndarray] = None, target_sample: str = 'all', alpha: float = 0.05, clip: float = 0.01) -> Dict[str, float]
Aggregate CATE predictions into ATE/ATT/ATC/ATO targets.
This mirrors the most-used grf::average_treatment_effect targets:
"all" (ATE), "treated" (ATT), "control" (ATC), and
"overlap" (ATO, weighted by e(X)(1-e(X))).
The estimate is the doubly-robust AIPW influence-function mean
(the estimator grf reports), not a plug-in average of the CATE
predictions. Using the forest's own cross-fitted nuisances
:math:\hat m(X)=\hat E[Y\mid X] and :math:\hat e(X)=\hat
E[T\mid X], the ATE score is
.. math:: \Gamma_i = \hat\tau(X_i) + \frac{T_i-\hat e(X_i)}{\hat e(X_i)(1-\hat e(X_i))} \bigl(Y_i-\hat m(X_i)-(T_i-\hat e(X_i))\hat\tau(X_i)\bigr),
and the ATT/ATC scores use the analogous Robins doubly-robust
weighting. se is the influence-function standard error
:math:\mathrm{sd}(\Gamma)/\sqrt n. When the score cannot be
formed (out-of-sample X with no stored nuisances) the function
falls back to the plug-in CATE average and sets method='plug_in'.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
clip
|
float
|
Propensity scores are clipped to |
0.01
|
forest_diagnostics ¶
forest_diagnostics(forest: 'CausalForest', X: Optional[ndarray] = None, T: Optional[ndarray] = None, propensity_bounds: Tuple[float, float] = (0.05, 0.95)) -> Dict[str, object]
Return overlap and CATE-distribution diagnostics for a fitted forest.