Skip to content

statspai.policy_learning

policy_learning

Policy Learning: Optimal treatment assignment from heterogeneous effects.

Learns an interpretable treatment assignment policy that maximises the expected welfare (value) of the population. Given estimated CATE, finds the optimal tree-based policy: "who should be treated?"

Components
  • PolicyTree : Optimal depth-limited decision tree for treatment assignment (Athey & Wager 2021).
  • policy_value : Evaluate the expected value of a treatment policy using doubly robust scores.
References

Athey, S. & Wager, S. (2021). Policy Learning with Observational Data. Econometrica, 89(1), 133-161. [@athey2021matrix]

Zhou, Z., Athey, S., & Wager, S. (2023). Offline Multi-Action Policy Learning: Generalization and Optimization. Operations Research, 71(1), 148-183. [@zhou2023offline]

PolicyTree

Optimal depth-limited policy tree.

Parameters:

Name Type Description Default
data DataFrame
required
y str
required
treat str
required
covariates list of str
required
policy_covariates list of str
None
max_depth int
2
min_leaf_size int
25
n_folds int
5
alpha float
0.05
random_state int
42

fit

fit() -> Dict[str, Any]

Learn the optimal policy tree.

predict

predict(X_new: ndarray) -> ndarray

Predict treatment assignment for new data.

Parameters:

Name Type Description Default
X_new ndarray(n, p)

Policy covariates for new observations.

required

Returns:

Type Description
ndarray(n)

Binary treatment recommendations (0 or 1).

PolicyTreeResult

Bases: dict

Result of :func:policy_tree.

Inherits from :class:dict so the legacy result['policy'] API keeps working (and isinstance(result, dict) is still True), while also exposing rich attribute access plus methods:

  • :attr:value_policy_se — influence-function SE of the policy value, computed from the AIPW scores :math:\Gamma_i and the binary policy :math:\hat\pi(X_i). Under the standard cross-fit / overlap conditions this is asymptotically valid.
  • :meth:summary / :meth:plot_tree / :meth:to_latex / :meth:cite that match the Stata / R reporting idioms.
  • :meth:to_excel for publication exports.

The tree attribute holds the fitted :class:PolicyTree instance so :meth:PolicyTree.predict is reachable downstream.

plot_tree

plot_tree(ax=None, figsize=(8.0, 5.0), node_color='#e8f0fe')

Draw the policy tree as a labeled hierarchical diagram.

Each split node shows feature ≤ threshold; each leaf shows TREAT / DON'T TREAT plus the leaf value (mean AIPW score). Requires matplotlib.

to_latex

to_latex(caption: Optional[str] = None, label: str = 'tab:policy_tree') -> str

Render a publication-style summary table (LaTeX).

to_excel

to_excel(path: str, digits: int = 4) -> str

Write a single-sheet Excel summary.

policy_value

policy_value(scores: ndarray, policy: ndarray) -> float

Evaluate the expected value of a treatment policy.

Parameters:

Name Type Description Default
scores ndarray(n)

Doubly robust scores (AIPW pseudo-outcomes for treatment). Positive scores indicate the individual benefits from treatment.

required
policy ndarray(n)

Binary policy recommendations (0 or 1).

required

Returns:

Type Description
float

Estimated expected value of the policy.

direct_method

direct_method(X: ndarray, A: ndarray, R: ndarray, pi_target, n_actions: Optional[int] = None, alpha: float = 0.05) -> OPEResult

Direct outcome regression (plug-in Q-model) OPE.

ips

ips(X: ndarray, A: ndarray, R: ndarray, pi_target, pi_behavior: Optional[ndarray] = None, clip: float = 50.0, alpha: float = 0.05) -> OPEResult

Inverse propensity score OPE.

snips

snips(X: ndarray, A: ndarray, R: ndarray, pi_target, pi_behavior: Optional[ndarray] = None, clip: float = 50.0, alpha: float = 0.05) -> OPEResult

Self-normalised IPS (bias-reduction for large IS weights).

doubly_robust

doubly_robust(X: ndarray, A: ndarray, R: ndarray, pi_target, pi_behavior: Optional[ndarray] = None, n_actions: Optional[int] = None, clip: float = 50.0, alpha: float = 0.05) -> OPEResult

Doubly-robust OPE (Dudik et al. 2011).