statspai.metalearners¶
metalearners ¶
Meta-Learners for heterogeneous treatment effect estimation.
Provides S/T/X/R/DR-Learner implementations that decompose CATE estimation into standard supervised-learning sub-problems. All learners accept any scikit-learn compatible estimator.
References
Kunzel et al. (2019). Metalearners for estimating heterogeneous treatment effects using machine learning. PNAS, 116(10), 4156-4165. [@kunzel2019metalearners]
Nie & Wager (2021). Quasi-oracle estimation of heterogeneous treatment effects. Biometrika, 108(2), 299-319. [@nie2021quasi]
Kennedy (2023). Towards optimal doubly robust estimation of heterogeneous causal effects. Electronic Journal of Statistics, 17(2), 3008-3049. [@kennedy2023towards]
SLearner ¶
S-Learner: single model, treatment as a feature.
Fits one model mu(X, D) and estimates CATE as: tau(x) = mu(x, 1) - mu(x, 0)
Simple but may under-regularise the treatment effect when the treatment variable is just one of many features.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
sklearn estimator
|
Outcome model mu(X, D). Default: GradientBoostingRegressor. |
None
|
TLearner ¶
T-Learner: separate models for each treatment arm.
Fits mu_1(X) on treated units and mu_0(X) on controls: tau(x) = mu_1(x) - mu_0(x)
Simple and flexible but can suffer from regularisation imbalance when treatment/control group sizes differ substantially.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_0
|
sklearn estimator
|
Control outcome model. Default: GradientBoostingRegressor. |
None
|
model_1
|
sklearn estimator
|
Treated outcome model. Default: same type as model_0. |
None
|
XLearner ¶
X-Learner (Künzel et al. 2019).
Two-stage procedure: 1. Fit mu_0, mu_1 (T-Learner first stage). 2. Impute individual treatment effects: - For treated: D1_i = Y_i - mu_0(X_i) - For controls: D0_i = mu_1(X_i) - Y_i 3. Fit CATE models tau_1(X) on D1 and tau_0(X) on D0. 4. Combine: tau(x) = e(x)tau_0(x) + (1-e(x))tau_1(x) where e(x) is the propensity score.
Particularly effective when treatment/control groups are very unbalanced.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_0
|
sklearn estimator
|
Control outcome model. |
None
|
model_1
|
sklearn estimator
|
Treated outcome model. |
None
|
cate_model_0
|
sklearn estimator
|
CATE model for control-side imputed effects. |
None
|
cate_model_1
|
sklearn estimator
|
CATE model for treated-side imputed effects. |
None
|
propensity_model
|
sklearn estimator
|
Model for e(x) = P(D=1|X). |
None
|
RLearner ¶
R-Learner (Nie & Wager 2021).
Based on the Robinson (1988) decomposition: Y - m(X) = tau(X) * (D - e(X)) + epsilon
Estimates nuisance functions m(X) = E[Y|X] and e(X) = E[D|X] via cross-fitting, then minimises the R-loss: L(tau) = sum_i [ (Y_i - m_hat(X_i)) - tau(X_i)*(D_i - e_hat(X_i)) ]^2
Achieves quasi-oracle rates under mild conditions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outcome_model
|
sklearn estimator
|
Model for m(X) = E[Y|X]. |
None
|
propensity_model
|
sklearn estimator
|
Model for e(X) = P(D=1|X). |
None
|
cate_model
|
sklearn estimator
|
Model for tau(X). Fit on pseudo-outcome. |
None
|
n_folds
|
int
|
Cross-fitting folds for nuisance estimation. |
5
|
DRLearner ¶
DR-Learner (Kennedy 2023): doubly robust CATE estimation.
Constructs the doubly robust pseudo-outcome: phi(X) = mu_1(X) - mu_0(X) + D(Y - mu_1(X)) / e(X) - (1-D)(Y - mu_0(X)) / (1-e(X))
Then regresses phi on X to obtain tau(X).
Achieves oracle rates and is robust to mis-specification of either the outcome or propensity model (but not both).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outcome_model
|
sklearn estimator
|
Model for mu_d(X) = E[Y|X, D=d]. |
None
|
propensity_model
|
sklearn estimator
|
Model for e(X) = P(D=1|X). |
None
|
cate_model
|
sklearn estimator
|
Final-stage model for tau(X). |
None
|
n_folds
|
int
|
Cross-fitting folds for nuisance estimation. |
5
|
AutoCATEResult
dataclass
¶
Leaderboard + winner from sp.auto_cate.
Attributes:
| Name | Type | Description |
|---|---|---|
leaderboard |
DataFrame
|
One row per learner with ATE, SE, CI, R-loss, BLP calibration columns, and CATE dispersion. Sorted by R-loss ascending. |
best_learner |
str
|
Full name of the chosen winner (e.g. |
best_result |
CausalResult
|
Full fitted result for the winner — supports |
results |
dict[str, CausalResult]
|
All fitted learners keyed by short code ( |
agreement |
DataFrame
|
Pearson-rho matrix of CATE vectors across learners (in-sample). High agreement suggests stable heterogeneity; low agreement suggests model dependence. |
selection_rule |
str
|
Human-readable description of the rule that picked the winner. |
n_obs |
int
|
Sample size (after dropping NA on modelled columns). |
FunctionalCATEResult
dataclass
¶
Output of FOCaL functional CATE estimator.
ClusterCATEResult
dataclass
¶
Per-cluster CATE table.
CATEEvalResult
dataclass
¶
Output of :func:cate_eval.
Attributes:
| Name | Type | Description |
|---|---|---|
autoc |
float
|
|
autoc_se |
float
|
|
autoc_ci |
(float, float)
|
|
qini |
float
|
|
qini_se |
float
|
|
qini_ci |
(float, float)
|
|
toc_curve |
DataFrame
|
Columns |
n_obs |
int
|
|
target |
str
|
|
method |
str
|
Always |
plot ¶
Plot the TOC curve plus a dashed zero line.
target defaults to self.target; pass "both" to overlay
AUTOC's curve with a marker for the QINI weighted area.
metalearner ¶
metalearner(data: DataFrame, y: str, treat: str, covariates: List[str], learner: str = 'dr', outcome_model: Optional[Any] = None, propensity_model: Optional[Any] = None, cate_model: Optional[Any] = None, n_folds: int = 5, n_bootstrap: int = 200, alpha: float = 0.05) -> CausalResult
Estimate heterogeneous treatment effects using meta-learners.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
DataFrame
|
Input data. |
required |
y
|
str
|
Outcome variable. |
required |
treat
|
str
|
Binary treatment variable (0/1). |
required |
covariates
|
list of str
|
Covariate / effect modifier variables. |
required |
learner
|
str
|
Meta-learner type: 's', 't', 'x', 'r', or 'dr'. |
'dr'
|
outcome_model
|
sklearn estimator
|
Custom ML model for outcome nuisance. |
None
|
propensity_model
|
sklearn estimator
|
Custom propensity score model (used by X/R/DR learners). |
None
|
cate_model
|
sklearn estimator
|
Custom model for final CATE stage (R/DR learners). |
None
|
n_folds
|
int
|
Cross-fitting folds for nuisance estimation (used by R/DR learners and the unified AIPW SE path; see Notes). |
5
|
n_bootstrap
|
int
|
Deprecated and ignored as of v1.11.4. Previously the SE for
S/T/X/R-Learner came from a re-sampling bootstrap of the fitted
CATE values, which treats τ̂ as fixed and severely
under-estimates uncertainty. The function now uses the AIPW
influence function for SE regardless of |
200
|
alpha
|
float
|
Significance level. |
0.05
|
Returns:
| Type | Description |
|---|---|
CausalResult
|
Result with ATE estimate, SE, CI, p-value, and individual
CATE predictions accessible via |
Notes
ATE / SE convention (v1.11.4+). Regardless of which CATE
estimator the user selects via learner=, the population ATE and
its SE are computed via the AIPW (DR) pseudo-outcome:
.. math::
\varphi_i = \hat\mu_1(X_i) - \hat\mu_0(X_i) + \frac{D_i (Y_i - \hat\mu_1(X_i))}{\hat e(X_i)} - \frac{(1-D_i)(Y_i - \hat\mu_0(X_i))}{1 - \hat e(X_i)}
with :math:\hat{\rm ATE} = \bar\varphi,
:math:\widehat{\rm SE} = \sigma(\varphi)/\sqrt n. AIPW is
the semiparametric-efficient estimating function for
:math:E[Y(1) - Y(0)] (van der Laan & Robins 2003; Kennedy 2023),
so the SE is valid for any CATE estimator. The chosen
learner= determines τ̂(X) (heterogeneity); ATE inference is
learner-independent.
Prior to v1.11.4, S/T/X/R-Learner used mean(τ̂) as ATE and a
re-sampling bootstrap of τ̂ as SE. The bootstrap silently treated
τ̂ as fixed → systematically too small SEs and severe under-
coverage. ⚠️ This is a correctness fix; numerical results will
change for non-DR learners.
Examples:
>>> import statspai as sp
>>> result = sp.metalearner(df, y='wage', treat='training',
... covariates=['age', 'edu', 'exp'])
>>> print(result.summary())
cate_summary ¶
Descriptive statistics of the CATE distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
CausalResult
|
Result from |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
Summary statistics: mean, sd, min, q25, median, q75, max, fraction positive, fraction significant (> 2*SE away from 0). |
cate_by_group ¶
cate_by_group(result: CausalResult, data: DataFrame, by: str, n_groups: int = 4, alpha: float = 0.05) -> DataFrame
Group-level average treatment effects.
Splits the CATE distribution by a covariate (or by CATE quartiles
if by='cate') and reports group means with confidence intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
CausalResult
|
Result from |
required |
data
|
DataFrame
|
Original data (same rows as the estimation sample). |
required |
by
|
str
|
Column name to group by, or 'cate' to group by CATE quartiles. |
required |
n_groups
|
int
|
Number of quantile groups when |
4
|
alpha
|
float
|
Significance level for CIs. |
0.05
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Columns: group, n, mean_cate, se, ci_lower, ci_upper. |
cate_plot ¶
cate_plot(result: CausalResult, kind: str = 'hist', ax=None, figsize: tuple = (8, 5), color: str = '#2C3E50', title: Optional[str] = None, **kwargs)
Plot the CATE distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
CausalResult
|
Result from |
required |
kind
|
str
|
'hist' for histogram, 'kde' for kernel density, 'both'. |
'hist'
|
ax
|
matplotlib Axes
|
|
None
|
figsize
|
tuple
|
|
(8, 5)
|
color
|
str
|
|
'#2C3E50'
|
title
|
str
|
|
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
cate_group_plot ¶
cate_group_plot(group_df: DataFrame, ax=None, figsize: tuple = (8, 5), color: str = '#2C3E50', title: Optional[str] = None)
Plot group-level CATEs with confidence intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
group_df
|
DataFrame
|
Output from |
required |
ax
|
matplotlib Axes
|
|
None
|
figsize
|
tuple
|
|
(8, 5)
|
color
|
str
|
|
'#2C3E50'
|
title
|
str
|
|
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax)
|
|
predict_cate ¶
Predict CATE on new (out-of-sample) data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
CausalResult
|
Result from |
required |
new_data
|
DataFrame
|
New data with the same covariate columns used in estimation. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Predicted CATE for each row of new_data. |
compare_metalearners ¶
compare_metalearners(data: DataFrame, y: str, treat: str, covariates: List[str], learners: Optional[List[str]] = None, **kwargs) -> DataFrame
Fit multiple meta-learners and compare their ATE estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
DataFrame
|
Input data. |
required |
y
|
str
|
Outcome variable. |
required |
treat
|
str
|
Binary treatment variable (0/1). |
required |
covariates
|
list of str
|
Covariate / effect modifier variables. |
required |
learners
|
list of str
|
Which learners to compare. Default: all five ('s','t','x','r','dr'). |
None
|
**kwargs
|
Additional arguments passed to |
{}
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Comparison table with columns: learner, ate, se, ci_lower, ci_upper, pvalue, cate_std, cate_iqr. |
Examples:
gate_test ¶
gate_test(result: CausalResult, data: DataFrame, by: str, n_groups: int = 4, alpha: float = 0.05) -> Dict[str, Any]
Test for significant heterogeneity across GATE (Group ATE) groups.
Performs two tests: 1. Omnibus F-test: are all group CATEs equal? (ANOVA) 2. Top-vs-bottom: is the highest-CATE group significantly different from the lowest?
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
CausalResult
|
Result from |
required |
data
|
DataFrame
|
Original data (same rows as estimation). |
required |
by
|
str
|
Column name to group by, or 'cate' for CATE quartiles. |
required |
n_groups
|
int
|
Number of groups. |
4
|
alpha
|
float
|
Significance level. |
0.05
|
Returns:
| Type | Description |
|---|---|
dict
|
Keys: 'gate_table' (DataFrame), 'omnibus_F', 'omnibus_pvalue', 'top_vs_bottom_diff', 'top_vs_bottom_se', 'top_vs_bottom_pvalue'. |
blp_test ¶
blp_test(result: CausalResult, data: DataFrame, y: str, treat: str, covariates: List[str], n_folds: int = 5, alpha: float = 0.05) -> Dict[str, Any]
Best Linear Predictor (BLP) test for CATE heterogeneity.
Implements the calibration test from Chernozhukov et al. (2018,
Econometrica) "Generic Machine Learning Inference on Heterogeneous
Treatment Effects." Equivalent to grf::test_calibration() in R.
Fits via OLS: Y_i = alpha + beta_1 * (D_i - e(X_i)) + beta_2 * (D_i - e(X_i)) * (S(X_i) - mean(S)) + eps
where S(X) is the CATE proxy from the meta-learner and e(X) the propensity score (estimated via cross-fitting).
- beta_1: tests whether ATE != 0 (mean forest prediction)
- beta_2: tests whether CATE heterogeneity is real (if beta_2 > 0 and significant, the learner has found genuine heterogeneity rather than noise)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
result
|
CausalResult
|
Result from |
required |
data
|
DataFrame
|
Original data. |
required |
y
|
str
|
Outcome and treatment column names. |
required |
treat
|
str
|
Outcome and treatment column names. |
required |
covariates
|
list of str
|
Covariate column names. |
required |
n_folds
|
int
|
Folds for propensity cross-fitting. |
5
|
alpha
|
float
|
Significance level. |
0.05
|
Returns:
| Type | Description |
|---|---|
dict
|
Keys: 'beta1' (ATE signal), 'beta1_se', 'beta1_pvalue', 'beta2' (heterogeneity signal), 'beta2_se', 'beta2_pvalue', 'heterogeneity_significant' (bool). |
References
Chernozhukov, V., Demirer, M., Duflo, E., & Fernandez-Val, I. (2018). Generic Machine Learning Inference on Heterogeneous Treatment Effects in Randomized Experiments. Econometrica (forthcoming as of 2018 NBER WP). [@chernozhukov2018double]
focal_cate ¶
focal_cate(data: DataFrame, y_columns: List[str], treat: str, covariates: List[str], test_data: Optional[DataFrame] = None, seed: int = 0) -> FunctionalCATEResult
Functional doubly-robust CATE estimator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
DataFrame
|
Training data with outcome columns y_columns (one per function evaluation point t). |
required |
y_columns
|
list of str
|
Outcome columns; len = number of function points. |
required |
treat
|
str
|
|
required |
covariates
|
list of str
|
|
required |
test_data
|
DataFrame
|
Defaults to |
None
|
seed
|
int
|
|
0
|
Returns:
| Type | Description |
|---|---|
FunctionalCATEResult
|
|