Skip to content

statspai.metalearners

metalearners

Meta-Learners for heterogeneous treatment effect estimation.

Provides S/T/X/R/DR-Learner implementations that decompose CATE estimation into standard supervised-learning sub-problems. All learners accept any scikit-learn compatible estimator.

References

Kunzel et al. (2019). Metalearners for estimating heterogeneous treatment effects using machine learning. PNAS, 116(10), 4156-4165. [@kunzel2019metalearners]

Nie & Wager (2021). Quasi-oracle estimation of heterogeneous treatment effects. Biometrika, 108(2), 299-319. [@nie2021quasi]

Kennedy (2023). Towards optimal doubly robust estimation of heterogeneous causal effects. Electronic Journal of Statistics, 17(2), 3008-3049. [@kennedy2023towards]

SLearner

S-Learner: single model, treatment as a feature.

Fits one model mu(X, D) and estimates CATE as: tau(x) = mu(x, 1) - mu(x, 0)

Simple but may under-regularise the treatment effect when the treatment variable is just one of many features.

Parameters:

Name Type Description Default
model sklearn estimator

Outcome model mu(X, D). Default: GradientBoostingRegressor.

None

fit

fit(X, Y, D)

Fit mu(X, D).

effect

effect(X)

Estimate CATE: mu(X,1) - mu(X,0).

TLearner

T-Learner: separate models for each treatment arm.

Fits mu_1(X) on treated units and mu_0(X) on controls: tau(x) = mu_1(x) - mu_0(x)

Simple and flexible but can suffer from regularisation imbalance when treatment/control group sizes differ substantially.

Parameters:

Name Type Description Default
model_0 sklearn estimator

Control outcome model. Default: GradientBoostingRegressor.

None
model_1 sklearn estimator

Treated outcome model. Default: same type as model_0.

None

XLearner

X-Learner (Künzel et al. 2019).

Two-stage procedure: 1. Fit mu_0, mu_1 (T-Learner first stage). 2. Impute individual treatment effects: - For treated: D1_i = Y_i - mu_0(X_i) - For controls: D0_i = mu_1(X_i) - Y_i 3. Fit CATE models tau_1(X) on D1 and tau_0(X) on D0. 4. Combine: tau(x) = e(x)tau_0(x) + (1-e(x))tau_1(x) where e(x) is the propensity score.

Particularly effective when treatment/control groups are very unbalanced.

Parameters:

Name Type Description Default
model_0 sklearn estimator

Control outcome model.

None
model_1 sklearn estimator

Treated outcome model.

None
cate_model_0 sklearn estimator

CATE model for control-side imputed effects.

None
cate_model_1 sklearn estimator

CATE model for treated-side imputed effects.

None
propensity_model sklearn estimator

Model for e(x) = P(D=1|X).

None

RLearner

R-Learner (Nie & Wager 2021).

Based on the Robinson (1988) decomposition: Y - m(X) = tau(X) * (D - e(X)) + epsilon

Estimates nuisance functions m(X) = E[Y|X] and e(X) = E[D|X] via cross-fitting, then minimises the R-loss: L(tau) = sum_i [ (Y_i - m_hat(X_i)) - tau(X_i)*(D_i - e_hat(X_i)) ]^2

Achieves quasi-oracle rates under mild conditions.

Parameters:

Name Type Description Default
outcome_model sklearn estimator

Model for m(X) = E[Y|X].

None
propensity_model sklearn estimator

Model for e(X) = P(D=1|X).

None
cate_model sklearn estimator

Model for tau(X). Fit on pseudo-outcome.

None
n_folds int

Cross-fitting folds for nuisance estimation.

5

DRLearner

DR-Learner (Kennedy 2023): doubly robust CATE estimation.

Constructs the doubly robust pseudo-outcome: phi(X) = mu_1(X) - mu_0(X) + D(Y - mu_1(X)) / e(X) - (1-D)(Y - mu_0(X)) / (1-e(X))

Then regresses phi on X to obtain tau(X).

Achieves oracle rates and is robust to mis-specification of either the outcome or propensity model (but not both).

Parameters:

Name Type Description Default
outcome_model sklearn estimator

Model for mu_d(X) = E[Y|X, D=d].

None
propensity_model sklearn estimator

Model for e(X) = P(D=1|X).

None
cate_model sklearn estimator

Final-stage model for tau(X).

None
n_folds int

Cross-fitting folds for nuisance estimation.

5

AutoCATEResult dataclass

Leaderboard + winner from sp.auto_cate.

Attributes:

Name Type Description
leaderboard DataFrame

One row per learner with ATE, SE, CI, R-loss, BLP calibration columns, and CATE dispersion. Sorted by R-loss ascending.

best_learner str

Full name of the chosen winner (e.g. "DR-Learner").

best_result CausalResult

Full fitted result for the winner — supports .summary(), .tidy(), .glance().

results dict[str, CausalResult]

All fitted learners keyed by short code ('s', 't', ...).

agreement DataFrame

Pearson-rho matrix of CATE vectors across learners (in-sample). High agreement suggests stable heterogeneity; low agreement suggests model dependence.

selection_rule str

Human-readable description of the rule that picked the winner.

n_obs int

Sample size (after dropping NA on modelled columns).

FunctionalCATEResult dataclass

Output of FOCaL functional CATE estimator.

ClusterCATEResult dataclass

Per-cluster CATE table.

CATEEvalResult dataclass

Output of :func:cate_eval.

Attributes:

Name Type Description
autoc float
autoc_se float
autoc_ci (float, float)
qini float
qini_se float
qini_ci (float, float)
toc_curve DataFrame

Columns q and toc; one row per quantile grid point.

n_obs int
target str

"AUTOC" (default) or "QINI".

method str

Always "Yadlowsky et al. 2025 (DR-RATE, IF-SE)".

plot

plot(ax=None, target: Optional[str] = None, figsize: Tuple[float, float] = (6.0, 4.0))

Plot the TOC curve plus a dashed zero line.

target defaults to self.target; pass "both" to overlay AUTOC's curve with a marker for the QINI weighted area.

metalearner

metalearner(data: DataFrame, y: str, treat: str, covariates: List[str], learner: str = 'dr', outcome_model: Optional[Any] = None, propensity_model: Optional[Any] = None, cate_model: Optional[Any] = None, n_folds: int = 5, n_bootstrap: int = 200, alpha: float = 0.05) -> CausalResult

Estimate heterogeneous treatment effects using meta-learners.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
y str

Outcome variable.

required
treat str

Binary treatment variable (0/1).

required
covariates list of str

Covariate / effect modifier variables.

required
learner str

Meta-learner type: 's', 't', 'x', 'r', or 'dr'.

'dr'
outcome_model sklearn estimator

Custom ML model for outcome nuisance.

None
propensity_model sklearn estimator

Custom propensity score model (used by X/R/DR learners).

None
cate_model sklearn estimator

Custom model for final CATE stage (R/DR learners).

None
n_folds int

Cross-fitting folds for nuisance estimation (used by R/DR learners and the unified AIPW SE path; see Notes).

5
n_bootstrap int

Deprecated and ignored as of v1.11.4. Previously the SE for S/T/X/R-Learner came from a re-sampling bootstrap of the fitted CATE values, which treats τ̂ as fixed and severely under-estimates uncertainty. The function now uses the AIPW influence function for SE regardless of learner=. The argument is kept for backward compatibility and will be removed in a future minor release.

200
alpha float

Significance level.

0.05

Returns:

Type Description
CausalResult

Result with ATE estimate, SE, CI, p-value, and individual CATE predictions accessible via result.model_info['cate'].

Notes

ATE / SE convention (v1.11.4+). Regardless of which CATE estimator the user selects via learner=, the population ATE and its SE are computed via the AIPW (DR) pseudo-outcome:

.. math::

\varphi_i = \hat\mu_1(X_i) - \hat\mu_0(X_i) + \frac{D_i (Y_i - \hat\mu_1(X_i))}{\hat e(X_i)} - \frac{(1-D_i)(Y_i - \hat\mu_0(X_i))}{1 - \hat e(X_i)}

with :math:\hat{\rm ATE} = \bar\varphi, :math:\widehat{\rm SE} = \sigma(\varphi)/\sqrt n. AIPW is the semiparametric-efficient estimating function for :math:E[Y(1) - Y(0)] (van der Laan & Robins 2003; Kennedy 2023), so the SE is valid for any CATE estimator. The chosen learner= determines τ̂(X) (heterogeneity); ATE inference is learner-independent.

Prior to v1.11.4, S/T/X/R-Learner used mean(τ̂) as ATE and a re-sampling bootstrap of τ̂ as SE. The bootstrap silently treated τ̂ as fixed → systematically too small SEs and severe under- coverage. ⚠️ This is a correctness fix; numerical results will change for non-DR learners.

Examples:

>>> import statspai as sp
>>> result = sp.metalearner(df, y='wage', treat='training',
...                         covariates=['age', 'edu', 'exp'])
>>> print(result.summary())
>>> # Use X-Learner with custom models
>>> from sklearn.ensemble import RandomForestRegressor
>>> result = sp.metalearner(df, y='wage', treat='training',
...                         covariates=['age', 'edu'],
...                         learner='x',
...                         outcome_model=RandomForestRegressor())
>>> # Access individual CATE predictions
>>> cate = result.model_info['cate']  # array of per-unit effects

cate_summary

cate_summary(result: CausalResult) -> DataFrame

Descriptive statistics of the CATE distribution.

Parameters:

Name Type Description Default
result CausalResult

Result from metalearner() containing model_info['cate'].

required

Returns:

Type Description
DataFrame

Summary statistics: mean, sd, min, q25, median, q75, max, fraction positive, fraction significant (> 2*SE away from 0).

cate_by_group

cate_by_group(result: CausalResult, data: DataFrame, by: str, n_groups: int = 4, alpha: float = 0.05) -> DataFrame

Group-level average treatment effects.

Splits the CATE distribution by a covariate (or by CATE quartiles if by='cate') and reports group means with confidence intervals.

Parameters:

Name Type Description Default
result CausalResult

Result from metalearner().

required
data DataFrame

Original data (same rows as the estimation sample).

required
by str

Column name to group by, or 'cate' to group by CATE quartiles.

required
n_groups int

Number of quantile groups when by='cate' or when the grouping variable is continuous.

4
alpha float

Significance level for CIs.

0.05

Returns:

Type Description
DataFrame

Columns: group, n, mean_cate, se, ci_lower, ci_upper.

cate_plot

cate_plot(result: CausalResult, kind: str = 'hist', ax=None, figsize: tuple = (8, 5), color: str = '#2C3E50', title: Optional[str] = None, **kwargs)

Plot the CATE distribution.

Parameters:

Name Type Description Default
result CausalResult

Result from metalearner().

required
kind str

'hist' for histogram, 'kde' for kernel density, 'both'.

'hist'
ax matplotlib Axes
None
figsize tuple
(8, 5)
color str
'#2C3E50'
title str
None

Returns:

Type Description
(fig, ax)

cate_group_plot

cate_group_plot(group_df: DataFrame, ax=None, figsize: tuple = (8, 5), color: str = '#2C3E50', title: Optional[str] = None)

Plot group-level CATEs with confidence intervals.

Parameters:

Name Type Description Default
group_df DataFrame

Output from cate_by_group().

required
ax matplotlib Axes
None
figsize tuple
(8, 5)
color str
'#2C3E50'
title str
None

Returns:

Type Description
(fig, ax)

predict_cate

predict_cate(result: CausalResult, new_data: DataFrame) -> ndarray

Predict CATE on new (out-of-sample) data.

Parameters:

Name Type Description Default
result CausalResult

Result from metalearner() containing a fitted estimator.

required
new_data DataFrame

New data with the same covariate columns used in estimation.

required

Returns:

Type Description
ndarray

Predicted CATE for each row of new_data.

compare_metalearners

compare_metalearners(data: DataFrame, y: str, treat: str, covariates: List[str], learners: Optional[List[str]] = None, **kwargs) -> DataFrame

Fit multiple meta-learners and compare their ATE estimates.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
y str

Outcome variable.

required
treat str

Binary treatment variable (0/1).

required
covariates list of str

Covariate / effect modifier variables.

required
learners list of str

Which learners to compare. Default: all five ('s','t','x','r','dr').

None
**kwargs

Additional arguments passed to metalearner().

{}

Returns:

Type Description
DataFrame

Comparison table with columns: learner, ate, se, ci_lower, ci_upper, pvalue, cate_std, cate_iqr.

Examples:

>>> import statspai as sp
>>> comp = sp.compare_metalearners(df, y='wage', treat='training',
...                                 covariates=['age', 'edu'])
>>> print(comp)

gate_test

gate_test(result: CausalResult, data: DataFrame, by: str, n_groups: int = 4, alpha: float = 0.05) -> Dict[str, Any]

Test for significant heterogeneity across GATE (Group ATE) groups.

Performs two tests: 1. Omnibus F-test: are all group CATEs equal? (ANOVA) 2. Top-vs-bottom: is the highest-CATE group significantly different from the lowest?

Parameters:

Name Type Description Default
result CausalResult

Result from metalearner().

required
data DataFrame

Original data (same rows as estimation).

required
by str

Column name to group by, or 'cate' for CATE quartiles.

required
n_groups int

Number of groups.

4
alpha float

Significance level.

0.05

Returns:

Type Description
dict

Keys: 'gate_table' (DataFrame), 'omnibus_F', 'omnibus_pvalue', 'top_vs_bottom_diff', 'top_vs_bottom_se', 'top_vs_bottom_pvalue'.

blp_test

blp_test(result: CausalResult, data: DataFrame, y: str, treat: str, covariates: List[str], n_folds: int = 5, alpha: float = 0.05) -> Dict[str, Any]

Best Linear Predictor (BLP) test for CATE heterogeneity.

Implements the calibration test from Chernozhukov et al. (2018, Econometrica) "Generic Machine Learning Inference on Heterogeneous Treatment Effects." Equivalent to grf::test_calibration() in R.

Fits via OLS: Y_i = alpha + beta_1 * (D_i - e(X_i)) + beta_2 * (D_i - e(X_i)) * (S(X_i) - mean(S)) + eps

where S(X) is the CATE proxy from the meta-learner and e(X) the propensity score (estimated via cross-fitting).

  • beta_1: tests whether ATE != 0 (mean forest prediction)
  • beta_2: tests whether CATE heterogeneity is real (if beta_2 > 0 and significant, the learner has found genuine heterogeneity rather than noise)

Parameters:

Name Type Description Default
result CausalResult

Result from metalearner().

required
data DataFrame

Original data.

required
y str

Outcome and treatment column names.

required
treat str

Outcome and treatment column names.

required
covariates list of str

Covariate column names.

required
n_folds int

Folds for propensity cross-fitting.

5
alpha float

Significance level.

0.05

Returns:

Type Description
dict

Keys: 'beta1' (ATE signal), 'beta1_se', 'beta1_pvalue', 'beta2' (heterogeneity signal), 'beta2_se', 'beta2_pvalue', 'heterogeneity_significant' (bool).

References

Chernozhukov, V., Demirer, M., Duflo, E., & Fernandez-Val, I. (2018). Generic Machine Learning Inference on Heterogeneous Treatment Effects in Randomized Experiments. Econometrica (forthcoming as of 2018 NBER WP). [@chernozhukov2018double]

focal_cate

focal_cate(data: DataFrame, y_columns: List[str], treat: str, covariates: List[str], test_data: Optional[DataFrame] = None, seed: int = 0) -> FunctionalCATEResult

Functional doubly-robust CATE estimator.

Parameters:

Name Type Description Default
data DataFrame

Training data with outcome columns y_columns (one per function evaluation point t).

required
y_columns list of str

Outcome columns; len = number of function points.

required
treat str
required
covariates list of str
required
test_data DataFrame

Defaults to data.

None
seed int
0

Returns:

Type Description
FunctionalCATEResult