Skip to content

statspai.tmle

tmle

Targeted Maximum Likelihood Estimation (TMLE) with Super Learner.

TMLE is a doubly robust, semiparametrically efficient estimator for causal effects that combines initial outcome regression with a targeted bias-correction step using the propensity score.

Components
  • TMLE : Full TMLE estimator for ATE/ATT with targeting step
  • SuperLearner : Ensemble learner for nuisance parameter estimation
References

van der Laan, M. J. & Rose, S. (2011). Targeted Learning: Causal Inference for Observational and Experimental Data. Springer Series in Statistics. [@vanderlaan2011targeted]

van der Laan, M. J., Polley, E. C., & Hubbard, A. E. (2007). Super Learner. Statistical Applications in Genetics and Molecular Biology, 6(1). [@vanderlaan2007super]

TMLE

Targeted Maximum Likelihood Estimation.

Parameters:

Name Type Description Default
data DataFrame
required
y str
required
treat str
required
covariates list of str
required
outcome_library list of sklearn estimators
None
propensity_library list of sklearn estimators
None
n_folds int
5
estimand str
'ATE'
alpha float
0.05
propensity_bounds tuple
(0.025, 0.975)
random_state int
42

fit

fit() -> CausalResult

Run TMLE and return causal effect estimates.

SuperLearner

Super Learner ensemble (van der Laan et al. 2007).

Parameters:

Name Type Description Default
library list of sklearn estimators

Candidate learners. If None, uses a default diverse library.

None
n_folds int

Number of cross-validation folds.

5
task str

'regression' or 'classification'.

'regression'
random_state int
42

fit

fit(X, y)

Fit the Super Learner.

  1. Get cross-validated predictions from each base learner.
  2. Find optimal weights via simplex-constrained least squares.
  3. Refit all base learners on full data.

predict

predict(X)

Predict using the weighted ensemble.

For classification (task='classification') the returned values are the convex combination of base-learner probability predictions and are clipped to (1e-6, 1 - 1e-6) so that callers can take logit(.) without inf. For regression no clipping is applied.

Parameters:

Name Type Description Default
X ndarray(n, p)
required

Returns:

Type Description
ndarray(n)

predict_proba

predict_proba(X)

Predict probabilities (for classification task).

Identical to :meth:predict under task='classification' — kept as a separate method for sklearn-style API parity.

summary

summary() -> str

Print Super Learner summary.

LTMLESurvivalResult dataclass

Counterfactual survival curves and contrasts.

HALRegressor

Bases: _BaseHAL

L1-penalised HAL regressor (sklearn-compatible duck-typed API).

HALClassifier

Bases: _BaseHAL

L1-penalised HAL logistic classifier (sklearn-compatible duck-typed API).