Skip to content

statspai.matching

matching

Matching module for StatsPAI.

Unified interface for matching estimators:

  • Nearest-neighbor matching (propensity score, Mahalanobis, Euclidean)
  • Exact matching
  • Coarsened Exact Matching (CEM)
  • Propensity score stratification / subclassification
  • Abadie-Imbens (2011) bias correction
  • Entropy balancing (Hainmueller 2012)
  • Covariate Balancing Propensity Score (Imai-Ratkovic 2014)
  • Genetic Matching (Diamond-Sekhon 2013)
  • Stable Balancing Weights (Zubizarreta 2015)
  • Optimal pair / full / cardinality matching (Rosenbaum 1989, 2012)
  • Overlap weights (Li-Morgan-Zaslavsky 2018)

The single entry point is :func:match — a method-aware dispatcher that routes method= to the correct estimator. Standalone functions (ebalance, cbps, genmatch, sbw, optimal_match, cardinality_match, overlap_weights) remain fully accessible for power users who need their estimator-specific parameters.

References

Rosenbaum, P.R. and Rubin, D.B. (1983). Biometrika, 70(1), 41-55. Abadie, A. and Imbens, G.W. (2006). Econometrica, 74(1), 235-267. Abadie, A. and Imbens, G.W. (2011). JBES, 29(1), 1-11. Iacus, S.M., King, G., and Porro, G. (2012). Political Analysis, 20(1), 1-24. Hainmueller, J. (2012). Political Analysis, 20(1), 25-46. Imai, K. and Ratkovic, M. (2014). JRSS-B, 76(1), 243-263. Diamond, A. and Sekhon, J.S. (2013). REStat, 95(3), 932-945. Zubizarreta, J.R. (2015). JASA, 110(511), 910-922. Li, F., Morgan, K.L., and Zaslavsky, A.M. (2018). JASA, 113(521), 390-400. Rosenbaum, P.R. (2012). JASA, 107(498), 691-700. Cunningham, S. (2021). Causal Inference: The Mixtape. Yale University Press. [@rosenbaum1983central]

MatchEstimator

Unified matching estimator supporting multiple distance × method combinations.

fit

fit() -> CausalResult

Fit matching estimator and return results.

PSBalanceResult

Container for propensity score balance diagnostics.

Attributes:

Name Type Description
table DataFrame

Balance statistics per covariate: mean_treat, mean_control, smd_raw, smd_weighted, variance_ratio, ks_stat.

ps Series

Estimated propensity scores.

summary

summary() -> str

Formatted balance summary table.

love_plot

love_plot(threshold: float = 0.1, **kwargs)

Convenience method: calls love_plot() from balance data.

BalanceDiagnosticsResult

Container for raw/weighted matching balance diagnostics.

SBWResult

Bases: CausalResult

Stable balancing weights with a diagnostic panel.

Thin subclass of :class:CausalResult that attaches the weight vector, effective sample size, and covariate balance table.

balanceplot

balanceplot(result: CausalResult, threshold: float = 0.1, ax=None, figsize: tuple = (8, None), title: str = None)

Love plot: covariate balance visualization (SMD dot plot).

Displays standardized mean differences (SMD) for each covariate. The standard threshold for good balance is |SMD| < 0.1.

Parameters:

Name Type Description Default
result CausalResult

Result from match() or ebalance().

required
threshold float

SMD threshold lines.

0.1
ax matplotlib Axes
None
figsize tuple

Height auto-scales with number of covariates if None.

(8, None)
title str
None

Returns:

Type Description
(fig, ax)

psplot

psplot(data: DataFrame, treat: str, covariates: List[str], *, n_bins: int = 40, ax=None, figsize: tuple = (8, 5), title: str = None, labels: tuple = ('Control', 'Treated'), colors: tuple = ('#3498DB', '#E74C3C'), trim: Optional[float] = None)

Propensity score distribution plot (common support diagnostic).

Overlays histograms of the estimated propensity score for treated and control groups, so the user can visually assess whether the common support (overlap) assumption holds.

Parameters:

Name Type Description Default
data DataFrame
required
treat str

Binary treatment column.

required
covariates list of str

Covariates used to estimate the propensity score.

required
n_bins int

Number of histogram bins.

40
ax matplotlib Axes
None
figsize tuple
(8, 5)
title str
None
labels tuple of str

Labels for (control, treated).

('Control', 'Treated')
colors tuple of str

Colors for (control, treated).

('#3498DB', '#E74C3C')
trim float

If set, draw vertical lines at (trim, 1-trim) to show the recommended trimming region.

None

Returns:

Type Description
(fig, ax)

Examples:

>>> fig, ax = sp.psplot(df, treat='D', covariates=['x1', 'x2'])
>>> fig, ax = sp.psplot(df, treat='D', covariates=['x1', 'x2'],
...                      trim=0.1)

propensity_score

propensity_score(data: DataFrame, treatment: str, covariates: List[str], method: str = 'logit', trimming: Optional[str] = None) -> Series

Estimate propensity scores P(D=1|X).

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
treatment str

Name of binary treatment column (0/1).

required
covariates list of str

Covariate column names.

required
method (logit, probit, gbm)

Estimation method. 'logit' uses IRLS (no sklearn needed). 'probit' uses scipy.optimize. 'gbm' tries sklearn GradientBoostingClassifier, falling back to logit with interactions.

'logit'
trimming (None, crump)

If 'crump', apply Crump et al. (2009) trimming after estimation. Trimmed observations receive NaN scores.

None

Returns:

Type Description
Series

Propensity scores indexed like data.

overlap_plot

overlap_plot(data: DataFrame, treatment: str, covariates: List[str], ps: Optional[Series] = None, method: str = 'logit', ax=None, figsize: Tuple[float, float] = (8, 4), title: str = 'Propensity Score Overlap') -> Tuple

Mirrored density plot of propensity scores by treatment group.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
treatment str

Binary treatment column.

required
covariates list of str

Covariates for PS estimation (ignored if ps supplied).

required
ps Series

Pre-estimated propensity scores.

None
method str

PS estimation method if ps is None.

'logit'
ax matplotlib Axes

Axes to plot on. If None, a new figure is created.

None
figsize tuple

Figure size (width, height).

(8, 4)
title str

Plot title.

'Propensity Score Overlap'

Returns:

Type Description
(fig, ax) : tuple

Matplotlib figure and axes.

trimming

trimming(data: DataFrame, treatment: str, covariates: List[str], method: str = 'crump', ps: Optional[Series] = None, ps_method: str = 'logit') -> DataFrame

Trim sample to optimal overlap region.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
treatment str

Binary treatment column.

required
covariates list of str

Covariates for PS estimation (if ps not supplied).

required
method (crump, sturmer)

'crump' uses Crump et al. (2009) optimal rule. 'sturmer' trims at the fixed [0.1, 0.9] interval.

'crump'
ps Series

Pre-estimated propensity scores. If None, estimated via ps_method.

None
ps_method str

Method for PS estimation if ps is None.

'logit'

Returns:

Type Description
DataFrame

Trimmed data (rows with PS in the overlap region).

love_plot

love_plot(data: DataFrame, treatment: str, covariates: List[str], weights: Optional[Union[ndarray, Series]] = None, threshold: float = 0.1, ps_method: str = 'logit', ax=None, figsize: Tuple[float, float] = (7, None), title: str = 'Covariate Balance (Love Plot)') -> Tuple

Love plot: dot plot of standardized mean differences before/after.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
treatment str

Binary treatment column.

required
covariates list of str

Covariate columns.

required
weights array - like

IPW or matching weights. If None, inverse-PS weights are computed.

None
threshold float

SMD threshold for the vertical dashed line (default 0.1).

0.1
ps_method str

PS estimation method for balance computation.

'logit'
ax matplotlib Axes
None
figsize tuple

(width, height). Height defaults to 0.4 * n_covariates + 1.

(7, None)
title str

Plot title.

'Covariate Balance (Love Plot)'

Returns:

Type Description
(fig, ax) : tuple

ps_balance

ps_balance(data: DataFrame, treatment: str, covariates: List[str], weights: Optional[Union[ndarray, Series]] = None, method: str = 'logit') -> PSBalanceResult

Compute comprehensive propensity score balance table.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
treatment str

Binary treatment column.

required
covariates list of str

Covariate columns to assess balance for.

required
weights array - like

IPW or matching weights. If None, inverse-PS weights are computed automatically from estimated propensity scores.

None
method str

PS estimation method ('logit', 'probit', 'gbm').

'logit'

Returns:

Type Description
PSBalanceResult

Object with .table, .ps, .summary(), .love_plot().

balance_diagnostics

balance_diagnostics(data: DataFrame, treatment: str, covariates: List[str], weights: Optional[Union[ndarray, Series, str]] = None, ps: Optional[Union[ndarray, Series, str]] = None, method: str = 'logit', threshold: float = 0.1) -> BalanceDiagnosticsResult

Unified balance diagnostics for matching and weighting estimators.

Parameters:

Name Type Description Default
data DataFrame

Analysis frame.

required
treatment str

Binary treatment indicator.

required
covariates list of str

Covariates to audit.

required
weights array - like or str

Observation weights after matching/weighting. If omitted, ATE inverse-propensity weights are computed from ps.

None
ps array - like or str

Propensity scores. If omitted, estimated with method.

None
method (logit, probit, gbm)

Propensity-score model when ps is not supplied.

'logit'
threshold float

Balance threshold for absolute standardized mean differences.

0.1

Returns:

Type Description
BalanceDiagnosticsResult

.table has one row per covariate; .summary_stats records max/mean SMDs, imbalance counts, effective sample size, and propensity-score overlap.

optimal_match

optimal_match(data: DataFrame, treatment: str, outcome: str, covariates: List[str], metric: str = 'mahalanobis', caliper: Optional[float] = None) -> OptimalMatchResult

Optimal 1:1 matching via the Hungarian algorithm.

Each treated unit is matched to exactly one control; the total sum of matched distances is globally minimised. Requires n_treated ≤ n_control.

Parameters:

Name Type Description Default
caliper float

Drop any pair with distance greater than caliper.

None

cardinality_match

cardinality_match(data: DataFrame, treatment: str, outcome: str, covariates: List[str], smd_tolerance: float = 0.1) -> CardinalityMatchResult

Cardinality matching — maximise the number of matched pairs subject to a standardised-mean-difference tolerance on every covariate.

Formulation (Zubizarreta 2014):

maximise   sum_j z_j
s.t.       |mean(X_k | T=1) - sum_j z_j X_{jk} / sum_j z_j|
            <= smd_tolerance * SD(X_k)   ∀ k
           z_j ∈ {0, 1}  for each control j

Uses a continuous LP relaxation (scipy.optimize.linprog) then rounds weights to 0/1 via a threshold — sufficient in almost all applied work. Matched pair sample is the matched controls each paired sequentially with the nearest treated in covariate space.