Skip to content

statspai.causal

causal

DEPRECATED — use :mod:statspai.forest instead.

This package was renamed in v1.10 to statspai.forest because it only ever housed the four forest-based causal estimators (causal_forest / iv_forest / multi_arm_forest / forest inference helpers). The causal name was misleading — it implied a top-level causal-inference namespace when in fact the content is a single-method family.

This shim keeps both from statspai.causal import X and from statspai.causal.causal_forest import X working by aliasing the deprecated submodule paths through :data:sys.modules to the real :mod:statspai.forest modules. Users see one :class:DeprecationWarning on first import; everything else continues working.

Plan to migrate within one minor cycle.

CausalForest

Bases: BaseModel

Causal Forest for heterogeneous treatment effect estimation

This class implements the Causal Forest algorithm, which uses random forests to estimate conditional average treatment effects (CATE) in a non-parametric way.

The method combines ideas from: 1. Honest estimation to avoid overfitting 2. Double machine learning to handle confounding 3. Random forests for flexible function approximation

Parameters:

Name Type Description Default
n_estimators int

Number of trees in the forest

100
min_samples_leaf int

Minimum number of samples required to be at a leaf node

5
max_depth int

Maximum depth of trees

None
max_samples float

Fraction of samples to use for each tree

0.5
model_y estimator

Model for outcome regression (first stage)

None
model_t (estimator, optional)

Model for treatment propensity (first stage)

None
discrete_treatment bool

Whether treatment is discrete (binary/categorical) or continuous

True
honest bool

Whether to use honest estimation (separate samples for splitting and effects)

True
bootstrap bool

Whether to use bootstrap sampling for trees

True
random_state int

Random state for reproducibility

None
n_jobs int

Number of parallel jobs

1
verbose int

Verbosity level

0

Attributes:

Name Type Description
fitted_ bool

Whether the model has been fitted

params Series

Not applicable for non-parametric methods, returns empty Series

std_errors Series

Not applicable for non-parametric methods, returns empty Series

tvalues Series

Not applicable for non-parametric methods, returns empty Series

pvalues ndarray

Not applicable for non-parametric methods, returns empty array

diagnostics dict

Model diagnostics and fit statistics

data_info dict

Information about the data used in fitting

Notes

This implementation is inspired by the EconML library's CausalForestDML but adapted to fit the StatsPAI architecture and interface.

Examples:

>>> import numpy as np
>>> import pandas as pd
>>> from statspai.forest import CausalForest
>>> 
>>> # Generate sample data
>>> np.random.seed(42)
>>> n = 1000
>>> X = np.random.normal(0, 1, (n, 3))
>>> T = np.random.binomial(1, 0.5, n)
>>> Y = X[:, 0] * T + X[:, 1] + np.random.normal(0, 1, n)
>>> data = pd.DataFrame({
...     'Y': Y, 'T': T, 'X1': X[:, 0], 'X2': X[:, 1], 'X3': X[:, 2]
... })
>>> 
>>> # Fit Causal Forest
>>> cf = CausalForest(n_estimators=50, random_state=42)
>>> cf.fit('Y ~ T | X1 + X2 + X3', data=data)
>>> 
>>> # Estimate treatment effects
>>> cate = cf.effect(data[['X1', 'X2', 'X3']])
>>> print(f"Average treatment effect: {cate.mean():.3f}")

fit

fit(formula: Optional[str] = None, data: Optional[DataFrame] = None, Y: Optional[ndarray] = None, T: Optional[ndarray] = None, X: Optional[ndarray] = None, W: Optional[ndarray] = None) -> CausalForest

Fit the Causal Forest model

Parameters:

Name Type Description Default
formula str

Formula specification in the form "Y ~ T | X1 + X2 + ... [| W1 + W2 + ...]" where Y is outcome, T is treatment, X are effect modifiers, W are controls

None
data DataFrame

Data containing all variables if using formula interface

None
Y array - like

Outcome variable (n_samples,)

None
T (array - like, optional)

Treatment variable (n_samples,)

None
X array - like

Effect modifier variables (n_samples, n_features)

None
W array - like

Control variables for confounding adjustment (n_samples, n_controls)

None

Returns:

Name Type Description
self CausalForest

Fitted estimator

effect

effect(X: ndarray) -> ndarray

Estimate conditional average treatment effects

Parameters:

Name Type Description Default
X (array - like, shape(n_samples, n_features))

Effect modifier variables

required

Returns:

Name Type Description
effects (array - like, shape(n_samples))

Estimated conditional average treatment effects

predict

predict(data: Optional[DataFrame] = None) -> ndarray

Generate treatment effect predictions (required by BaseModel)

Parameters:

Name Type Description Default
data DataFrame

Data containing effect modifier variables. If None, uses training data.

None

Returns:

Type Description
ndarray

Predicted treatment effects

Notes

This method is required by the BaseModel interface. For Causal Forest, "predictions" are treatment effect estimates rather than outcome predictions.

effect_interval

effect_interval(X: ndarray, alpha: float = 0.05) -> Tuple[ndarray, ndarray]

Compute confidence intervals for treatment effects using bootstrap

Parameters:

Name Type Description Default
X (array - like, shape(n_samples, n_features))

Effect modifier variables

required
alpha float

Significance level (1-alpha is confidence level)

0.05

Returns:

Name Type Description
lower (array - like, shape(n_samples))

Lower bounds of confidence intervals

upper (array - like, shape(n_samples))

Upper bounds of confidence intervals

average_treatment_effect

average_treatment_effect(X: Optional[ndarray] = None, T: Optional[ndarray] = None, target_sample: str = 'all', alpha: float = 0.05) -> Dict[str, float]

GRF-style ATE/ATT/ATC/ATO aggregation of CATE predictions.

forest_diagnostics

forest_diagnostics(X: Optional[ndarray] = None, T: Optional[ndarray] = None, propensity_bounds: Tuple[float, float] = (0.05, 0.95)) -> Dict[str, object]

Overlap and CATE-distribution diagnostics for this fitted forest.

summary

summary() -> str

Return a summary of the fitted model

Returns:

Name Type Description
summary str

Model summary string

variable_importance

variable_importance() -> Series

Permutation-based variable importance for the causal forest.

For each feature j, shuffle its column in the effect-modifier matrix and measure how much the cross-validated CATE predictions degrade (MSE increase). Higher degradation → more important for treatment-effect heterogeneity.

Returns a normalised importance score (sums to 1).

best_linear_projection

best_linear_projection(X_test: Optional[ndarray] = None, alpha: float = 0.05, clip: float = 0.01) -> DataFrame

Best Linear Projection (BLP) of CATE on features (Semenova-Chernozhukov 2021).

Constructs the augmented inverse-propensity-weighted (AIPW) doubly-robust score :math:\Gamma_i and regresses it on :math:X_i with HC1 standard errors:

.. math:: \Gamma_i = \hat{\tau}(X_i) + \frac{T_i - \hat{e}(X_i)}{\hat{e}(X_i)(1-\hat{e}(X_i))} \bigl(Y_i - \hat{m}(X_i) - (T_i - \hat{e}(X_i))\hat{\tau}(X_i)\bigr).

:math:\Gamma_i is unbiased for :math:\tau(X_i) under standard cross-fitting / overlap conditions; OLS of :math:\Gamma_i on :math:(1, X_i) recovers the population BLP coefficients with valid heteroscedasticity-robust inference (HC1).

This replaces the earlier plug-in OLS of :math:\hat{\tau}(X_i) on :math:X_i, which produces anti-conservative SEs (the SE on a fitted model is not the SE on the population BLP).

Parameters:

Name Type Description Default
X_test array - like

Features to evaluate the BLP at; defaults to in-sample X.

None
alpha float

Significance level for CIs (reported alongside coef/SE).

0.05
clip float

Propensity clip (binary discrete treatment only) to prevent the inverse-propensity term from blowing up under near-violations of overlap. Counts of clipped units are exposed via self.diagnostics.

0.01

Returns:

Type Description
DataFrame

Index ["Intercept", *features] with columns [coef, se, t, p, ci_lower, ci_upper]. HC1 SEs.

References

Semenova V., Chernozhukov V. (2021). "Debiased Machine Learning of Conditional Average Treatment Effects and Other Causal Functions." Econometrics Journal 24(2): 264-289. DOI: 10.1093/ectj/utaa027.

ate

ate(X: Optional[ndarray] = None) -> float

Average Treatment Effect (mean CATE).

att

att(X: Optional[ndarray] = None, T: Optional[ndarray] = None) -> float

Average Treatment Effect on the Treated.

calibration_test

calibration_test(forest: 'CausalForest', X: Optional[ndarray] = None, Y: Optional[ndarray] = None, T: Optional[ndarray] = None, alpha: float = 0.05) -> DataFrame

BLP-of-CATE calibration test (Chernozhukov-Demirer-Duflo-Fernandez-Val 2020).

Pseudo-outcome regression:

Ψ_i = α + β₁ · τ̂(X_i) + β₂ · (τ̂(X_i) - Eτ̂) + ε_i

where Ψ_i is the orthogonal AIPW pseudo-outcome built from the forest's own propensity/outcome-model predictions. Hypothesis:

H₀^{(1)}: β₁ = 1   (well-calibrated mean forest prediction)
H₀^{(2)}: β₂ = 0   (no systematic CATE heterogeneity)

Rejecting H₀^{(2)} is the headline finding — it demonstrates that the forest captures real heterogeneity rather than noise.

Parameters:

Name Type Description Default
forest fitted CausalForest
required
X optional arrays

If not given, the forest's stored training arrays are used.

None
Y optional arrays

If not given, the forest's stored training arrays are used.

None
T optional arrays

If not given, the forest's stored training arrays are used.

None
alpha float

Significance level for reported CIs.

0.05

Returns:

Type Description
DataFrame

Rows beta_mean and beta_differential with coef, se, t, p, ci_low, ci_high.

rate

rate(forest: 'CausalForest', X: Optional[ndarray] = None, Y: Optional[ndarray] = None, T: Optional[ndarray] = None, target: str = 'AUTOC', q_grid: int = 100, alpha: float = 0.05, seed: Optional[int] = None) -> Dict[str, float]

Rank-Average Treatment Effect (Yadlowsky et al. 2023).

Let S(x) = τ̂(x) denote a prioritisation score (higher = higher priority). Define the TOC (targeting operator characteristic) curve:

TOC(q) = E[τ(X) | S(X) ≥ Q_{1-q}(S)] - E[τ(X)]

i.e. the expected CATE among the top-q fraction minus the population ATE. Two scalar summaries are supported:

  • AUTOC: ∫₀¹ TOC(q) dq — the unweighted area under the TOC curve. Emphasises prioritisation performance uniformly across the quantile range.
  • QINI: ∫₀¹ q · TOC(q) dq — down-weights narrow top fractions; closer to the classical uplift / Qini coefficient.
Estimation

The DR-RATE estimator from Yadlowsky et al. uses an AIPW pseudo- outcome Ψ_i (computed from the forest's own cross-fitted nuisance predictions) and reduces AUTOC / Qini to a weighted sum:

AUTOC_hat = (1/n) Σ_i Ψ_i · w_{AUTOC}(R_i / n) - Ψ̄
QINI_hat  = (1/n) Σ_i Ψ_i · w_{QINI}(R_i / n)  - (1/2) Ψ̄

where R_i is the descending rank of S(X_i) and the weights are closed-form rank kernels. This representation makes the estimator a sample mean of per-observation contributions φ_i, so the variance admits the standard influence-function form

Var(AUTOC_hat) = (1/(n(n-1))) Σ_i (φ_i - φ̄)²

which replaces the conservative half-sample estimator used in the earlier draft of this function.

Parameters:

Name Type Description Default
forest fitted CausalForest
required
X arrays

If omitted, falls back to the forest's stored training arrays.

None
Y arrays

If omitted, falls back to the forest's stored training arrays.

None
T arrays

If omitted, falls back to the forest's stored training arrays.

None
target ('AUTOC', 'QINI')
'AUTOC'
q_grid int

Number of quantile grid points used to report the TOC curve. Does not affect the point estimate or SE (those are computed from ranks exactly).

100
alpha float
0.05
seed int

Ignored; kept for API backwards compatibility.

None

Returns:

Type Description
dict with keys ``estimate``, ``se``, ``ci_low``, ``ci_high``,
``target``, ``toc_curve`` (``(q_grid, 2)``), ``n``, ``method``.
References

Yadlowsky, S., Fleming, S., Shah, N., Brunskill, E., Wager, S. (2021). "Evaluating Treatment Prioritization Rules via Rank- Weighted Average Treatment Effects." arXiv:2111.07966.

honest_variance

honest_variance(forest: 'CausalForest', X: Optional[ndarray] = None, n_splits: int = 25, seed: Optional[int] = None) -> Dict[str, float]

Half-sample bootstrap variance of the ATE/GATE estimate.

Repeatedly partition the sample into two halves, compute the mean predicted CATE on each, and aggregate. Returns the sample variance of the per-split means divided by the number of splits — a crude but robust uncertainty quantifier when the forest's internal variance estimator is unavailable.

Parameters:

Name Type Description Default
forest fitted CausalForest
required
X ndarray
None
n_splits int

Number of random half-sample draws.

25
seed int
None

Returns:

Type Description
dict with ``ate``, ``se``, ``ci_low``, ``ci_high`` (95 %).

average_treatment_effect

average_treatment_effect(forest: 'CausalForest', X: Optional[ndarray] = None, T: Optional[ndarray] = None, target_sample: str = 'all', alpha: float = 0.05, clip: float = 0.01) -> Dict[str, float]

Aggregate CATE predictions into ATE/ATT/ATC/ATO targets.

This mirrors the most-used grf::average_treatment_effect targets: "all" (ATE), "treated" (ATT), "control" (ATC), and "overlap" (ATO, weighted by e(X)(1-e(X))).

The estimate is the doubly-robust AIPW influence-function mean (the estimator grf reports), not a plug-in average of the CATE predictions. Using the forest's own cross-fitted nuisances :math:\hat m(X)=\hat E[Y\mid X] and :math:\hat e(X)=\hat E[T\mid X], the ATE score is

.. math:: \Gamma_i = \hat\tau(X_i) + \frac{T_i-\hat e(X_i)}{\hat e(X_i)(1-\hat e(X_i))} \bigl(Y_i-\hat m(X_i)-(T_i-\hat e(X_i))\hat\tau(X_i)\bigr),

and the ATT/ATC scores use the analogous Robins doubly-robust weighting. se is the influence-function standard error :math:\mathrm{sd}(\Gamma)/\sqrt n. When the score cannot be formed (out-of-sample X with no stored nuisances) the function falls back to the plug-in CATE average and sets method='plug_in'.

Parameters:

Name Type Description Default
clip float

Propensity scores are clipped to [clip, 1-clip] before the inverse-propensity term to stabilise the score under near-overlap violations.

0.01

forest_diagnostics

forest_diagnostics(forest: 'CausalForest', X: Optional[ndarray] = None, T: Optional[ndarray] = None, propensity_bounds: Tuple[float, float] = (0.05, 0.95)) -> Dict[str, object]

Return overlap and CATE-distribution diagnostics for a fitted forest.