[1]:
%matplotlib inline

Multiway Cluster Robust DML

This example shows how the multiway cluster roboust DML (Chiang et al. 2020) can be implemented with the DoubleML package. Chiang et al. (2020) consider double-indexed data

\begin{equation} \lbrace W_{ij}: i \in \lbrace 1, \ldots, N \rbrace, j \in \lbrace 1, \ldots, M \rbrace \rbrace \end{equation}

and the partially linear IV regression model (PLIV)

\[\begin{split}\begin{aligned} Y_{ij} = D_{ij} \theta_0 + g_0(X_{ij}) + \epsilon_{ij}, & &\mathbb{E}(\epsilon_{ij} | X_{ij}, Z_{ij}) = 0, \\ Z_{ij} = m_0(X_{ij}) + v_{ij}, & &\mathbb{E}(v_{ij} | X_{ij}) = 0. \end{aligned}\end{split}\]

TODO: Add a few more details and the reference! https://arxiv.org/pdf/1909.03489.pdf

[2]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

from sklearn.model_selection import KFold, RepeatedKFold
from sklearn.base import clone

from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression

from doubleml import DoubleMLData, DoubleMLPLIV
from doubleml.double_ml_resampling import DoubleMLMultiwayResampling

from doubleml.datasets import make_pliv_multiway_cluster_CKMS2019

Simulate multiway cluster data

We use the PLIV data generating process described in Section 4.1 of Chiang et al. (2020).

[3]:
# Set the simulation parameters
N = 25  # number of observations (first dimension)
M = 25  # number of observations (second dimension)
dim_X = 100  # dimension of X
np.random.seed(3141) # set seed

obj_dml_data = make_pliv_multiway_cluster_CKMS2019(N, M, dim_X)
[4]:
# The data comes with multi index for rows (tuples with two entries)
obj_dml_data.data.head(30)
[4]:
X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 ... X94 X95 X96 X97 X98 X99 X100 Y D Z
0 0 -0.261903 -0.195564 -0.118952 0.508459 0.226598 -0.544555 0.183888 1.694919 0.750048 0.314384 ... -0.205938 -0.996934 -1.136836 0.010269 -0.396985 -0.161141 -0.614188 0.256567 -0.113780 0.104787
1 -1.082973 -0.303835 0.760778 -0.542671 -0.601598 -0.201768 0.234910 0.212844 -0.523743 -0.267547 ... 0.427573 0.303324 0.247826 0.109273 -0.410795 -0.128408 0.633433 -1.816318 -1.002983 -0.942661
2 -0.110359 -0.679539 0.491245 -0.309772 -0.552727 0.036729 -0.673302 -1.024604 -0.635283 0.005629 ... 0.094381 -0.922996 -2.054068 -0.477474 -0.543380 -0.720664 -0.332996 -0.851366 -0.302648 0.628069
3 -0.010940 1.003427 0.412653 0.784238 0.014637 0.611269 0.323679 -0.557999 0.060568 0.607348 ... 0.698694 1.342992 -1.136089 -0.632058 -0.509958 -0.456370 -0.557595 1.081488 0.438960 0.054162
4 -0.073207 -0.370736 -0.005857 -0.833907 -0.096337 -0.714240 0.094026 0.435401 0.941408 0.260039 ... -0.149714 -0.164864 -0.756805 0.518175 0.510385 -0.681176 -1.020271 -1.708190 -1.805007 -0.850321
5 0.896538 0.178065 0.186276 0.159529 -0.774354 0.192105 -0.657249 -0.198245 -0.003188 -0.094065 ... 0.705045 0.163243 -0.345162 0.226965 0.150167 -0.158045 -0.070203 0.463671 0.872245 0.705041
6 0.108513 -0.079120 0.302778 -1.047476 -0.155363 0.298648 0.749952 -0.124667 -0.113644 0.006786 ... 1.024134 -0.904801 0.576899 0.621352 -0.860336 -1.885346 -0.966592 0.809934 0.976860 0.391837
7 0.822284 0.636223 0.923465 -0.572504 -0.506166 -0.717440 -0.223074 -0.018359 0.239189 0.830245 ... 0.062526 0.085077 -0.026525 -0.228230 -0.818160 -0.096595 -0.872666 0.133181 -0.199412 -0.220692
8 0.525072 0.223879 -0.127802 0.066327 -0.159234 0.141032 0.280494 0.650861 0.556402 -0.217081 ... 0.245080 0.671400 -0.435076 -0.184348 -0.488481 0.601852 0.741026 1.253541 0.599807 0.014742
9 -1.277828 -0.001477 -0.402283 -0.981596 -0.366864 -1.113721 -0.432596 0.671401 -0.173375 0.038081 ... -0.122195 -0.328467 -0.393734 0.438837 -1.671225 -1.519050 -0.664356 -1.535628 -1.113164 -0.167665
10 0.255616 1.049006 0.113948 -0.369794 0.326758 0.121641 0.531501 1.098936 0.535221 0.177892 ... -0.263503 -0.555764 -1.593443 0.889207 0.227092 -0.596293 0.220435 0.796065 1.106031 0.868961
11 -0.459151 -0.429218 0.286828 -0.414682 -0.024168 -0.466549 -0.543770 0.217462 0.158002 1.294238 ... -0.104150 0.405159 -0.699605 -0.252566 -0.611904 -0.492235 0.749119 -1.648823 -0.884543 0.208989
12 -0.334370 0.435261 -0.027700 0.126514 -1.175591 -1.056230 -0.710277 1.017577 0.823502 -0.287104 ... 1.022750 -0.873425 0.359529 -0.280370 0.356360 -0.452985 0.160973 -2.221900 -1.664425 -1.107739
13 -0.302661 0.646040 0.640180 0.042748 -0.113326 -0.220553 0.481621 -0.009104 0.788551 -0.335388 ... -1.037490 -0.409926 0.160269 0.047010 -1.160500 -0.403388 -1.199449 1.478980 0.949196 0.311235
14 0.266911 -0.335268 -0.778099 -0.531650 -0.574113 -0.360536 -0.546285 -0.288788 0.231988 1.066750 ... 0.844182 -0.479060 -0.233906 -0.577776 -1.053606 -0.964492 -1.277590 -1.914087 -1.306469 -0.567831
15 -0.448515 -0.798097 -0.722419 -1.081633 0.301170 -0.469064 0.682194 0.596549 0.507988 1.497058 ... 0.273602 -0.659725 -0.447759 -0.184420 -0.764353 -0.701221 0.274125 -2.001985 -1.043219 -0.773286
16 0.310542 0.377708 -0.825026 0.064748 -0.065031 -0.840901 -0.868081 -0.362104 -0.258764 -0.031331 ... -0.123091 0.402909 -1.259463 -0.075889 0.531235 -1.051569 0.017347 1.015092 1.095498 0.742751
17 -0.388185 -0.253031 0.031026 -0.182594 0.139518 0.217987 -0.102031 1.189637 -0.075355 -0.112260 ... 0.178143 -0.958899 -0.614178 0.430811 -0.438502 0.337254 -0.728490 -1.897825 -1.845649 -0.574103
18 -1.896181 -1.335367 -1.954696 -0.318999 -0.873602 -0.982294 -0.421576 0.285829 0.771820 -0.433136 ... 1.536387 -0.681559 -0.678403 0.065437 -0.448467 -1.152647 -0.928326 -6.174873 -4.216969 -1.682444
19 -0.610261 -0.472693 -1.016153 0.229175 -0.149479 -0.254136 -0.427090 -0.088191 -0.671016 0.170206 ... 0.344147 -0.973939 -0.520261 0.921434 0.094977 0.131462 -0.444317 -1.414520 -1.191680 -0.592696
20 0.354445 0.159904 0.464633 0.179123 -0.517307 0.405764 -0.164046 -0.250802 -0.380517 -0.298236 ... 0.670892 -0.475977 -0.485404 -0.596366 -0.487097 -1.587140 -0.788443 -1.059925 -0.512110 0.187438
21 -0.311117 0.619567 -0.666270 0.080992 -0.850385 0.368533 -0.556263 0.541462 0.652096 -1.036711 ... 0.364330 -0.818845 0.150383 0.099672 -1.383939 -0.690475 -0.706981 -0.925493 -0.077535 0.349462
22 -0.054551 0.268802 -1.092590 -0.479608 -0.151828 1.097881 -1.212572 -0.074945 0.048361 -0.409973 ... -0.045910 0.062560 -0.328584 1.196474 0.883488 -0.309468 -0.104805 -0.467264 -0.502017 -0.315042
23 0.290257 0.890970 0.981587 -0.206763 -0.074407 -0.288766 -0.863318 0.392271 -0.799175 0.227677 ... -0.340238 -0.397798 -0.459780 0.057891 -1.401155 -0.898479 -0.900867 0.748160 1.123161 0.568661
24 0.447506 1.025707 0.165442 -0.367202 0.179182 0.865451 0.204474 0.635203 -0.081398 1.067841 ... 1.093112 -0.710041 -0.179545 0.126526 -0.309415 -0.060628 0.279899 1.250745 0.440775 0.532491
1 0 0.085163 0.126914 0.793147 0.806835 0.348796 -0.183503 0.260567 -0.477626 -0.177070 -0.464281 ... 0.759463 -0.194517 0.334943 0.192110 0.592027 -0.181729 -0.411909 0.154185 -0.044910 0.737830
1 0.640061 0.856448 0.372337 -0.083796 0.638544 0.118188 -1.805682 -0.647788 -0.509075 -0.094675 ... -0.133795 0.075466 1.117873 0.721091 0.719767 -0.880899 -0.875224 0.293850 0.028736 0.213211
2 0.570709 0.343690 1.275316 0.302782 -0.119363 0.287774 -0.983988 -0.532366 -0.727415 0.038715 ... -0.082782 -0.419441 -0.959787 -0.051135 -0.524217 -1.089889 -0.503182 1.817116 0.834741 0.446372
3 1.025913 0.050668 0.337848 1.204428 -0.135959 -0.219208 -0.302827 -1.759298 -0.447924 -0.633410 ... -0.192812 0.335100 -0.704397 0.011564 0.190351 -0.644845 -0.186131 0.669817 -0.314738 0.148209
4 1.672162 0.433178 0.199653 -0.120461 0.019282 -0.337269 -0.725652 -1.898732 -0.994920 -1.276961 ... -0.647876 0.235043 0.723715 0.082780 0.541649 -0.920438 0.158546 2.206842 0.482398 0.392821

30 rows × 103 columns

Initialize the objects of class DoubleMLData and DoubleMLPLIV

[5]:
# Set machine learning methods for m & g
learner = RandomForestRegressor(max_depth=2, n_estimators=10)
ml_g = clone(learner)
ml_m = clone(learner)
ml_r = clone(learner)

# initialize the DoubleMLPLIV object
dml_pliv_obj = DoubleMLPLIV(obj_dml_data,
                            ml_g,
                            ml_m,
                            ml_r,
                            score='partialling out',
                            dml_procedure='dml1',
                            draw_sample_splitting=False)

Split samples and transfer the sample splitting to the object

[6]:
K = 3  # number of folds
smpl_sizes = [N, M]
obj_dml_multiway_resampling = DoubleMLMultiwayResampling(K, smpl_sizes)
smpls_multi_ind, smpls_lin_ind = obj_dml_multiway_resampling.split_samples()

dml_pliv_obj.set_sample_splitting([smpls_lin_ind])
[6]:
<doubleml.double_ml_pliv.DoubleMLPLIV at 0x7f735d944970>

Fit the model and show a summary

[7]:
dml_pliv_obj.fit()
print(dml_pliv_obj.summary)
       coef   std err          t          P>|t|     2.5 %    97.5 %
D  1.197224  0.040051  29.892863  2.436427e-196  1.118727  1.275722

Visualization of sample splitting with tuple and linear indexing

[8]:
#discrete color scheme
x = sns.color_palette("RdBu_r", 7)
cMap = ListedColormap([x[0], x[3], x[6]])
plt.rcParams['figure.figsize'] = 15, 12
sns.set(font_scale=1.3)

Visualize sample splitting with tuples (one plot per fold)

[9]:
for i_split, this_split_ind in enumerate(smpls_multi_ind):
    plt.subplot(K, K, i_split + 1)
    df = pd.DataFrame(np.zeros([N*M, 1]),
                  index = pd.MultiIndex.from_product([range(N), range(M)]),
                  columns=['value'])

    ind_array_train = [*this_split_ind[0]]
    ind_array_test = [*this_split_ind[1]]

    df.loc[ind_array_train, :] = -1.
    df.loc[ind_array_test, :] = 1.

    df_wide = df.reset_index().pivot(index="level_0", columns="level_1", values="value")
    df_wide.index.name=''
    df_wide.columns.name=''

    ax = sns.heatmap(df_wide, cmap=cMap);
    ax.invert_yaxis();
    ax.set_ylim([0, M]);
    colorbar = ax.collections[0].colorbar
    colorbar.set_ticks([-0.667, 0, 0.667])
    if i_split % K == (K - 1):
        colorbar.set_ticklabels(['Nuisance', '', 'Score'])
    else:
        colorbar.set_ticklabels(['', '', ''])
../_images/examples_double_ml_multiway_cluster_15_0.png

Visualize sample splitting with linear indexing (one column per fold)

[10]:
df = pd.DataFrame(np.zeros([N*M, K*K]))
for i_split, this_split_ind in enumerate(smpls_lin_ind):
    df.loc[this_split_ind[0], i_split] = -1.
    df.loc[this_split_ind[1], i_split] = 1.

ax = sns.heatmap(df, cmap=cMap);
ax.invert_yaxis();
ax.set_ylim([0, N*M]);
colorbar = ax.collections[0].colorbar
colorbar.set_ticks([-0.667, 0, 0.667])
colorbar.set_ticklabels(['Nuisance', '', 'Score'])
../_images/examples_double_ml_multiway_cluster_17_0.png