Python: Conditional Average Treatment Effects (CATEs)#

In this simple example, we illustrate how the DoubleML package can be used to estimate conditional average treatment effects with B-splines for one or two-dimensional effects.

Data#

We define a data generating process to create synthetic data to compare the estimates to the true effect. The data generating process is based on the Monte Carlo simulation from Oprescu et al. (2019) and this notebook from EconML.

[1]:
import numpy as np
import pandas as pd
import doubleml as dml

The data is generated as

\[\begin{split}\begin{aligned} Y_i & = g(X_i)T_i + \langle W_i,\gamma_0\rangle + \epsilon_i \\ T_i & = \langle W_i,\beta_0\rangle +\eta_i, \end{aligned}\end{split}\]

where \(W_i\sim\mathcal{N}(0,I_{d_w})\), \(X_i\sim\mathcal{U}[0,1]^{d_x}\) and \(\epsilon_i,\eta_i\sim\mathcal{U}[0,1]\). The coefficient vectors \(\gamma_0\) and \(\beta_0\) both have small random support which values are drawn independently from \(\mathcal{U}[0,1]\). Further, \(g(x)\) defines the conditional treatment effect, which is defined differently depending on the dimension of \(x\).

If \(x\) is univariate the conditional treatment effect takes the following form

\[g(x) = \exp(2x) + 3\sin(4x),\]

whereas for a two-dimensional variable \(x=(x_1,x_2)\) the conditional treatment effect is defined as

\[g(x) = \exp(2x_1) + 3\sin(4x_2).\]
[2]:
def treatment_effect_1d(x):
    te = np.exp(2 * x) + 3 * np.sin(4 * x)
    return te

def treatment_effect_2d(x):
    te = np.exp(2 * x[0]) + 3 * np.sin(4 * x[1])
    return te

def create_synthetic_data(n_samples=200, n_w=30, support_size=5, n_x=1):
    # Outcome support
    # With the next two lines we are effectively choosing the matrix gamma in the example
    support_y = np.random.choice(np.arange(n_w), size=support_size, replace=False)
    coefs_y = np.random.uniform(0, 1, size=support_size)
    # Define the function to generate the noise
    epsilon_sample = lambda n: np.random.uniform(-1, 1, size=n_samples)
    # Treatment support
    # Assuming the matrices gamma and beta have the same non-zero components
    support_t = support_y
    coefs_t = np.random.uniform(0, 1, size=support_size)
    # Define the function to generate the noise
    eta_sample = lambda n: np.random.uniform(-1, 1, size=n_samples)

    # Generate controls, covariates, treatments and outcomes
    w = np.random.normal(0, 1, size=(n_samples, n_w))
    x = np.random.uniform(0, 1, size=(n_samples, n_x))
    # Heterogeneous treatment effects
    if n_x == 1:
        te = np.array([treatment_effect_1d(x_i) for x_i in x]).reshape(-1)
    elif n_x == 2:
        te = np.array([treatment_effect_2d(x_i) for x_i in x]).reshape(-1)
    # Define treatment
    log_odds = np.dot(w[:, support_t], coefs_t) + eta_sample(n_samples)
    t_sigmoid = 1 / (1 + np.exp(-log_odds))
    t = np.array([np.random.binomial(1, p) for p in t_sigmoid])
    # Define the outcome
    y = te * t + np.dot(w[:, support_y], coefs_y) + epsilon_sample(n_samples)

    # Now we build the dataset
    y_df = pd.DataFrame({'y': y})
    if n_x == 1:
        x_df = pd.DataFrame({'x': x.reshape(-1)})
    elif n_x == 2:
        x_df = pd.DataFrame({'x_0': x[:,0],
                             'x_1': x[:,1]})
    t_df = pd.DataFrame({'t': t})
    w_df = pd.DataFrame(data=w, index=np.arange(w.shape[0]), columns=[f'w_{i}' for i in range(w.shape[1])])

    data = pd.concat([y_df, x_df, t_df, w_df], axis=1)

    covariates = list(w_df.columns.values) + list(x_df.columns.values)
    return data, covariates, te

One-dimensional Example#

We start with \(X\) being one-dimensional and create our training data.

[3]:
# DGP constants
np.random.seed(42)
n_samples = 2000
n_w = 10
support_size = 5
n_x = 1

# Create data
data, covariates, true_effect = create_synthetic_data(n_samples=n_samples, n_w=n_w, support_size=support_size, n_x=n_x)
data_dml_base = dml.DoubleMLData(data,
                                 y_col='y',
                                 d_cols='t',
                                 x_cols=covariates)

Next, define the learners for the nuisance functions and fit the IRM Model. Remark that the learners are not optimal for the linear form of this example.

[4]:
# First stage estimation
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
randomForest_reg = RandomForestRegressor(n_estimators=500)
randomForest_class = RandomForestClassifier(n_estimators=500)

np.random.seed(42)

dml_irm = dml.DoubleMLIRM(data_dml_base,
                          ml_g=randomForest_reg,
                          ml_m=randomForest_class,
                          trimming_threshold=0.01,
                          n_folds=5)
print("Training IRM Model")
dml_irm.fit()
Training IRM Model
[4]:
<doubleml.double_ml_irm.DoubleMLIRM at 0x7f7a25c6df70>

To estimate the CATE, we rely on the best-linear-predictor of the linear score as in Semenova et al. (2021) To approximate the target function \(g(x)\) with a linear form, we have to define a data frame of basis functions. Here, we rely on patsy to construct a suitable basis of B-splines.

[5]:
import patsy
design_matrix = patsy.dmatrix("bs(x, df=5, degree=2)", {"x":data["x"]})
spline_basis = pd.DataFrame(design_matrix)

To estimate the parameters to calculate the CATE estimate call the cate() method and supply the dataframe of basis elements.

[6]:
cate = dml_irm.cate(spline_basis)
print(cate)
================== DoubleMLBLP Object ==================

------------------ Fit summary ------------------
       coef   std err          t          P>|t|    [0.025    0.975]
0  0.803842  0.187175   4.294598   1.834144e-05  0.436763  1.170921
1  2.313014  0.312769   7.395278   2.066986e-13  1.699626  2.926402
2  4.728587  0.199944  23.649585  3.356811e-109  4.336467  5.120708
3  4.498873  0.239346  18.796495   1.070486e-72  4.029478  4.968268
4  3.860540  0.245883  15.700725   1.816059e-52  3.378326  4.342755
5  4.111399  0.266502  15.427289   8.201758e-51  3.588748  4.634050

To obtain the confidence intervals for the CATE, we have to call the confint() method and a supply a dataframe of basis elements. This could be the same basis as for fitting the CATE model or a new basis to e.g. evaluate the CATE model on a grid. Here, we will evaluate the CATE on a grid from 0.1 to 0.9 to plot the final results. Further, we construct uniform confidence intervals by setting the option joint and providing a number of bootstrap repetitions n_rep_boot.

[7]:
new_data = {"x": np.linspace(0.1, 0.9, 100)}
spline_grid = pd.DataFrame(patsy.build_design_matrices([design_matrix.design_info], new_data)[0])
df_cate = cate.confint(spline_grid, level=0.95, joint=True, n_rep_boot=2000)
print(df_cate)
       2.5 %    effect    97.5 %
0   2.161524  2.486687  2.811850
1   2.280143  2.606998  2.933854
2   2.394160  2.724940  3.055721
3   2.504567  2.840513  3.176460
4   2.612210  2.953717  3.295225
..       ...       ...       ...
95  4.479230  4.810769  5.142308
96  4.482378  4.808534  5.134690
97  4.486453  4.807426  5.128399
98  4.490813  4.807445  5.124076
99  4.494702  4.808590  5.122478

[100 rows x 3 columns]

Finally, we can plot our results and compare them with the true effect.

[8]:
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = 10., 7.5

df_cate['x'] = new_data['x']
df_cate['true_effect'] = treatment_effect_1d(new_data['x'])
fig, ax = plt.subplots()
ax.plot(df_cate['x'],df_cate['effect'], label='Estimated Effect')
ax.plot(df_cate['x'],df_cate['true_effect'], color="green", label='True Effect')
ax.fill_between(df_cate['x'], df_cate['2.5 %'], df_cate['97.5 %'], color='b', alpha=.3, label='Confidence Interval')

plt.legend()
plt.title('CATE')
plt.xlabel('x')
_ =  plt.ylabel('Effect and 95%-CI')
../_images/examples_py_double_ml_cate_16_0.png

Two-Dimensional Example#

It is also possible to estimate multi-dimensional conditional effects. We will use the same data-generating process as above, but let \(X\) be two-dimensional.

[9]:
# DGP constants
np.random.seed(42)
n_samples = 5000
n_w = 10
support_size = 5
n_x = 2
[10]:
# Create data
data, covariates, true_effect = create_synthetic_data(n_samples=n_samples, n_w=n_w, support_size=support_size, n_x=n_x)
data_dml_base = dml.DoubleMLData(data,
                                 y_col='y',
                                 d_cols='t',
                                 x_cols=covariates)

As univariate example estimate the IRM Model.

[11]:
# First stage estimation
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
randomForest_reg = RandomForestRegressor(n_estimators=500)
randomForest_class = RandomForestClassifier(n_estimators=500)

np.random.seed(123)

dml_irm = dml.DoubleMLIRM(data_dml_base,
                          ml_g=randomForest_reg,
                          ml_m=randomForest_class,
                          trimming_threshold=0.01,
                          n_folds=5)
print("Training IRM Model")
dml_irm.fit()
Training IRM Model
[11]:
<doubleml.double_ml_irm.DoubleMLIRM at 0x7f7a21353c40>

As above, we will rely on the patsy package to construct the basis elements. In the two-dimensional case, we will construct a tensor product of B-splines (for more information see here).

[12]:
design_matrix = patsy.dmatrix("te(bs(x_0, df=7, degree=3), bs(x_1, df=7, degree=3))", {"x_0": data["x_0"], "x_1": data["x_1"]})
spline_basis = pd.DataFrame(design_matrix)

cate = dml_irm.cate(spline_basis)
print(cate)
================== DoubleMLBLP Object ==================

------------------ Fit summary ------------------
         coef   std err          t          P>|t|     [0.025     0.975]
0    2.919496  0.131819  22.147748  1.031651e-103   2.661073   3.177920
1   -3.123862  1.055228  -2.960368   3.087298e-03  -5.192576  -1.055148
2    1.323503  1.083102   1.221956   2.217825e-01  -0.799857   3.446864
3    3.765601  0.933075   4.035689   5.526381e-05   1.936360   5.594842
4    1.140783  0.960682   1.187471   2.350988e-01  -0.742581   3.024146
5   -4.087710  1.164123  -3.511408   4.497151e-04  -6.369906  -1.805513
6   -4.555476  1.256634  -3.625141   2.917067e-04  -7.019035  -2.091916
7   -8.171822  1.355522  -6.028542   1.774414e-09 -10.829247  -5.514398
8   -0.682767  1.129008  -0.604750   5.453728e-01  -2.896123   1.530588
9    0.592693  1.176052   0.503968   6.143062e-01  -1.712891   2.898277
10   1.118080  0.967196   1.156001   2.477364e-01  -0.778053   3.014213
11  -1.499592  1.005430  -1.491493   1.358959e-01  -3.470681   0.471497
12   1.042978  1.248387   0.835461   4.034986e-01  -1.404413   3.490369
13  -3.205079  1.301410  -2.462773   1.382050e-02  -5.756419  -0.653738
14   0.104659  1.140120   0.091796   9.268638e-01  -2.130481   2.339798
15   0.389498  0.984990   0.395433   6.925401e-01  -1.541519   2.320514
16   1.259657  1.008809   1.248657   2.118497e-01  -0.718057   3.237370
17   2.717739  0.825871   3.290756   1.006183e-03   1.098666   4.336811
18   2.045820  0.819154   2.497478   1.254023e-02   0.439914   3.651725
19  -1.586945  1.023858  -1.549967   1.212135e-01  -3.594160   0.420270
20  -3.505756  1.113612  -3.148096   1.653149e-03  -5.688928  -1.322583
21  -4.992207  0.955858  -5.222747   1.835097e-07  -6.866113  -3.118300
22   0.846579  0.982381   0.861762   3.888600e-01  -1.079323   2.772481
23   3.495717  1.048151   3.335128   8.588780e-04   1.440876   5.550557
24   4.312733  0.855997   5.038257   4.865016e-07   2.634599   5.990867
25   2.313173  0.840396   2.752481   5.936115e-03   0.665625   3.960721
26   0.306966  1.019577   0.301072   7.633726e-01  -1.691858   2.305790
27  -1.885131  1.128953  -1.669805   9.502111e-02  -4.098378   0.328117
28   0.410831  1.041205   0.394573   6.931750e-01  -1.630392   2.452055
29   5.373776  1.207812   4.449182   8.808886e-06   3.005929   7.741624
30   1.166331  1.298411   0.898276   3.690821e-01  -1.379129   3.711792
31   5.545567  1.056203   5.250476   1.580522e-07   3.474942   7.616193
32   2.899004  1.087844   2.664908   7.725937e-03   0.766347   5.031660
33   2.069761  1.315561   1.573292   1.157152e-01  -0.509322   4.648844
34   0.429075  1.401980   0.306049   7.595802e-01  -2.319428   3.177578
35  -3.048162  1.422606  -2.142661   3.218896e-02  -5.837100  -0.259224
36   5.001099  1.309061   3.820372   1.348834e-04   2.434759   7.567439
37   7.467624  1.402798   5.323377   1.063572e-07   4.717517  10.217730
38   7.677422  1.150131   6.675261   2.738734e-11   5.422656   9.932188
39   6.235872  1.192998   5.227060   1.793050e-07   3.897067   8.574677
40   2.099459  1.416893   1.481734   1.384748e-01  -0.678280   4.877197
41   3.668546  1.514738   2.421901   1.547525e-02   0.698988   6.638105
42   2.824390  1.533177   1.842181   6.550841e-02  -0.181317   5.830097
43  10.036905  1.279085   7.846943   5.186320e-15   7.529332  12.544478
44   4.331213  1.352405   3.202602   1.370548e-03   1.679901   6.982526
45   7.478441  1.130706   6.613954   4.137220e-11   5.261755   9.695127
46   6.707993  1.126226   5.956172   2.760498e-09   4.500092   8.915895
47   5.289351  1.426380   3.708234   2.109945e-04   2.493014   8.085688
48   0.106993  1.520447   0.070369   9.439026e-01  -2.873758   3.087744
49   3.106145  1.415861   2.193820   2.829466e-02   0.330429   5.881861

Finally, we create a new grid to evaluate and plot the effects.

[13]:
grid_size = 100
x_0 = np.linspace(0.1, 0.9, grid_size)
x_1 = np.linspace(0.1, 0.9, grid_size)
x_0, x_1 = np.meshgrid(x_0, x_1)

new_data = {"x_0": x_0.ravel(), "x_1": x_1.ravel()}
[14]:
spline_grid = pd.DataFrame(patsy.build_design_matrices([design_matrix.design_info], new_data)[0])
df_cate = cate.confint(spline_grid, joint=True, n_rep_boot=2000)
print(df_cate)
         2.5 %    effect    97.5 %
0     1.170194  1.998915  2.827636
1     1.211724  2.010808  2.809893
2     1.265405  2.031274  2.797142
3     1.327257  2.059397  2.791536
4     1.393231  2.094264  2.795298
...        ...       ...       ...
9995  4.100965  4.868541  5.636118
9996  4.164483  4.965244  5.766004
9997  4.222204  5.055011  5.887818
9998  4.275363  5.136758  5.998153
9999  4.324848  5.209400  6.093952

[10000 rows x 3 columns]
[15]:
import plotly.graph_objects as go

true_effect = np.array([treatment_effect_2d(x_i) for x_i in zip(x_0.ravel(), x_1.ravel())]).reshape(x_0.shape)
effect = np.asarray(df_cate['effect']).reshape(x_0.shape)
lower_bound = np.asarray(df_cate['2.5 %']).reshape(x_0.shape)
upper_bound = np.asarray(df_cate['97.5 %']).reshape(x_0.shape)

fig = go.Figure(data=[
    go.Surface(x=x_0,
               y=x_1,
               z=true_effect),
    go.Surface(x=x_0,
               y=x_1,
               z=upper_bound, showscale=False, opacity=0.4,colorscale='purp'),
    go.Surface(x=x_0,
               y=x_1,
               z=lower_bound, showscale=False, opacity=0.4,colorscale='purp'),
])
fig.update_traces(contours_z=dict(show=True, usecolormap=True,
                                  highlightcolor="limegreen", project_z=True))

fig.update_layout(scene = dict(
                    xaxis_title='X_0',
                    yaxis_title='X_1',
                    zaxis_title='Effect'),
                    width=700,
                    margin=dict(r=20, b=10, l=10, t=10))

fig.show()