Python: Conditional Average Treatment Effects (CATEs) for IRM models#

In this simple example, we illustrate how the DoubleML package can be used to estimate conditional average treatment effects with B-splines for one or two-dimensional effects in the DoubleMLIRM model.

[1]:
import numpy as np
import pandas as pd
import doubleml as dml

from doubleml.datasets import make_heterogeneous_data

Data#

We define a data generating process to create synthetic data to compare the estimates to the true effect. The data generating process is based on the Monte Carlo simulation from Oprescu et al. (2019).

The documentation of the data generating process can be found here.

One-dimensional Example#

We start with an one-dimensional effect and create our training data. In this example the true effect depends only the first covariate \(X_0\) and takes the following form

\[\theta_0(X) = \exp(2X_0) + 3\sin(4X_0).\]

The generated dictionary also contains a callable with key treatment_effect to calculate the true treatment effect for new observations.

[2]:
np.random.seed(42)
data_dict = make_heterogeneous_data(
    n_obs=2000,
    p=10,
    support_size=5,
    n_x=1,
    binary_treatment=True,
)
treatment_effect = data_dict['treatment_effect']
data = data_dict['data']
print(data.head())
          y    d       X_0       X_1       X_2       X_3       X_4       X_5  \
0  4.803300  1.0  0.259828  0.886086  0.895690  0.297287  0.229994  0.411304
1  5.655547  1.0  0.824350  0.396992  0.156317  0.737951  0.360475  0.671271
2  1.878402  0.0  0.988421  0.977280  0.793818  0.659423  0.577807  0.866102
3  6.941440  1.0  0.427486  0.330285  0.564232  0.850575  0.201528  0.934433
4  1.703049  1.0  0.016200  0.818380  0.040139  0.889913  0.991963  0.294067

        X_6       X_7       X_8       X_9
0  0.240532  0.672384  0.826065  0.673092
1  0.270644  0.081230  0.992582  0.156202
2  0.289440  0.467681  0.619390  0.411190
3  0.689088  0.823273  0.556191  0.779517
4  0.210319  0.765363  0.253026  0.865562

First, define the DoubleMLData object.

[3]:
data_dml_base = dml.DoubleMLData(
    data,
    y_col='y',
    d_cols='d'
)

Next, define the learners for the nuisance functions and fit the IRM Model. Remark that linear learners would usually be optimal due to the data generating process.

[4]:
# First stage estimation
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
randomForest_reg = RandomForestRegressor(n_estimators=500)
randomForest_class = RandomForestClassifier(n_estimators=500)

np.random.seed(42)

dml_irm = dml.DoubleMLIRM(data_dml_base,
                          ml_g=randomForest_reg,
                          ml_m=randomForest_class,
                          trimming_threshold=0.05,
                          n_folds=5)
print("Training IRM Model")
dml_irm.fit()

print(dml_irm.summary)
Training IRM Model
       coef   std err           t  P>|t|     2.5 %    97.5 %
d  4.466099  0.041522  107.559895    0.0  4.384718  4.547481

To estimate the CATE, we rely on the best-linear-predictor of the linear score as in Semenova et al. (2021) To approximate the target function \(\theta_0(x)\) with a linear form, we have to define a data frame of basis functions. Here, we rely on patsy to construct a suitable basis of B-splines.

[5]:
import patsy
design_matrix = patsy.dmatrix("bs(x, df=5, degree=2)", {"x": data["X_0"]})
spline_basis = pd.DataFrame(design_matrix)

To estimate the parameters to calculate the CATE estimate call the cate() method and supply the dataframe of basis elements.

[6]:
cate = dml_irm.cate(spline_basis)
print(cate)
================== DoubleMLBLP Object ==================

------------------ Fit summary ------------------
       coef   std err          t          P>|t|    [0.025    0.975]
0  0.667090  0.165286   4.035962   5.643279e-05  0.342937  0.991242
1  2.327267  0.275171   8.457524   5.200595e-17  1.787614  2.866920
2  4.952494  0.176898  27.996286  9.628846e-146  4.605569  5.299418
3  4.766177  0.211872  22.495595  4.665132e-100  4.350664  5.181690
4  3.705866  0.215236  17.217659   4.785196e-62  3.283754  4.127978
5  4.347630  0.231332  18.793887   1.116123e-72  3.893952  4.801308

To obtain the confidence intervals for the CATE, we have to call the confint() method and a supply a dataframe of basis elements. This could be the same basis as for fitting the CATE model or a new basis to e.g. evaluate the CATE model on a grid. Here, we will evaluate the CATE on a grid from 0.1 to 0.9 to plot the final results. Further, we construct uniform confidence intervals by setting the option joint and providing a number of bootstrap repetitions n_rep_boot.

[7]:
new_data = {"x": np.linspace(0.1, 0.9, 100)}
spline_grid = pd.DataFrame(patsy.build_design_matrices([design_matrix.design_info], new_data)[0])
df_cate = cate.confint(spline_grid, level=0.95, joint=True, n_rep_boot=2000)
print(df_cate)
       2.5 %    effect    97.5 %
0   2.058228  2.326356  2.594483
1   2.174366  2.447195  2.720024
2   2.287354  2.566044  2.844733
3   2.397882  2.682901  2.967921
4   2.506513  2.797769  3.089025
..       ...       ...       ...
95  4.375866  4.663404  4.950942
96  4.383952  4.665388  4.946825
97  4.394845  4.669928  4.945011
98  4.407947  4.677022  4.946098
99  4.422512  4.686673  4.950833

[100 rows x 3 columns]

Finally, we can plot our results and compare them with the true effect.

[8]:
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = 10., 7.5

df_cate['x'] = new_data['x']
df_cate['true_effect'] = treatment_effect(new_data["x"].reshape(-1, 1))
fig, ax = plt.subplots()
ax.plot(df_cate['x'],df_cate['effect'], label='Estimated Effect')
ax.plot(df_cate['x'],df_cate['true_effect'], color="green", label='True Effect')
ax.fill_between(df_cate['x'], df_cate['2.5 %'], df_cate['97.5 %'], color='b', alpha=.3, label='Confidence Interval')

plt.legend()
plt.title('CATE')
plt.xlabel('x')
_ =  plt.ylabel('Effect and 95%-CI')
../_images/examples_py_double_ml_cate_17_0.png

If the effect is not one-dimensional, the estimate still corresponds to the projection of the true effect on the basis functions.

Two-Dimensional Example#

It is also possible to estimate multi-dimensional conditional effects. We will use a similar data generating process but now the effect depends on the first two covariates \(X_0\) and \(X_1\) and takes the following form

\[\theta_0(X) = \exp(2X_0) + 3\sin(4X_1).\]

With the argument n_x=2 we can specify set the effect to be two-dimensional.

[9]:
np.random.seed(42)
data_dict = make_heterogeneous_data(
    n_obs=5000,
    p=10,
    support_size=5,
    n_x=2,
    binary_treatment=True,
)
treatment_effect = data_dict['treatment_effect']
data = data_dict['data']
print(data.head())
          y    d       X_0       X_1       X_2       X_3       X_4       X_5  \
0  1.286203  1.0  0.014080  0.006958  0.240127  0.100807  0.260211  0.177043
1  0.416899  1.0  0.152148  0.912230  0.892796  0.653901  0.672234  0.005339
2  2.087634  1.0  0.344787  0.893649  0.291517  0.562712  0.099731  0.921956
3  7.508433  1.0  0.619351  0.232134  0.000943  0.757151  0.985207  0.809913
4  0.567695  0.0  0.477130  0.447624  0.775191  0.526769  0.316717  0.258158

        X_6       X_7       X_8       X_9
0  0.028520  0.909304  0.008223  0.736082
1  0.984872  0.877833  0.895106  0.659245
2  0.140770  0.224897  0.558134  0.764093
3  0.460207  0.903767  0.409848  0.524934
4  0.037747  0.583195  0.229961  0.148134

As univariate example estimate the IRM Model.

[10]:
data_dml_base = dml.DoubleMLData(
    data,
    y_col='y',
    d_cols='d'
)
[11]:
# First stage estimation
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
randomForest_reg = RandomForestRegressor(n_estimators=500)
randomForest_class = RandomForestClassifier(n_estimators=500)

np.random.seed(42)

dml_irm = dml.DoubleMLIRM(data_dml_base,
                          ml_g=randomForest_reg,
                          ml_m=randomForest_class,
                          trimming_threshold=0.05,
                          n_folds=5)
print("Training IRM Model")
dml_irm.fit()

print(dml_irm.summary)
Training IRM Model
       coef   std err           t  P>|t|     2.5 %    97.5 %
d  4.549577  0.038761  117.374711    0.0  4.473606  4.625547

As above, we will rely on the patsy package to construct the basis elements. In the two-dimensional case, we will construct a tensor product of B-splines (for more information see here).

[12]:
design_matrix = patsy.dmatrix("te(bs(x_0, df=7, degree=3), bs(x_1, df=7, degree=3))", {"x_0": data["X_0"], "x_1": data["X_1"]})
spline_basis = pd.DataFrame(design_matrix)

cate = dml_irm.cate(spline_basis)
print(cate)
================== DoubleMLBLP Object ==================

------------------ Fit summary ------------------
        coef   std err          t          P>|t|    [0.025     0.975]
0   2.773359  0.105416  26.308592  6.816294e-143  2.566696   2.980022
1  -1.232569  0.860988  -1.431576   1.523283e-01 -2.920487   0.455348
2   0.657257  0.851866   0.771549   4.404182e-01 -1.012779   2.327292
3   2.425408  0.751965   3.225427   1.266025e-03  0.951224   3.899593
4   0.648466  0.759430   0.853885   3.932100e-01 -0.840354   2.137286
5  -3.269591  0.955308  -3.422553   6.253869e-04 -5.142417  -1.396765
6  -4.782140  1.028094  -4.651462   3.381763e-06 -6.797660  -2.766620
7  -6.377953  0.998981  -6.384461   1.877333e-10 -8.336398  -4.419508
8  -2.365387  0.904109  -2.616263   8.916818e-03 -4.137841  -0.592932
9   3.689090  0.925798   3.984769   6.851683e-05  1.874116   5.504064
10 -0.037642  0.783487  -0.048044   9.616834e-01 -1.573623   1.498340
11  1.634924  0.823330   1.985745   4.711678e-02  0.020831   3.249017
12 -0.935979  1.022229  -0.915626   3.599077e-01 -2.940001   1.068042
13 -2.639468  1.129759  -2.336310   1.951478e-02 -4.854297  -0.424639
14 -2.452786  1.035279  -2.369202   1.786466e-02 -4.482392  -0.423180
15  0.153408  0.758453   0.202264   8.397188e-01 -1.333496   1.640311
16  1.255881  0.770409   1.630149   1.031337e-01 -0.254462   2.766223
17  3.782939  0.662249   5.712256   1.180020e-08  2.484636   5.081241
18  1.327579  0.670242   1.980745   4.767513e-02  0.013607   2.641550
19 -1.547622  0.838261  -1.846229   6.491860e-02 -3.190986   0.095741
20 -2.104664  0.892305  -2.358681   1.837867e-02 -3.853978  -0.355350
21 -2.878083  0.775480  -3.711354   2.084173e-04 -4.398369  -1.357797
22  2.259226  0.768457   2.939953   3.297806e-03  0.752711   3.765742
23  1.888339  0.801961   2.354653   1.857886e-02  0.316140   3.460537
24  4.510768  0.690189   6.535558   6.974922e-11  3.157692   5.863844
25  2.214720  0.663554   3.337664   8.510949e-04  0.913860   3.515579
26  1.849755  0.871730   2.121934   3.389279e-02  0.140777   3.558733
27 -1.355051  0.940723  -1.440436   1.498073e-01 -3.199285   0.489183
28 -1.315615  0.903196  -1.456622   1.452842e-01 -3.086279   0.455049
29  4.041362  0.979012   4.128002   3.719490e-05  2.122065   5.960659
30  4.367025  0.968430   4.509388   6.651680e-06  2.468474   6.265577
31  6.037199  0.836456   7.217589   6.094847e-13  4.397373   7.677024
32  3.810815  0.814526   4.678567   2.965845e-06  2.213983   5.407648
33  2.536485  1.051705   2.411783   1.591088e-02  0.474677   4.598294
34 -1.360917  1.176522  -1.156728   2.474391e-01 -3.667422   0.945588
35 -0.627770  1.150825  -0.545496   5.854373e-01 -2.883898   1.628358
36  7.133428  0.973696   7.326134   2.751678e-13  5.224552   9.042305
37  5.044823  1.015944   4.965652   7.077151e-07  3.053123   7.036524
38  7.220189  0.828980   8.709722   4.092509e-18  5.595020   8.845358
39  6.827332  0.908751   7.512872   6.824209e-14  5.045776   8.608887
40  2.307430  1.143246   2.018315   4.361229e-02  0.066162   4.548699
41  4.151618  1.257398   3.301752   9.676415e-04  1.686560   6.616676
42  2.181122  1.160290   1.879808   6.019292e-02 -0.093560   4.455805
43  9.825443  0.933201  10.528752   1.186673e-25  7.995955  11.654931
44  4.037505  1.127578   3.580688   3.459800e-04  1.826952   6.248058
45  9.153749  0.810367  11.295805   3.128381e-29  7.565070  10.742428
46  5.361873  0.821043   6.530559   7.209702e-11  3.752264   6.971482
47  5.894343  1.018080   5.789665   7.487508e-09  3.898454   7.890231
48  1.393749  1.044746   1.334056   1.822470e-01 -0.654416   3.441914
49  1.197660  1.034498   1.157720   2.470341e-01 -0.830416   3.225735

Finally, we create a new grid to evaluate and plot the effects.

[13]:
grid_size = 100
x_0 = np.linspace(0.1, 0.9, grid_size)
x_1 = np.linspace(0.1, 0.9, grid_size)
x_0, x_1 = np.meshgrid(x_0, x_1)

new_data = {"x_0": x_0.ravel(), "x_1": x_1.ravel()}
[14]:
spline_grid = pd.DataFrame(patsy.build_design_matrices([design_matrix.design_info], new_data)[0])
df_cate = cate.confint(spline_grid, joint=True, n_rep_boot=2000)
print(df_cate)
         2.5 %    effect    97.5 %
0     1.691528  2.373624  3.055720
1     1.704127  2.365692  3.027257
2     1.723757  2.360766  2.997774
3     1.748352  2.358710  2.969067
4     1.775738  2.359390  2.943042
...        ...       ...       ...
9995  3.757634  4.511393  5.265152
9996  3.859517  4.655106  5.450695
9997  3.956571  4.794755  5.632939
9998  4.051531  4.928702  5.805873
9999  4.146613  5.055309  5.964005

[10000 rows x 3 columns]
[15]:
import plotly.graph_objects as go

grid_array = np.array(list(zip(x_0.ravel(), x_1.ravel())))
true_effect = treatment_effect(grid_array).reshape(x_0.shape)
effect = np.asarray(df_cate['effect']).reshape(x_0.shape)
lower_bound = np.asarray(df_cate['2.5 %']).reshape(x_0.shape)
upper_bound = np.asarray(df_cate['97.5 %']).reshape(x_0.shape)

fig = go.Figure(data=[
    go.Surface(x=x_0,
               y=x_1,
               z=true_effect),
    go.Surface(x=x_0,
               y=x_1,
               z=upper_bound, showscale=False, opacity=0.4,colorscale='purp'),
    go.Surface(x=x_0,
               y=x_1,
               z=lower_bound, showscale=False, opacity=0.4,colorscale='purp'),
])
fig.update_traces(contours_z=dict(show=True, usecolormap=True,
                                  highlightcolor="limegreen", project_z=True))

fig.update_layout(scene = dict(
                    xaxis_title='X_0',
                    yaxis_title='X_1',
                    zaxis_title='Effect'),
                    width=700,
                    margin=dict(r=20, b=10, l=10, t=10))

fig.show()