statspai.dag¶
dag ¶
DAG (Directed Acyclic Graph) module for causal reasoning.
Declare causal graphs, compute adjustment sets, check for collider bias,
enumerate paths, detect bad controls, and visualize causal structures —
the Python equivalent of R's dagitty and ggdag.
import statspai as sp g = sp.dag('X -> Y; Z -> X; Z -> Y') g.adjustment_sets('X', 'Y') [{'Z'}] g.backdoor_paths('X', 'Y') g.bad_controls('X', 'Y') g.summary('X', 'Y') g.do('X') # interventional graph sp.dag_example('discrimination') # classic textbook DAG
DAG ¶
A directed acyclic graph for causal reasoning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spec
|
str
|
Edge specification. Supported formats:
|
''
|
Examples:
observed_nodes
property
¶
Nodes that are not latent (latents start with _L_).
is_collider ¶
Check if node is a collider on path.
A node is a collider on a path if both its neighbours on the
path are parents of it (arrows point into it: → node ←).
all_paths ¶
Enumerate all simple paths between x and y.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
str
|
Start and end nodes. |
required |
y
|
str
|
Start and end nodes. |
required |
directed_only
|
bool
|
If True, only follow directed edges parent→child. If False (default), traverse edges in either direction (needed for finding backdoor paths). |
False
|
Returns:
| Type | Description |
|---|---|
list of list of str
|
Each inner list is an ordered path from x to y. |
causal_paths ¶
All directed (causal) paths from exposure to outcome.
These are the paths through which the treatment actually causes changes in the outcome.
backdoor_paths ¶
All backdoor (non-causal) paths from exposure to outcome.
A backdoor path is any path that starts with an arrow into the exposure (← exposure), creating spurious association.
is_path_open ¶
Check if a path is open (active) given a conditioning set.
Rules (Pearl 2009): - A non-collider on the path blocks if conditioned on. - A collider on the path blocks unless it or a descendant is conditioned on.
path_status ¶
Classify every path between exposure and outcome.
Returns:
| Type | Description |
|---|---|
list of dict
|
Each dict has keys |
Examples:
>>> g = sp.dag('Z -> X; Z -> Y; X -> Y')
>>> g.path_status('X', 'Y')
[{'path': ['X', 'Y'], 'type': 'causal', 'open': True},
{'path': ['X', 'Z', 'Y'], 'type': 'backdoor', 'open': True}]
>>> g.path_status('X', 'Y', conditioned={'Z'})
[{'path': ['X', 'Y'], 'type': 'causal', 'open': True},
{'path': ['X', 'Z', 'Y'], 'type': 'backdoor', 'open': False}]
classify_variable ¶
Classify the role(s) of node relative to exposure → outcome.
Returns a set that may include:
'confounder', 'mediator', 'collider',
'instrument', 'ancestor_of_treatment',
'ancestor_of_outcome'.
Examples:
bad_controls ¶
Identify variables that should not be conditioned on.
Returns a dict mapping variable names to the reason they are bad controls. Based on Cinelli, Forney & Pearl (2022) and the "bad controls" discussion in Cunningham (2021, ch. 3).
Categories of bad controls:
- descendant_of_treatment: conditioning on a descendant of exposure blocks part of the causal effect (over-control bias).
- collider: conditioning opens a previously closed backdoor path (collider bias / selection bias).
- mediator: conditioning on a mediator blocks the indirect causal effect (over-control / mediation bias).
- M-bias: conditioning on a pre-treatment variable that is a collider on a backdoor path, opening a non-causal path.
Examples:
do ¶
Return the interventional graph G_{\overline{X}}: the graph with all incoming edges to the intervention node(s) removed.
This implements Pearl's do-operator at the graphical level.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
intervention
|
str or set of str
|
The node(s) being intervened on. |
required |
Returns:
| Type | Description |
|---|---|
DAG
|
A new DAG with incoming edges to intervention removed. |
Examples:
frontdoor_sets ¶
Find sets satisfying Pearl's frontdoor criterion.
A set M satisfies the frontdoor criterion relative to (X, Y) if:
- M intercepts all directed paths from X to Y.
- There is no unblocked backdoor path from X to M.
- All backdoor paths from M to Y are blocked by X.
Returns:
| Type | Description |
|---|---|
list of set
|
Valid frontdoor adjustment sets (possibly empty). |
Examples:
d_separated ¶
Test if x and y are d-separated given conditioned.
Uses the Bayes-Ball algorithm (Shachter 1998).
adjustment_sets ¶
adjustment_sets(exposure: str, outcome: str, method: str = 'backdoor', minimal: bool = True) -> List[Set[str]]
Find valid adjustment sets for estimating the causal effect of exposure on outcome.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exposure
|
str
|
Treatment and outcome nodes. |
required |
outcome
|
str
|
Treatment and outcome nodes. |
required |
method
|
str
|
|
'backdoor'
|
minimal
|
bool
|
If True, return only minimal sufficient adjustment sets. |
True
|
Returns:
| Type | Description |
|---|---|
list of set
|
Each set is a valid adjustment set (possibly empty). |
plot ¶
plot(exposure: Optional[str] = None, outcome: Optional[str] = None, conditioned: Optional[Set[str]] = None, positions: Optional[Dict[str, Tuple[float, float]]] = None, figsize: tuple = (8, 6), seed: int = 42, title: Optional[str] = None, style: str = 'ggdag', node_size: float = 0.22, font_size: int = 12, ax=None)
Plot the DAG with publication-quality styling.
When exposure and outcome are provided, nodes are colour-coded
by causal role (like R's ggdag):
- Exposure: green
- Outcome: blue
- Confounder: orange
- Mediator: purple
- Collider / bad control: red
- Unobserved (latent): grey dashed outline
- Adjusted / conditioned: hatched fill
Bidirected edges (latent common causes) are rendered as curved dashed arcs rather than routing through hidden nodes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exposure
|
str
|
Treatment and outcome nodes for role colouring. |
None
|
outcome
|
str
|
Treatment and outcome nodes for role colouring. |
None
|
conditioned
|
set of str
|
Nodes being conditioned on (shown with hatched fill). |
None
|
positions
|
dict
|
|
None
|
figsize
|
tuple
|
Figure size (width, height) in inches. |
(8, 6)
|
seed
|
int
|
Random seed for layout jitter. |
42
|
title
|
str
|
Plot title. Auto-generated if exposure/outcome given. |
None
|
style
|
str
|
|
'ggdag'
|
node_size
|
float
|
Radius of node circles in data coordinates. |
0.22
|
font_size
|
int
|
Font size for node labels. |
12
|
ax
|
matplotlib Axes
|
Axes to draw on. If |
None
|
Returns:
| Type | Description |
|---|---|
(fig, ax) : matplotlib figure and axes
|
|
summary ¶
IdentificationResult
dataclass
¶
Outcome of an identification query.
Attributes:
| Name | Type | Description |
|---|---|---|
identifiable |
bool
|
True iff P(Y | do(X)) is identifiable from the observed distribution. |
estimand |
str
|
Do-free formula when identifiable; a structured hedge otherwise. |
c_components |
list[set[str]]
|
The c-components of the ancestral semi-Markovian graph G[An(Y)]. |
hedge |
tuple[frozenset, frozenset] | None
|
Witness C-forest pair (F, F') that proves non-identifiability. |
explanation |
str
|
Human-readable proof / refutation. |
SWIGGraph ¶
Single World Intervention Graph for a DAG under do(X=x).
Attributes:
| Name | Type | Description |
|---|---|---|
parent |
DAG
|
Source DAG. |
intervention |
dict[str, str]
|
Variable → value label. |
nodes |
set[str]
|
All SWIG nodes (split halves + potential outcomes). |
edges |
dict[str, set[str]]
|
Adjacency map on SWIG nodes. |
SCM
dataclass
¶
Structural Causal Model.
Each node has:
- parents: iterable of parent node names
- equation: callable(parents_dict, noise) -> value
- noise: callable() -> float (a draw from the exogenous noise
distribution; defaults to standard normal)
Example
scm = SCM() scm.add("X", [], lambda pa, u: u, lambda rng: rng.normal()) scm.add("Y", ["X"], lambda pa, u: 2*pa["X"] + u)
counterfactual ¶
counterfactual(evidence: dict, intervention: dict, n_samples: int = 2000, seed: int | None = None, tol: float = 0.01) -> dict
Compute E[Y(intervention) | evidence] via abduction-action-prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
evidence
|
dict
|
Observed values of some subset of nodes (factual world). |
required |
intervention
|
dict
|
Values to set for do-intervened variables. |
required |
n_samples
|
int
|
Number of accepted noise draws for rejection-sampling. |
2000
|
tol
|
float
|
Tolerance for matching continuous evidence. |
1e-2
|
Returns:
| Type | Description |
|---|---|
dict[str, ndarray]
|
Counterfactual samples for every node. |
LLMCausalAssessResult
dataclass
¶
Output of :func:llm_causal_assess.
PairwiseBenchmarkResult
dataclass
¶
Output of :func:pairwise_causal_benchmark.
dag ¶
dag(spec: str = '') -> DAG
Create a causal DAG from a string specification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spec
|
str
|
Edge specification. Examples:
|
''
|
Returns:
| Type | Description |
|---|---|
DAG
|
|
Examples:
dag_example ¶
dag_example(name: str) -> DAG
Load a classic textbook DAG by name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
One of: |
required |
Returns:
| Type | Description |
|---|---|
DAG
|
The example DAG. Call |
Examples:
dag_example_positions ¶
Return hand-tuned node positions for a named example DAG.
dag_simulate ¶
Run a classic DAG simulation from Cunningham (2021, ch. 3).
Available simulations:
'discrimination'— Gender discrimination / occupational sorting. True effect of discrimination on wage is -1. Conditioning on occupation alone flips the sign (collider bias).'movie_star'— Beauty–Talent collider. Beauty and Talent are independent in the population, but conditioning on Star status induces a spurious negative correlation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
|
required |
n
|
int
|
Number of observations (default 10000). |
10000
|
seed
|
int
|
Random seed for reproducibility. |
42
|
Returns:
| Type | Description |
|---|---|
DataFrame
|
Simulated dataset. |
Examples:
>>> df = sp.dag_simulate('discrimination')
>>> import statsmodels.formula.api as smf
>>> # Biased: wrong sign due to collider
>>> smf.ols('wage ~ female + occupation', data=df).fit().params['female']
>>> # Correct: includes ability
>>> smf.ols('wage ~ female + occupation + ability', data=df).fit().params['female']
identify ¶
identify(dag, treatment, outcome) -> IdentificationResult
Run Shpitser-Pearl ID algorithm on dag.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dag
|
DAG
|
|
required |
treatment
|
str | Iterable[str]
|
Set of variables X being intervened on. |
required |
outcome
|
str | Iterable[str]
|
Set of outcome variables Y. |
required |
Returns:
| Type | Description |
|---|---|
IdentificationResult
|
|
rule1 ¶
Check Rule 1: can we insert or delete observation of Z?
rule2 ¶
Check Rule 2: can do(Z) be swapped for observing Z?
apply_rules ¶
Try all three rules and return every applicable simplification.
llm_causal_assess ¶
llm_causal_assess(level1_items: Optional[DataFrame] = None, level2_items: Optional[DataFrame] = None, *, llm_client: Callable[[str], str], llm_identifier: str = 'llm') -> LLMCausalAssessResult
Combined Level-1 + Level-2 LLM causal-reasoning assessment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
level1_items
|
DataFrame
|
Columns: |
None
|
level2_items
|
DataFrame
|
Columns: |
None
|
llm_client
|
Callable[[str], str]
|
|
required |
llm_identifier
|
Callable[[str], str]
|
|
required |
Returns:
| Type | Description |
|---|---|
LLMCausalAssessResult
|
|
pairwise_causal_benchmark ¶
pairwise_causal_benchmark(ground_truth: DataFrame, *, llm_client: Callable[[str], str], llm_identifier: str = 'llm', pair_a_col: str = 'A', pair_b_col: str = 'B', truth_col: str = 'a_causes_b', prompt_template: str = "Does variable {a} causally influence variable {b}? Answer 'yes' or 'no'.") -> PairwiseBenchmarkResult
Benchmark an LLM on pairwise causal-direction identification.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ground_truth
|
DataFrame
|
One row per pair with columns |
required |
llm_client
|
callable(str) -> str
|
Function taking a prompt and returning a string. |
required |
llm_identifier
|
str
|
|
'llm'
|
prompt_template
|
str
|
|
``"Does variable {a} causally ..."``
|
Returns:
| Type | Description |
|---|---|
PairwiseBenchmarkResult
|
|
recommend_estimator ¶
recommend_estimator(dag, exposure: str, outcome: str, candidate_instruments: Optional[List[str]] = None) -> EstimatorRecommendation
Inspect a DAG and recommend a statspai estimator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dag
|
DAG
|
|
required |
exposure
|
str
|
|
required |
outcome
|
str
|
|
required |
candidate_instruments
|
list of str
|
Variable names to check as potential IVs. If omitted, all observed nodes other than exposure/outcome are considered. |
None
|