An illustration of A/B testing.

In this notebook, we demontrate exemplarily how the DoubleML package can be used to estimate the causal effect of seeing a new ad design on customers’ purchases in a webshop. We base the estimation steps of our analysis according to the DoubleML workflow.

0. Problem Formulation: A/B Testing

The A/B Testing Scenario

Let’s consider the following stylized scenario. The manager of a webshop performs an A/B test to estimate the effect a new ad design \(A\) has on customers’ purchases (in \(100\$\)), \(Y\), on average. This effect is called the Average Treatment Effect (ATE). The treatment is assigned randomly conditional on the visitors’ characteristics, which we call \(V\). Such characteristics could be collected from a customer’s shoppers account, for example. These might include the number of previous purchases, time since the last purchase, length of stay on a page as well as whether a customer has a rewards card, among other characteristics.

In the following, we use a Directed Acyclical Graph (DAG) to illustrate our assumptions on the causal structure of the scenario. As not only the outcome, but also the treatment is dependent on the individual characteristics, there are arrows going from \(V\) to both \(A\) and \(Y\). In our example, we also assume that the treatment \(A\) is a direct cause of the customers’ purchases \(Y\).

Scenario illustration with a DAG

Let’s assume the conditional randomization has been conducted properly, such that a tidy data set has been collected. Now, a data scientist wants to evaluate whether the new ad design causally affected the sales, by using the DoubleML package.

Why control for individual characteristics?

Before we start the case study, let us briefly address the question why we need to include individual characteristics in our analysis at all. There are mainly two reasons why we want to control for observable characteristics. First, so-called confounders, i.e., variables that have a causal effect on both the treatment variable and the outcome variable, possibly create a bias in our estimate. In order to uncover the true causal effect of the treatment, it is necessary that our causal framework takes all confounding variables into account. Otherwise, the average causal effect of the treatment on the outcome is not identified. A second reason to include individual characteristics is efficiency. The more variation can be explained within our causal framework, the more precise will be the resulting estimate. In practical terms, greater efficiency leads to tighter confidence intervals and smaller standard errors and p-values. This might help to improve the power of A/B tests even if the treatment variable is unconditionally assigned to individuals.

Why use machine learning to analyze A/B tests?

ML methods have turned out to be very flexible in terms of modeling complex relationships of explanatory variables and dependent variables and, thus, have exhibited a great predictive performance in many applications. In the double machine learning approach (Chernozhukov et al. (2018)), ML methods are used for modelling so-called nuisance functions. In terms of the A/B case study considered here, ML tools can be used to flexibly control for confounding variables. For example, a linear parametric specification as in a standard linear regression model might not be correct and, hence, not sufficient to account for the underlying confounding. Moreover, by using powerful ML techniques, the causal model will likely be able to explain a greater share of the total variation and, hence, lead to more precise estimation.

1. Data-Backend

The data set

As an illustrative example we use a data set from the ACIC 2019 Data Challenge. In this challenge, a great number of data sets have been generated in a way that they mimic distributional relationships that are found in many economic real data applications. Although the data have not been generated explicitly to address an A/B testing case study, they are well-suited for demonstration purposes. We will focus on one of the many different data genereting processes (DGP) that we picked at random, in this particualar case a data set called high42. An advantage of using the synthetic ACIC 2019 data is that we know the true average treatment effect which is 0.8 in our data set.

# Load required packages for this tutorial
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
library(ggplot2)

# suppress messages during fitting
lgr::get_logger("mlr3")$set_threshold("warn")

First we load the data.

# Load data set from url (internet connection required)
url = "https://raw.githubusercontent.com/DoubleML/doubleml-docs/master/doc/examples/data/high42.CSV"
df = fread(url)
dim(df)
## [1] 1000  202
head(df)
##           Y A V1 V2 V3 V4       V5       V6 V7         V8        V9 V10 V11 V12
## 1: 7.358185 1 10  0  0  7 192.7938 23.67695  8 0.18544294 15.853240   6   1   0
## 2: 8.333672 1 12  0  1  4 199.6536 19.28127  7 0.51484172  9.244882   5   0   0
## 3: 7.472758 0 14  1  1  2 194.2078 24.58933  5 0.30919878 10.791593   4   1   0
## 4: 6.502319 1  0  1  0  9 201.8380 25.51392  4 0.16016010 22.639362   8   0   1
## 5: 7.043758 1 12  0  0  9 201.3604 31.16064  6 0.29197555 25.793415   3   1   1
## 6: 5.658337 0  8  0  1  6 193.2195 20.46564  9 0.05673076  9.277520   6   1   0
##          V13        V14      V15       V16      V17 V18        V19 V20 V21
## 1: 0.2008238  1.1701569 1.217621 -16.81330 8.070488   1  1.0398440  30   0
## 2: 0.2230785 -1.5365715 1.118535 -15.08600 4.482310   0 -0.5417638  -1   0
## 3: 0.1751244 -1.2845893 1.265334 -27.76753 6.190104   1  1.0213073  59   0
## 4: 0.1487592  0.3477511 1.141496 -15.75004 4.752172   0  1.2265897  25   1
## 5: 0.1381315  0.3454147 1.298493 -30.29079 7.794285   0  1.6554302   8   1
## 6: 0.2511722  0.4423286 1.124828 -33.89765 6.324069   0  0.7876613  25   0
##          V22        V23         V24        V25          V26       V27 V28
## 1: 0.2793573 -1.4206026 0.055594684 0.04032890 7.901696e-05 -58.99537  20
## 2: 0.2561681  2.7890151 0.014976614 0.04844543 8.593516e-03 -20.40609   7
## 3: 0.7500856 -0.6466704 0.006374132 0.06934194 2.614655e-05 -50.77884  16
## 4: 0.5371091  4.6331190 0.057029086 0.12232294 8.090068e-03 -58.68011  27
## 5: 0.7368710 -3.5655872 0.031656106 0.10157513 1.200738e-03 -12.73111  14
## 6: 0.4499631 -0.4609568 0.043176027 0.02658478 2.442900e-02 -48.19671  10
##         V29      V30       V31         V32        V33      V34 V35       V36
## 1: 16.64241 34.71766  4.951458  0.17461066  0.2324233 2.070571   0 20.061676
## 2: 46.53248 21.68320  7.369149  0.43136246  1.9259106 1.833840   1 28.059975
## 3: 14.45965 24.99987 14.157243  0.82646459  6.0069024 1.774421   0 11.849514
## 4: 13.68466 25.50238  7.705198 -0.77489925 32.1862806 1.873539   0 20.509810
## 5: 18.46981 28.61574  8.432233 -0.96055871  0.4755487 1.391923   0  9.656802
## 6: 12.99067 32.42152  9.737143  0.03840029 19.1461227 2.300713   0 16.668824
##             V37 V38       V39       V40 V41      V42 V43 V44 V45 V46       V47
## 1: 0.0004402816   1 0.4810670 1.2442739   5 21.34893   1  27   1   1 20.794542
## 2: 0.0007701103   0 2.1610440 1.0118577  12 12.07421   0  27   1   0 24.425323
## 3: 0.0018029212   0 1.0183895 1.2272869   8 15.07926   0  26   0   1 19.125459
## 4: 0.0026928483   0 2.3844421 0.6641995  15 29.66762   0  23   1   1  6.958642
## 5: 0.0023956495   0 0.2408289 1.1954664   8 34.22849   1  28   0   0 20.872313
## 6: 0.0003555247   1 1.0684471 1.3538031   7 42.12471   1  31   0   1 23.213361
##    V48 V49 V50         V51      V52 V53 V54 V55      V56 V57        V58
## 1:   0   1   1  0.58347949 3.585321   1   0   0 2.143131   0 -4.5760397
## 2:   0   0   1 -0.89160276 2.599451   1   0   1 2.392858   1 -0.2151697
## 3:   0   1   0  1.51972438 4.510590   1   0   1 2.113959   1 -4.0197275
## 4:   0   1   1 -0.20967451 2.151854   1   0   0 2.277861   0 -4.1827249
## 5:   0   1   1 -0.09122315 4.605654   1   0   1 2.076338   0 -3.1929933
## 6:   0   0   1  0.08143847 2.747924   1   0   0 2.123266   0 -1.7701957
##           V59 V60          V61      V62 V63 V64 V65 V66      V67       V68
## 1: -2.3393634   1 1.138766e-03 11.53770   0   4   1   1 3.904367 177.51886
## 2:  0.5914007   1 1.370902e-05 11.26127   0   2   1   0 6.468333  70.21335
## 3:  2.0021808   0 2.299705e-02 10.67112   0   2   1   0 4.796108  28.38070
## 4: -0.8959613   1 1.224616e-03 12.09844   1   5   1   0 2.624500  19.75818
## 5: -0.6357487   0 8.927531e-04 10.82328   0   0   1   0 5.195238  55.38393
## 6:  2.4273998   1 2.320485e-03 10.96484   0   2   1   0 4.032135 160.05336
##             V69        V70          V71       V72      V73 V74      V75
## 1: -0.977834002 -15.655234 -0.339091443 -4.859462 46.77102   7 27.35320
## 2: -0.008146874   2.425168  0.250900654 -4.879361 44.25859   4 22.95139
## 3: -0.093093555  -4.628594  0.493405176 -2.582342 59.11854   3 21.71225
## 4:  0.293489224 -12.201118 -0.319991248 -4.524812 11.21414   3 25.32508
## 5: -0.329338305 -19.628055 -0.007397628 -3.060636 12.19038   5 23.33296
## 6: -0.963812494 -13.956940 -0.071165907 -4.836254 21.10611   9 28.26248
##           V76 V77 V78 V79         V80        V81 V82 V83        V84
## 1: -2.2034941   0   6   1 -0.95657964 0.05625706  29   0 0.03737981
## 2: -0.9763126   0   7   0 -0.91832528 0.08790430  24   0 0.03721510
## 3: -0.7862249   0  11   0 -0.92248578 0.06542868  22   0 0.03099046
## 4: -0.2882286   0  14   0  0.05925364 0.11128509  13   0 0.03540332
## 5:  0.3681190   0   7   0 -0.07175302 0.05468009  16   0 0.02777697
## 6: -2.8403231   0   5   0  0.04554488 0.08101242  34   0 0.03868584
##             V85         V86 V87 V88 V89      V90 V91         V92 V93       V94
## 1: -0.157240114 -0.19498800   0   0  17 1.078362   1  1.25262806   0 0.9756274
## 2: -0.138141504  1.18997769   0   1  17 1.262945   1 -0.49464176   0 0.9756161
## 3:  0.996975342 -4.13585291   0   1  22 2.439541   0 -0.12586824   0 0.9808798
## 4: -2.097800685 -0.02922322   0   0   8 1.039004   0  1.06547295   0 0.9817917
## 5:  0.007976423  0.50863166   0   0  14 1.836204   1 -0.02935486   0 0.9882583
## 6:  0.179009664  0.70442709   0   1  11 1.613274   0  1.21502580   0 0.9839163
##    V95      V96      V97 V98 V99     V100     V101 V102 V103     V104      V105
## 1:   5 2.214363 1.266806   1   0 3.858811 2.248257    7    1 35.61543 0.9860864
## 2:   6 5.377556 1.252839   0   0 3.977629 1.821898    3    0 23.06962 0.9936143
## 3:   6 3.865115 1.279053   1   0 3.745750 2.142148    6    2 55.44521 0.9914764
## 4:   4 2.458335 1.286735   0   0 3.894206 1.742324    5    3 28.14201 0.9307358
## 5:   4 3.042121 1.163366   1   0 3.739791 2.188678   10    2 49.26435 0.9991959
## 6:   5 3.616383 1.276404   1   0 3.959868 2.634676    3    2 43.16583 0.9964038
##           V106       V107 V108      V109 V110     V111 V112      V113
## 1:   0.8109243 -0.9703621    0 0.2730344    0 73.50169    1 0.6634900
## 2:  10.9250096 -1.0017012    0 0.4120395    0 51.62268    1 0.2838329
## 3:  13.6599136 -0.9802746    0 0.5236362    0 62.99163    1 0.9516340
## 4: -25.5088824 -0.9934119    1 0.6180330    0 43.44839    1 0.7741553
## 5:   7.0802934 -0.8977199    1 0.5246819    0 72.31038    1 0.6754245
## 6:  -3.3171809 -0.9993140    1 0.7486420    0 60.74483    0 0.1145096
##          V114      V115     V116       V117 V118 V119 V120     V121       V122
## 1: 0.43676411 0.9991616 10.79522 -0.5385633    1    4  934 33.04630  0.4615909
## 2: 0.36474425 0.9989004 11.55450  0.2812390   15    1  941 33.86293 -5.6026849
## 3: 0.06739379 0.9989085 11.84418 -1.4635959   11    5  874 22.15443 -1.8215083
## 4: 0.13928599 0.9996213 14.81413 -0.3407824   43    8  941 45.65891 -2.9436803
## 5: 0.27403924 0.9202706 18.14229 -1.2608972   50    1  952 19.93158 -1.1604592
## 6: 0.25122805 0.9854695 14.57774 -0.8166704    9    2 1111 22.45532 -1.5620386
##    V123      V124       V125 V126 V127 V128 V129 V130 V131      V132 V133
## 1:    0 0.2877762  0.1592139    2    0    1    0    0    6 -14.26363    1
## 2:    0 0.3517710  0.5749188    0    0    0    0    0   20  44.22123    0
## 3:    1 0.4206429 -0.4930869    0    0    0    1    0   15  44.42188    1
## 4:    1 0.3484142  0.9731821    2    0    1    1    0    4  25.29488    1
## 5:    0 0.4275244 -2.7654479    1    5    0    1    1    7  15.44793    3
## 6:    0 0.4941055 -0.4603852    1    2    0    1    0   14  16.02116    1
##         V134      V135 V136 V137        V138 V139     V140 V141 V142 V143
## 1: 0.6480628 4.9721815    1    1 -0.01270437   14 2.211044    0  -88    1
## 2: 0.5464590 2.8346019    0    0 -3.34385582   15 2.780464    0 -103    1
## 3: 0.4751264 1.4166685    0    1  0.70178508   12 2.583958    0 -104    1
## 4: 0.8665180 2.3020696    0    1  3.17621778    6 2.593933    0  -94    1
## 5: 0.4346036 0.5164262    0    0  1.13260721   11 3.313170    0  -62    1
## 6: 0.4303166 1.7248318    0    1 -0.99296580    7 2.070250    0  -80    1
##         V144      V145       V146     V147     V148 V149 V150      V151
## 1: 14.899569 0.2973792 -1.4045348 22.93680 1.112065    1    0 0.7522654
## 2:  9.283158 1.5241806 -0.5923406 19.84037 1.107726    1    0 0.7412935
## 3:  8.126532 1.9441522  1.7871841 17.53077 1.109680    1    0 0.7495357
## 4:  7.990564 1.3594995  0.2679640 24.41190 1.112921    0    0 0.7961531
## 5:  4.432384 0.9491806 -0.7068222 24.55891 1.101129    0    0 0.7781383
## 6:  5.261740 0.8242018 -0.6108768 20.55625 1.109204    0    1 0.8590800
##          V152 V153     V154 V155      V156      V157 V158       V159 V160
## 1: 0.09487609    1 61.75640    0 0.1555205  1.299214    0  8.9809734    1
## 2: 0.23032948    1 45.97932    1 0.1844182  3.095729    3 14.0603393    1
## 3: 0.13815855    1 51.36620    0 0.2829257  0.810593    4  0.9319917    0
## 4: 0.14440331    0 48.24623    0 0.1138707  2.181276    5  2.5809724    0
## 5: 0.13364686    1 47.29790    1 0.2961928 -0.434977    4  7.2345673    1
## 6: 0.33345130    0 56.33525    0 0.1476424 -1.926538    3  4.4139180    1
##         V161 V162       V163 V164 V165     V166 V167       V168       V169 V170
## 1: 14.597886    0  0.3639882    0    0 2.733640    0 0.16228697  2.2205500    0
## 2:  3.380854    0  0.5129260    0    0 2.724039    0 0.10602104 -0.6852898    0
## 3:  3.109170    1 -1.2629269    0    0 2.715319    0 0.05108764  1.9828793    1
## 4: 57.611061    0  0.5685741    0    0 2.711964    0 0.13866903  1.4471618    2
## 5: 32.457049    0 -0.5958445    0    1 2.713607    0 0.06787200  0.4291680    4
## 6: 24.568721    0 -0.2389066    0    0 2.742426    0 0.11932443  2.4901435    2
##          V171 V172     V173      V174 V175 V176 V177       V178       V179 V180
## 1:  30.147837    0 5.890918 0.5035608    1    0    0 0.64153940 -0.7871431   31
## 2: -86.355520    0 6.902555 0.2127434    4    0    0 0.01900934 -1.0393244   21
## 3:  -2.545955    0 6.704886 0.4259085    3    0    0 0.80010433  0.2331956   21
## 4: -33.742974    1 7.528415 0.3282203    6    0    1 1.35364593  0.1319536   32
## 5: -20.346347    1 5.539800 0.4322122    4    0    0 0.90278942 -0.6498747   30
## 6: -17.029889    1 7.514143 0.2331520    1    0    0 1.22747635  0.8908464   20
##    V181 V182 V183 V184        V185       V186       V187        V188 V189 V190
## 1:    3    0    1   17  0.21751273 0.16146110 -0.1212995  0.46743435    3    0
## 2:    3    0    0   16  0.54973231 0.22095415 -0.5137855 -0.11908883    3    0
## 3:    0    0    0   15  0.42325170 0.11036085  0.2995553 -0.25991994    3    0
## 4:    6    1    0    0 -0.18960878 0.02418432 -0.1869971 -0.08153095    2    0
## 5:    3    0    0   17 -0.03580849 0.24367914 -0.6039507  0.53261710    1    0
## 6:    2    0    0    0 -0.70241461 0.57273510  1.8515487  0.64336691    0    0
##        V191 V192     V193 V194 V195     V196       V197 V198 V199       V200
## 1: 1.462837    1 1627.274    0    0 4.683956  0.5656669    0    3 0.02433804
## 2: 1.330522    1 1661.484    1    0 6.766661 -0.3954021    0    4 0.05651780
## 3: 1.384151    1 1658.939    0    0 5.647794  1.1127661    0    0 0.01344207
## 4: 1.220303    1 1650.802    0    0 5.370363 -0.3058420    0    4 0.03463235
## 5: 1.170094    1 1676.819    0    0 3.446532  2.4406606    0    1 0.01751381
## 6: 1.802945    1 1634.093    1    0 5.294410  1.0869714    1    0 0.04712806

We see that the data set consists of 1000 observations (= website visitors) and 202 variables:

  • Y: A customer’s purchases (in \(100\$\))
  • A: Binary treatment variable with a value 1 indicating that a customer has been exposed to the new ad design (and value 0 otherwise).
  • V1,…, V200: The remaining 200 columns \(V\) represent individual characteristics of the customers (=confounders).

To start our analysis, we initialize the data backend from the previously loaded data set, i.e., we create a new instance of a DoubleMLData object. During initialization, we specify the roles of the variables in the data set, i.e., in our example the outcome variable \(Y\) via the parameter y_col, the treatment variable \(A\) via d_cols and the confounding variables \(V\) via x_cols.

# Specify explanatory variables for data-backend
features_base = colnames(df)[grep("V", colnames(df))]

# Initialize DoubleMLData (data-backend of DoubleML)
data_dml = DoubleMLData$new(df,
                           y_col = "Y",
                           d_cols = "A",
                           x_cols = features_base)

We can print the data-backend to see the variables, which we have assigned as outcome, treatment and controls.

print(data_dml)
## ================= DoubleMLData Object ==================
## 
## 
## ------------------ Data summary      ------------------
## Outcome variable: Y
## Treatment variable(s): A
## Covariates: V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, V191, V192, V193, V194, V195, V196, V197, V198, V199, V200
## Instrument(s): 
## No. Observations: 1000

2. Causal Model

The inference problem is to determine the causal effect of seeing the new ad design \(A\) on customers’ purchases \(Y\) once we control for individual characteristics \(V\). In our example, we are interested in the average treatment effect. Basically, there are two causal models available in DoubleML that can be used to estimate the ATE.

The so-called interactive regression model (IRM) called by DoubleMLIRM is a flexible (nonparametric) model to estimate this causal quantity. The model does not impose functional form restrictions on the underlying regression relationships, for example, linearity or additivity as in a standard linear regression model. This means that the model hosts heterogeneous treatment effects, i.e., account for variation in the effect of the new ad design across customers. Moreover, it is possible to also estimate other causal parameters with the IRM, for example, the average treatment effect on the treated (= those customers who have been exposed to the new ad), which might be of interest too.

2.1. Interactive regression model (IRM)

We briefly introduce the interactive regression model where the main regression relationship of interest is provided by

\[Y = g_0(A, V) + U_1, \quad E(U_1 | V, A) = 0,\]

where the treatment variable is binary, \(A \in \lbrace 0,1 \rbrace\). We consider estimation of the average treatment effect (ATE):

\[\theta_0 = \mathbb{E}[g_0(1, V) - g_0(0,V)],\]

when treatment effects are heterogeneous. In order to be able to use ML methods, the estimation framework generally requires a property called “double robustness” or “Neyman orthogonality”. In the IRM, double robustness can be achieved by including the first-stage estimation

\[A = m_0(V) + U_2, \quad E(U_2| V) = 0,\]

which amounts to estimation of the propensity score, i.e., the probability that a customer is exposed to the treatment provided her observed characteristics. Both predictions are then combined in the doubly robust score for the average treatment effect which is given by

\[\psi(W; \theta, \eta) := g(1,V) - g(0,V) + \frac{A (Y - g(1,V))}{m(V)} - \frac{(1 - A)(Y - g(0,V))}{1 - m(V)} - \theta.\]

2.2. Naive Approach: Unconditional estimate of ATE

As a naive estimate, we could calculate the unconditional average treatment effect. In other words, we simply take the difference between \(Y\) observed for the customers who have been exposed to the treatment \((A=1)\) and those who haven’t been exposed \((A=0)\).

Since the unconditional ATE does not account for the confounding variables, it will generally not correspond to the true ATE (only in the case of unconditionally random treatment assignment, the unconditional ATE will correspond to the true ATE). For example, if the unconditional ATE estimate is greater than the actual ATE, the manager would erroneously overinterpret the effect of the new ad design and probably make misleading decisions for the marketing budget in the future.

df[, mean(Y), by = A]
##    A       V1
## 1: 1 7.953744
## 2: 0 6.836141
ATE_uncond = df[A == 1, mean(Y)] - df[A==0, mean(Y)]
ATE_uncond
## [1] 1.117603

3. ML Methods

In this step, we define the learners that will be used for estimation of the nuisance functions later.

3.1. Benchmark using linear and logistic regression

Let us first start with a benchmark model that is based on (unpenalized) linear and logistic regression. Hence, we estimate the functions \(g_0(A,V)\) using a linear regression model and \(m_0(V)\) by using an (unpenalized) logistic regression. In both cases, we include all available characteristics \(V\). We will later compare the performance of this model to that using more advanced ML methods.

# TODO: Initialize Linear and Logistic Regression learners
linreg = lrn("regr.lm")
logreg_class = lrn("classif.log_reg")

3.2. Instantiate one or several ML learners of your choice

# TODO: Initialize one ML learner of your choice
lasso = lrn("regr.cv_glmnet", nfolds = 5, s = "lambda.min")
lasso_class = lrn("classif.cv_glmnet", nfolds = 5, s = "lambda.min")
# TODO: Initialize a second ML learner of your choice
#      (proceed as long as you like)
randomForest = lrn("regr.ranger")
randomForest_class = lrn("classif.ranger")

4. DML Specifications

At this stage, we instantiate a causal model object of the class DoubleMLIRM. Provide the learners via parameters ml_g and ml_m. You can either stick with the default setting or change the parameters. The documentation for the DoubleMLIRM class is available here. Also have a look at the documentation of the abstract base class DoubleML

Hint: Use set.seed() to set a random seed prior to your initialization. This makes the sample splits of the different models comparable. Also try to use the same DML specifications in all models to attain some comparability.

4.1. Linear and logistic benchmark model

# TODO: Initialize benchmark DoubleMLIRM model
set.seed(1234)
dml_irm_regression = DoubleMLIRM$new(data_dml,
                                    ml_g = linreg,
                                    ml_m = logreg_class,
                                    trimming_threshold = 0.025,
                                    n_folds = 3,
                                    n_rep = 3)

4.2. ML Model of your choice

# TODO: Initialize a DoubleMLIRM model using the ML learners of your choice
set.seed(1234)
dml_irm_lasso = DoubleMLIRM$new(data_dml,
                               ml_g = lasso,
                               ml_m = lasso_class,
                               trimming_threshold = 0.025,
                               n_folds = 3,
                               n_rep = 3)

4.3. - 4.X. ML Model of your choice

Proceed with the models using the other ML learners.

# TODO: Initialize a DoubleMLIRM model using the ML learners of your choice
set.seed(1234)
dml_irm_forest = DoubleMLIRM$new(data_dml,
                                ml_g = randomForest,
                                ml_m = randomForest_class,
                                trimming_threshold = 0.025,
                                n_folds = 3,
                                n_rep = 3)

# Set nuisance-part specific parameters
dml_irm_forest$set_ml_nuisance_params("ml_g0", "A",
                                      list("mtry" = 200,
                                           "num.trees" = 250))
dml_irm_forest$set_ml_nuisance_params("ml_g1", "A",
                                      list("mtry" = 200,
                                           "num.trees" = 250))
dml_irm_forest$set_ml_nuisance_params("ml_m", "A",
                                      list("mtry" = 200,
                                           "num.trees" = 250))

5. Estimation

5.1. Estimation for the Benchmark IRM

# TODO: Fit benchmark DoubleMLIRM model using the fit() method
dml_irm_regression$fit(store_predictions = TRUE)
dml_irm_regression$summary()
## Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)
## A    0.1728     0.7152   0.242    0.809

5.2. Estimation Diagnostics for the Benchmark IRM

5.2.1. Assess the Predictive Performance in the benchmark IRM

To evaluate the different models we can compare how well the employed estimators fit the nuisance functions \(g_0(\cdot)\) and \(m_0(\cdot)\). Use the following helper function to compare the predictive performance of your models.

# A function to calculate prediction accuracy values for every repetition
# of a Double Machine Learning model using IRM, DoubleMLIRM
pred_acc_irm = function(obj, prop) {
  # obj : DoubleML::DoubleMLIRM
  # The IRM Double Machine Learning model
  # prop : logical
  # Indication if RMSE values have to be computed for main regression or
  # log loss values for propensity score  
  
  if (obj$data$n_treat > 1) {
    stop("Number of treatment variable is > 1. Helper function for nuisance accuracy is only implemented for 1 treatment variable.")
  }
  h = obj$data$n_obs
  w = obj$n_rep
  
  y = obj$data$data_model[[obj$data$y_col]]
  d = obj$data$data_model[[obj$data$treat_col]]
  g0 = matrix(obj$predictions[['ml_g0']][,,1], ncol = w)
  g1 = matrix(obj$predictions[['ml_g1']][,,1], ncol = w)
  m = matrix(obj$predictions[['ml_m']][,,1], ncol = w)
  
  if (!all(unique(d) %in% c(0,1))) {
    stop("Treatment must be a binary variable.")
  }

  if (!prop) {
    export_pred = d*g1 + (1-d) * g0
    # Calculate MSE for every repetition
    pred_acc = apply(export_pred, 2,
                     function(x) mlr3measures::rmse(y,x))
  } else {
    pred_acc = rep(NA, w)
      for (j in seq_len(w)) {
          class_probs = matrix(c(1-m[,j],m[,j]), ncol = 2)
          colnames(class_probs) = c("0", "1")
          pred_acc[j] = mlr3measures::logloss(as.factor(d),class_probs)
    }
  }
  return(pred_acc)
}
# TODO: Evaluate the predictive performance for `ml_g` and `ml_m` using the
#       helper function `pred_acc_irm()`.
rmse_main_linlog_irm = pred_acc_irm(dml_irm_regression, prop = FALSE)
rmse_main_linlog_irm_mean = mean(rmse_main_linlog_irm)
rmse_main_linlog_irm_sd = sd(rmse_main_linlog_irm)

logloss_prop_linlog_irm = pred_acc_irm(dml_irm_regression, prop = TRUE)
logloss_prop_linlog_irm_mean = mean(logloss_prop_linlog_irm)
logloss_prop_linlog_irm_sd = sd(logloss_prop_linlog_irm)

Optional: 5.2.2. Evaluation of Propensity Score Estimates in the Benchmark IRM

The propensity score \(m_0(A,V)\) plays an important role in the score of the IRM model. Try to summarize the estimates for \(m_0(A,V)\) using some descriptive statistics or visualization.

# (TODO): Summarize the propensity score estimates

# Function to plot propensity scores
rep_propscore_plot = function(obj) {
  # obj : doubleml
  # The Double Machine Learning model
  if (obj$data$n_treat > 1) {
    stop("Number of treatment variable is > 1. Helper function for nuisance accuracy is only implemented for 1 treatment variable.")
  }
  m = data.table(obj$predictions[['ml_m']][,,1])
  colnames(m) = paste("Repetition", 1:obj$n_rep)
  m = melt(m,
           measure.vars = names(m))
  
  hist_ps = ggplot(m) +
    geom_histogram(aes(y = ..count.., x = value),
                   bins = 25, fill = "darkblue",
                   col= "darkblue", alpha = 0.5) + 
    xlim(c(0,1)) + theme_minimal() + 
    facet_grid(. ~ variable )
  return(hist_ps)
}
rep_propscore_plot(dml_irm_regression)
## Warning: Removed 6 rows containing missing values (geom_bar).

5.3. Estimation for ML Model

# TODO: Fit the ML DoubleMLIRM model using the fit() method
dml_irm_lasso$fit(store_predictions = TRUE)
dml_irm_lasso$summary()
## Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)    
## A   0.85985    0.07177   11.98   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

5.3. Estimation Diagnostics for the IRM using ML Methods

5.3.1. Assess the Predictive Performance in the IRM using ML methods

# TODO: Evaluate the predictive performance for `ml_g` and `ml_m` using the
#       helper function `pred_acc_irm()`.
rmse_main_lasso_irm = pred_acc_irm(dml_irm_lasso, prop = FALSE)
rmse_main_lasso_irm_mean = mean(rmse_main_lasso_irm)
rmse_main_lasso_irm_sd = sd(rmse_main_lasso_irm)

logloss_prop_lasso_irm = pred_acc_irm(dml_irm_lasso, prop = TRUE)
logloss_prop_lasso_irm_mean = mean(logloss_prop_lasso_irm)
logloss_prop_lasso_irm_sd = sd(logloss_prop_lasso_irm)

Optional: 5.3.2. Evaluation of Propensity Score Estimates in the Benchmark IRM

# (TODO): Summarize the propensity score estimates
rep_propscore_plot(dml_irm_lasso)
## Warning: Removed 6 rows containing missing values (geom_bar).

5.4. - 5.X. ML Model of your choice

Proceed with the models using the other ML learners.

dml_irm_forest$fit(store_predictions = TRUE)
dml_irm_forest$summary()
## Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)    
## A   0.88472    0.07656   11.56   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
rmse_main_forest_irm = pred_acc_irm(dml_irm_forest, prop = FALSE)
rmse_main_forest_irm_mean = mean(rmse_main_forest_irm)
rmse_main_forest_irm_sd = sd(rmse_main_forest_irm)

logloss_prop_forest_irm = pred_acc_irm(dml_irm_forest, prop = TRUE)
logloss_prop_forest_irm_mean = mean(logloss_prop_forest_irm)
logloss_prop_forest_irm_sd = sd(logloss_prop_forest_irm)
rep_propscore_plot(dml_irm_forest)
## Warning: Removed 6 rows containing missing values (geom_bar).

5.X+1 Summarize your Results on the Quality of Estimation

Provide a brief summary of your estimation results, for example by creating a table or figure.

# TODO: Summarize the results on the nuisance estimation in a table or figure
estimators = c("linear regression", "lasso", "random forest")
estimators = factor(estimators, levels = estimators)
irm_rmse = data.table(
  "ML" = estimators,
  "RMSE (mean)" = c(rmse_main_linlog_irm_mean,
                    rmse_main_lasso_irm_mean,
                    rmse_main_forest_irm_mean),
  "RMSE (sd)" = c(rmse_main_linlog_irm_sd,
                  rmse_main_lasso_irm_sd,
                  rmse_main_forest_irm_sd),
  "log loss (mean)" = c(logloss_prop_linlog_irm_mean,
                        logloss_prop_lasso_irm_mean,
                        logloss_prop_forest_irm_mean),
  "log loss (sd)" = c(logloss_prop_linlog_irm_sd,
                      logloss_prop_lasso_irm_sd,
                      logloss_prop_forest_irm_sd)
)

print(irm_rmse, 4)
##                   ML RMSE (mean)   RMSE (sd) log loss (mean) log loss (sd)
## 1: linear regression    1.693588 0.040248655       1.0675439   0.044847925
## 2:             lasso    1.125516 0.006438981       0.6667697   0.004654689
## 3:     random forest    1.162918 0.009308529       0.6780709   0.002578629
irm_rmse[, ':=' (
  lower_rmse = `RMSE (mean)` - `RMSE (sd)`,
  upper_rmse = `RMSE (mean)` + `RMSE (sd)`,
  lower_logloss = `log loss (mean)` - `log loss (sd)`,
  upper_logloss = `log loss (mean)`  + `log loss (sd)`)]
g_rmse_irm = ggplot(irm_rmse, aes(x = ML, y = `RMSE (mean)`, color = ML)) +
        geom_point() +
        geom_errorbar(aes(ymin = lower_rmse, ymax = upper_rmse),  color = "darkgrey")  +
        theme_minimal() + ylab("Mean RMSE +/- 1 sd") +
        xlab("") +
        theme(axis.text.x = element_text(angle = 90), legend.position = "none", text = element_text(size = 18))

g_rmse_irm

g_logloss_irm = ggplot(irm_rmse, aes(x = ML, y = `log loss (mean)`, color = ML)) +
        geom_point() +
        geom_errorbar(aes(ymin = lower_logloss, ymax = upper_logloss),  color = "darkgrey")  +
        theme_minimal() + ylab("Mean log loss +/- 1 sd") +
        xlab("") +
        theme(axis.text.x = element_text(angle = 90), legend.position = "none", text = element_text(size = 18))

g_logloss_irm

6. Inference

Summarize your results on the coefficient estimate for \(\theta_0\) as well as the standard errors and / or confidence intervals, respectively. You can create a table or a figure illustrating your findings.

Try to answer the following questions:

Solution:

6.1. Inference for the benchmark IRM

## TODO: After calling fit(), access the coefficient parameter,
##      the standard error and confidence interva by calling the method
##      `summary()` and `confint().
dml_irm_regression$summary()
## Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)
## A    0.1728     0.7152   0.242    0.809
dml_irm_regression$confint()
##       2.5 %   97.5 %
## A -1.228837 1.574525

6.2. Inference for the IRM using ML methods

## TODO: After calling fit(), access the coefficient parameter,
##      the standard error and confidence interval by calling the methods
##      `summary()` and `confint()`.
dml_irm_lasso$summary()
## Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)    
## A   0.85985    0.07177   11.98   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
dml_irm_lasso$confint()
##       2.5 %   97.5 %
## A 0.7191893 1.000519

6.3. - 6.X. ML Model of your choice

Proceed with the models using the other ML learners.

dml_irm_forest$summary()
## Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)    
## A   0.88472    0.07656   11.56   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
dml_irm_forest$confint()
##       2.5 %   97.5 %
## A 0.7346654 1.034768

Variation / Scope for extensions

Variation 1: Partially linear regression

As an alternative to the (nonparametric) IRM model, the DoubleML package also includes the partial linear regression (PLR) model, which assumes the population regression has a linear and additive structure. Although in reality, we never know if this structure really holds for the underlying data generating process, we can apply this model and see how the estimates compare to those from the IRM.

We can estimate the nuisance functions \(g_0\) and \(m_0\) in the following PLR model:

\[\begin{eqnarray} & Y = A\theta_0 + g_0(V) + \zeta, &\quad E[\zeta \mid A,V]= 0,\\ & A = m_0(V) + U_3, &\quad E[U_3 \mid V] = 0. \end{eqnarray}\]

Variation 2: Employ an alternative learner

Instead of the learners used above, we can experiment with different learners that are available from the mlr3 ecosystem. A searchable list of all implemented learners is available here.

Variation 3: Tune a learner or experiment with pipelines

The learner section of the user guide explains how to perform parameter tuning using the mlr3tuning package.

It is also possible to implement pipelines using the mlr3pipelines package. You can find an experimental notebook here.


Notes and Acknowledgement

We would like to thank the organizers of the ACIC 2019 Data Challenge for setting up this data challenge and making the numerous synthetic data examples publicly available. Although the data examples in the ACIC 2019 Data Challenge do not explicitly adress A/B testing, we put the data example here in this context to give a tractable example on the use of causal machine learning in practice. The parameters for the random forests and extreme gradient boosting learners have been tuned externally. The corresponding tuning notebook will be uploaded in the examples gallery in the future.

References

Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W. and Robins, J. (2018), Double/debiased machine learning for treatment and structural parameters. The Econometrics Journal, 21: C1-C68. doi:10.1111/ectj.12097.