statspai.tmle¶
tmle ¶
Targeted Maximum Likelihood Estimation (TMLE) with Super Learner.
TMLE is a doubly robust, semiparametrically efficient estimator for causal effects that combines initial outcome regression with a targeted bias-correction step using the propensity score.
Components
- TMLE : Full TMLE estimator for ATE/ATT with targeting step
- SuperLearner : Ensemble learner for nuisance parameter estimation
References
van der Laan, M. J. & Rose, S. (2011). Targeted Learning: Causal Inference for Observational and Experimental Data. Springer Series in Statistics. [@vanderlaan2011targeted]
van der Laan, M. J., Polley, E. C., & Hubbard, A. E. (2007). Super Learner. Statistical Applications in Genetics and Molecular Biology, 6(1). [@vanderlaan2007super]
TMLE ¶
Targeted Maximum Likelihood Estimation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
DataFrame
|
|
required |
y
|
str
|
|
required |
treat
|
str
|
|
required |
covariates
|
list of str
|
|
required |
outcome_library
|
list of sklearn estimators
|
|
None
|
propensity_library
|
list of sklearn estimators
|
|
None
|
n_folds
|
int
|
|
5
|
estimand
|
str
|
|
'ATE'
|
alpha
|
float
|
|
0.05
|
propensity_bounds
|
tuple
|
|
(0.025, 0.975)
|
random_state
|
int
|
|
42
|
SuperLearner ¶
Super Learner ensemble (van der Laan et al. 2007).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
library
|
list of sklearn estimators
|
Candidate learners. If None, uses a default diverse library. |
None
|
n_folds
|
int
|
Number of cross-validation folds. |
5
|
task
|
str
|
'regression' or 'classification'. |
'regression'
|
random_state
|
int
|
|
42
|
fit ¶
Fit the Super Learner.
- Get cross-validated predictions from each base learner.
- Find optimal weights via simplex-constrained least squares.
- Refit all base learners on full data.
predict ¶
Predict using the weighted ensemble.
For classification (task='classification') the returned
values are the convex combination of base-learner probability
predictions and are clipped to (1e-6, 1 - 1e-6) so that
callers can take logit(.) without inf. For regression no
clipping is applied.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
ndarray(n, p)
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray(n)
|
|
predict_proba ¶
Predict probabilities (for classification task).
Identical to :meth:predict under task='classification' —
kept as a separate method for sklearn-style API parity.
LTMLESurvivalResult
dataclass
¶
Counterfactual survival curves and contrasts.
HALRegressor ¶
Bases: _BaseHAL
L1-penalised HAL regressor (sklearn-compatible duck-typed API).
HALClassifier ¶
Bases: _BaseHAL
L1-penalised HAL logistic classifier (sklearn-compatible duck-typed API).