Example: Sensitivity Analysis for Causal ML#

This notebook complements the introductory paper “Sensitivity Analysis for Causal ML: A Use Case at Booking.com” by Bach et al. (2024) (forthcoming). It illustrates the causal analysis and sensitivity considerations in a simplified example.

[1]:
import doubleml as dml
from doubleml.datasets import make_confounded_irm_data

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.isotonic import IsotonicRegression
from lightgbm import LGBMRegressor, LGBMClassifier
import plotly.graph_objects as go

Load Data#

We will simulate a data set according to a data generating process (DGP), which is available and documented in the DoubleML for Python. We will parametrize the DGP in a way that it roughly mimics patterns of the data used in the original analysis.

[2]:
# Use smaller number of observations in demo example to reduce computational time
n_obs = 75000

# Parameters for the data generating process
# True average treatment effect (very similar to ATT in this example)
theta = 0.07
# Coefficient of the unobserved confounder in the outcome regression.
beta_a = 0.25
# Coefficient of the unobserved confounder in the propensity score.
gamma_a = 0.123
# Variance for outcome regression error
var_eps = 1.5
# Threshold being applied on trimming propensity score on the population level
trimming_threshold = 0.05

# Number of observations
np.random.seed(42)
dgp_dict = make_confounded_irm_data(n_obs=n_obs, theta=theta, beta_a=beta_a, gamma_a=gamma_a, var_eps=var_eps, trimming_threshold=trimming_threshold)

x_cols = [f'X{i + 1}' for i in np.arange(dgp_dict['x'].shape[1])]
df = pd.DataFrame(np.column_stack((dgp_dict['x'], dgp_dict['y'], dgp_dict['d'])), columns=x_cols + ['y', 'd'])

/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/doubleml/datasets.py:1050: UserWarning: Propensity score is close to 0 or 1. Trimming is at 0.05 and 0.95 is applied
  warnings.warn(f'Propensity score is close to 0 or 1. '
[3]:
df.head()
[3]:
X1 X2 X3 X4 X5 y d
0 0.496714 -0.138264 0.647689 1.523030 -0.234153 1.421200 1.0
1 -0.234137 1.579213 0.767435 -0.469474 0.542560 1.315031 1.0
2 -0.463418 -0.465730 0.241962 -1.913280 -1.724918 1.493144 1.0
3 -0.562288 -1.012831 0.314247 -0.908024 -1.412304 1.399858 1.0
4 1.465649 -0.225776 0.067528 -1.424748 -0.544383 3.092229 0.0

Causal Analysis with DoubleML#

1. Formulation of Causal Model & Identification Assumptions#

In the use case, we focused on a nonparametric model for the treatment effect, also called Interactive Regression Model (IRM). Under the assumptions of consistency, overlap and unconfoundedness, this model can be used to identify the Average Treatment Effect (ATE) and the Average Treatment Effect on the Treated (ATT). The identification strategy was based on a DAG.

DAG underlying to the causal analysis.

In the use case of consideration, the key causal quantity was the ATT as it quantifies by how much ancillary products increase follow-up bookings on average for customers who purchased an ancillary product.

[4]:
# Set up the data backend with treatment variable d, outcome variable y, and covariates x
dml_data = dml.DoubleMLData(df, 'y', 'd', x_cols)

2. Estimation of Causal Effect#

For estimation, we employed the DoubleML package in Python. The nuisance functions (including the outcome regression and the propensity score) have been used with LightGBM.

[5]:
# Initialize LightGBM learners
n_estimators = 150
learning_rate = 0.05
ml_g = LGBMRegressor(n_estimators=n_estimators, learning_rate = 0.05, verbose=-1)
ml_m = LGBMClassifier(n_estimators=n_estimators, learning_rate = 0.05, verbose=-1)

# Initialize the DoubleMLIRM model, specify score "ATTE" for average treatment effect on the treated
dml_obj = dml.DoubleMLIRM(dml_data, score = "ATTE", ml_g = ml_g, ml_m = ml_m, n_folds = 5, n_rep = 2)


# fit the model
dml_obj.fit()
[5]:
<doubleml.irm.irm.DoubleMLIRM at 0x7f6307105460>

Let’s summarize the estimation results for the ATT.

[6]:
dml_obj.summary.round(3)
[6]:
coef std err t P>|t| 2.5 % 97.5 %
d 0.123 0.008 15.065 0.0 0.107 0.139

The results point at a sizeable positive effect. However, we are concerned that this effect might be driven by unobserved confounders: The large positive effect might represent selection into treatment mechanisms rather than the pure causal effect of the treatment.

3. Sensitivity Analysis#

To address the concerns with regard to the confounding bias, sensitivity analysis has been employed. The literature has developed various approaches, which differ in terms of their applicability to the specific estimation approach (among others). In the context of the use case, the approaches of VanderWeele and Arah (2011) and Chernozhukov et al. (2023) have been employed. Here, we won’t go into the details of the methods but rather illustrate the application of the sensitivity analysis.

VanderWeele and Arah (2011) provide a general formula for the omitted variable bias that is applicable irrespective of the estimation framework. The general formula is based on explicit parametrization of the model in terms of the distribution of the unobserved confounder. Such a specification might be difficult to achieve in practice. Hence, the authors also offer a simplified version that employs additional assumptions. For the ATT, these assumptions impose a binary confounding variable that has an effect on \(D\) and \(Y\) which does not vary with the observed confounding variables. Under these scenarios, the bias formula arises as

\[\theta_s - \theta_0 = \delta \cdot \gamma\]

where \(\theta_s\) refers to the short parameter (= the ATT that is identfiable from the available data, i.e., under unobserved confounding) and \(\theta_0\) the long or true parameter (that would be identifiable if the unobserved confounder was observed). \(\delta\) and \(\gamma\) denote the sensitivity parameters in this framework and refer to difference in the prevalence of the (binary) confounder in the treated and the untreated group (after accounting for \(X\)): \(\delta = P(U|D=1, X) - P(U|D=0,X)\). \(\gamma\) refers to the confounding effect in the main equation, i.e., \(\gamma = E[Y|D,X, U=1] - E[Y|D,X, U=0]\), which describes the average expected difference in the outcome variable due to a change in the confounding variable \(U\) (given \(D\) and \(X\)). For a more detailed treatment, we refer to the original paper by VanderWeele and Arah (2011). This sensitivity approach is appealing because of its simplicity and applicability. We can specify various scenarios in terms of values for \(\gamma\) and \(\delta\) and compute the bias. This could also be illustrated in a contour plot.

We would like to note that in the context of the original analysis, we experimented with various sensitivity frameworks. Hence, the presentation here is mainly for illustrative purposes.

[7]:
# Implement vanderWeele and Arah corrected estimate
def adj_vanderWeeleArah(coef, gamma, delta, downward = True):
    bias = gamma * delta

    if downward:
        adj = coef - bias
    else:
        adj = coef + bias
    return adj

We set up a grid of values for \(\gamma\) and \(\delta\) and compute the bias for each combination. We then illustrate the bias in a contour plot.

[8]:
gamma_val, delta_val = np.linspace(0, 1, 100), np.linspace(0, 5, 100)

# all combinations of gamma_val and delta_val
gamma_val, delta_val = np.meshgrid(gamma_val, delta_val)

# Set "downward = True": We are worried that selection into the treatment leads to an upward bias of the ATT
adj_est = adj_vanderWeeleArah(dml_obj.coef, gamma_val, delta_val, downward = True)
[9]:
# set up a contour plot based on the values for gamma_val, delta_val, and adj_est
fig = go.Figure(data=go.Contour(z=adj_est, x=gamma_val[0], y=delta_val[:, 0],
                                contours=dict(coloring = 'heatmap', showlabels = True)))

fig.update_layout(title='Adjusted ATT estimate based on VanderWeele and Arah (downward bias)',
                  xaxis_title='gamma',
                  yaxis_title='delta')

# highlight the contour line at the level of zero
fig.show()