In this notebook, we demontrate exemplarily how the DoubleML package can be used to estimate the causal effect of seeing a new ad design on customers’ purchases in a webshop. We base the estimation steps of our analysis according to the DoubleML workflow.
Let’s consider the following stylized scenario. The manager of a
webshop performs an A/B test to estimate the effect a new ad design
\(A\) has on customers’ purchases (in
\(100\$\)), \(Y\), on average. This effect is called the
Average Treatment
Effect (ATE). The treatment is
assigned randomly conditional on the visitors’ characteristics, which we
call \(V\). Such characteristics could
be collected from a customer’s shoppers account, for example. These
might include the number of previous purchases, time since the last
purchase, length of stay on a page as well as whether a customer has a
rewards card, among other characteristics.
In the following, we use a Directed Acyclical Graph (DAG) to illustrate our assumptions on the causal structure of the scenario. As not only the outcome, but also the treatment is dependent on the individual characteristics, there are arrows going from \(V\) to both \(A\) and \(Y\). In our example, we also assume that the treatment \(A\) is a direct cause of the customers’ purchases \(Y\).
Let’s assume the conditional randomization has been conducted properly, such that a tidy data set has been collected. Now, a data scientist wants to evaluate whether the new ad design causally affected the sales, by using the DoubleML package.
Before we start the case study, let us briefly address the question why we need to include individual characteristics in our analysis at all. There are mainly two reasons why we want to control for observable characteristics. First, so-called confounders, i.e., variables that have a causal effect on both the treatment variable and the outcome variable, possibly create a bias in our estimate. In order to uncover the true causal effect of the treatment, it is necessary that our causal framework takes all confounding variables into account. Otherwise, the average causal effect of the treatment on the outcome is not identified. A second reason to include individual characteristics is efficiency. The more variation can be explained within our causal framework, the more precise will be the resulting estimate. In practical terms, greater efficiency leads to tighter confidence intervals and smaller standard errors and p-values. This might help to improve the power of A/B tests even if the treatment variable is unconditionally assigned to individuals.
ML methods have turned out to be very flexible in terms of modeling complex relationships of explanatory variables and dependent variables and, thus, have exhibited a great predictive performance in many applications. In the double machine learning approach (Chernozhukov et al. (2018)), ML methods are used for modelling so-called nuisance functions. In terms of the A/B case study considered here, ML tools can be used to flexibly control for confounding variables. For example, a linear parametric specification as in a standard linear regression model might not be correct and, hence, not sufficient to account for the underlying confounding. Moreover, by using powerful ML techniques, the causal model will likely be able to explain a greater share of the total variation and, hence, lead to more precise estimation.
As an illustrative example we use a data set from the ACIC
2019 Data Challenge. In this challenge, a great number of data sets
have been generated in a way that they mimic distributional
relationships that are found in many economic real data applications.
Although the data have not been generated explicitly to address an A/B
testing case study, they are well-suited for demonstration purposes. We
will focus on one of the many different data genereting processes (DGP)
that we picked at random, in this particualar case a data set called
high42
. An advantage of using the synthetic ACIC
2019 data is that we know the true average treatment effect which is
0.8 in our data set.
# Load required packages for this tutorial
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
library(ggplot2)
# suppress messages during fitting
lgr::get_logger("mlr3")$set_threshold("warn")
First we load the data.
# Load data set from url (internet connection required)
url = "https://raw.githubusercontent.com/DoubleML/doubleml-docs/master/doc/examples/data/high42.CSV"
df = fread(url)
dim(df)
## [1] 1000 202
head(df)
## Y A V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V11 V12
## 1: 7.358185 1 10 0 0 7 192.7938 23.67695 8 0.18544294 15.853240 6 1 0
## 2: 8.333672 1 12 0 1 4 199.6536 19.28127 7 0.51484172 9.244882 5 0 0
## 3: 7.472758 0 14 1 1 2 194.2078 24.58933 5 0.30919878 10.791593 4 1 0
## 4: 6.502319 1 0 1 0 9 201.8380 25.51392 4 0.16016010 22.639362 8 0 1
## 5: 7.043758 1 12 0 0 9 201.3604 31.16064 6 0.29197555 25.793415 3 1 1
## 6: 5.658337 0 8 0 1 6 193.2195 20.46564 9 0.05673076 9.277520 6 1 0
## V13 V14 V15 V16 V17 V18 V19 V20 V21
## 1: 0.2008238 1.1701569 1.217621 -16.81330 8.070488 1 1.0398440 30 0
## 2: 0.2230785 -1.5365715 1.118535 -15.08600 4.482310 0 -0.5417638 -1 0
## 3: 0.1751244 -1.2845893 1.265334 -27.76753 6.190104 1 1.0213073 59 0
## 4: 0.1487592 0.3477511 1.141496 -15.75004 4.752172 0 1.2265897 25 1
## 5: 0.1381315 0.3454147 1.298493 -30.29079 7.794285 0 1.6554302 8 1
## 6: 0.2511722 0.4423286 1.124828 -33.89765 6.324069 0 0.7876613 25 0
## V22 V23 V24 V25 V26 V27 V28
## 1: 0.2793573 -1.4206026 0.055594684 0.04032890 7.901696e-05 -58.99537 20
## 2: 0.2561681 2.7890151 0.014976614 0.04844543 8.593516e-03 -20.40609 7
## 3: 0.7500856 -0.6466704 0.006374132 0.06934194 2.614655e-05 -50.77884 16
## 4: 0.5371091 4.6331190 0.057029086 0.12232294 8.090068e-03 -58.68011 27
## 5: 0.7368710 -3.5655872 0.031656106 0.10157513 1.200738e-03 -12.73111 14
## 6: 0.4499631 -0.4609568 0.043176027 0.02658478 2.442900e-02 -48.19671 10
## V29 V30 V31 V32 V33 V34 V35 V36
## 1: 16.64241 34.71766 4.951458 0.17461066 0.2324233 2.070571 0 20.061676
## 2: 46.53248 21.68320 7.369149 0.43136246 1.9259106 1.833840 1 28.059975
## 3: 14.45965 24.99987 14.157243 0.82646459 6.0069024 1.774421 0 11.849514
## 4: 13.68466 25.50238 7.705198 -0.77489925 32.1862806 1.873539 0 20.509810
## 5: 18.46981 28.61574 8.432233 -0.96055871 0.4755487 1.391923 0 9.656802
## 6: 12.99067 32.42152 9.737143 0.03840029 19.1461227 2.300713 0 16.668824
## V37 V38 V39 V40 V41 V42 V43 V44 V45 V46 V47
## 1: 0.0004402816 1 0.4810670 1.2442739 5 21.34893 1 27 1 1 20.794542
## 2: 0.0007701103 0 2.1610440 1.0118577 12 12.07421 0 27 1 0 24.425323
## 3: 0.0018029212 0 1.0183895 1.2272869 8 15.07926 0 26 0 1 19.125459
## 4: 0.0026928483 0 2.3844421 0.6641995 15 29.66762 0 23 1 1 6.958642
## 5: 0.0023956495 0 0.2408289 1.1954664 8 34.22849 1 28 0 0 20.872313
## 6: 0.0003555247 1 1.0684471 1.3538031 7 42.12471 1 31 0 1 23.213361
## V48 V49 V50 V51 V52 V53 V54 V55 V56 V57 V58
## 1: 0 1 1 0.58347949 3.585321 1 0 0 2.143131 0 -4.5760397
## 2: 0 0 1 -0.89160276 2.599451 1 0 1 2.392858 1 -0.2151697
## 3: 0 1 0 1.51972438 4.510590 1 0 1 2.113959 1 -4.0197275
## 4: 0 1 1 -0.20967451 2.151854 1 0 0 2.277861 0 -4.1827249
## 5: 0 1 1 -0.09122315 4.605654 1 0 1 2.076338 0 -3.1929933
## 6: 0 0 1 0.08143847 2.747924 1 0 0 2.123266 0 -1.7701957
## V59 V60 V61 V62 V63 V64 V65 V66 V67 V68
## 1: -2.3393634 1 1.138766e-03 11.53770 0 4 1 1 3.904367 177.51886
## 2: 0.5914007 1 1.370902e-05 11.26127 0 2 1 0 6.468333 70.21335
## 3: 2.0021808 0 2.299705e-02 10.67112 0 2 1 0 4.796108 28.38070
## 4: -0.8959613 1 1.224616e-03 12.09844 1 5 1 0 2.624500 19.75818
## 5: -0.6357487 0 8.927531e-04 10.82328 0 0 1 0 5.195238 55.38393
## 6: 2.4273998 1 2.320485e-03 10.96484 0 2 1 0 4.032135 160.05336
## V69 V70 V71 V72 V73 V74 V75
## 1: -0.977834002 -15.655234 -0.339091443 -4.859462 46.77102 7 27.35320
## 2: -0.008146874 2.425168 0.250900654 -4.879361 44.25859 4 22.95139
## 3: -0.093093555 -4.628594 0.493405176 -2.582342 59.11854 3 21.71225
## 4: 0.293489224 -12.201118 -0.319991248 -4.524812 11.21414 3 25.32508
## 5: -0.329338305 -19.628055 -0.007397628 -3.060636 12.19038 5 23.33296
## 6: -0.963812494 -13.956940 -0.071165907 -4.836254 21.10611 9 28.26248
## V76 V77 V78 V79 V80 V81 V82 V83 V84
## 1: -2.2034941 0 6 1 -0.95657964 0.05625706 29 0 0.03737981
## 2: -0.9763126 0 7 0 -0.91832528 0.08790430 24 0 0.03721510
## 3: -0.7862249 0 11 0 -0.92248578 0.06542868 22 0 0.03099046
## 4: -0.2882286 0 14 0 0.05925364 0.11128509 13 0 0.03540332
## 5: 0.3681190 0 7 0 -0.07175302 0.05468009 16 0 0.02777697
## 6: -2.8403231 0 5 0 0.04554488 0.08101242 34 0 0.03868584
## V85 V86 V87 V88 V89 V90 V91 V92 V93 V94
## 1: -0.157240114 -0.19498800 0 0 17 1.078362 1 1.25262806 0 0.9756274
## 2: -0.138141504 1.18997769 0 1 17 1.262945 1 -0.49464176 0 0.9756161
## 3: 0.996975342 -4.13585291 0 1 22 2.439541 0 -0.12586824 0 0.9808798
## 4: -2.097800685 -0.02922322 0 0 8 1.039004 0 1.06547295 0 0.9817917
## 5: 0.007976423 0.50863166 0 0 14 1.836204 1 -0.02935486 0 0.9882583
## 6: 0.179009664 0.70442709 0 1 11 1.613274 0 1.21502580 0 0.9839163
## V95 V96 V97 V98 V99 V100 V101 V102 V103 V104 V105
## 1: 5 2.214363 1.266806 1 0 3.858811 2.248257 7 1 35.61543 0.9860864
## 2: 6 5.377556 1.252839 0 0 3.977629 1.821898 3 0 23.06962 0.9936143
## 3: 6 3.865115 1.279053 1 0 3.745750 2.142148 6 2 55.44521 0.9914764
## 4: 4 2.458335 1.286735 0 0 3.894206 1.742324 5 3 28.14201 0.9307358
## 5: 4 3.042121 1.163366 1 0 3.739791 2.188678 10 2 49.26435 0.9991959
## 6: 5 3.616383 1.276404 1 0 3.959868 2.634676 3 2 43.16583 0.9964038
## V106 V107 V108 V109 V110 V111 V112 V113
## 1: 0.8109243 -0.9703621 0 0.2730344 0 73.50169 1 0.6634900
## 2: 10.9250096 -1.0017012 0 0.4120395 0 51.62268 1 0.2838329
## 3: 13.6599136 -0.9802746 0 0.5236362 0 62.99163 1 0.9516340
## 4: -25.5088824 -0.9934119 1 0.6180330 0 43.44839 1 0.7741553
## 5: 7.0802934 -0.8977199 1 0.5246819 0 72.31038 1 0.6754245
## 6: -3.3171809 -0.9993140 1 0.7486420 0 60.74483 0 0.1145096
## V114 V115 V116 V117 V118 V119 V120 V121 V122
## 1: 0.43676411 0.9991616 10.79522 -0.5385633 1 4 934 33.04630 0.4615909
## 2: 0.36474425 0.9989004 11.55450 0.2812390 15 1 941 33.86293 -5.6026849
## 3: 0.06739379 0.9989085 11.84418 -1.4635959 11 5 874 22.15443 -1.8215083
## 4: 0.13928599 0.9996213 14.81413 -0.3407824 43 8 941 45.65891 -2.9436803
## 5: 0.27403924 0.9202706 18.14229 -1.2608972 50 1 952 19.93158 -1.1604592
## 6: 0.25122805 0.9854695 14.57774 -0.8166704 9 2 1111 22.45532 -1.5620386
## V123 V124 V125 V126 V127 V128 V129 V130 V131 V132 V133
## 1: 0 0.2877762 0.1592139 2 0 1 0 0 6 -14.26363 1
## 2: 0 0.3517710 0.5749188 0 0 0 0 0 20 44.22123 0
## 3: 1 0.4206429 -0.4930869 0 0 0 1 0 15 44.42188 1
## 4: 1 0.3484142 0.9731821 2 0 1 1 0 4 25.29488 1
## 5: 0 0.4275244 -2.7654479 1 5 0 1 1 7 15.44793 3
## 6: 0 0.4941055 -0.4603852 1 2 0 1 0 14 16.02116 1
## V134 V135 V136 V137 V138 V139 V140 V141 V142 V143
## 1: 0.6480628 4.9721815 1 1 -0.01270437 14 2.211044 0 -88 1
## 2: 0.5464590 2.8346019 0 0 -3.34385582 15 2.780464 0 -103 1
## 3: 0.4751264 1.4166685 0 1 0.70178508 12 2.583958 0 -104 1
## 4: 0.8665180 2.3020696 0 1 3.17621778 6 2.593933 0 -94 1
## 5: 0.4346036 0.5164262 0 0 1.13260721 11 3.313170 0 -62 1
## 6: 0.4303166 1.7248318 0 1 -0.99296580 7 2.070250 0 -80 1
## V144 V145 V146 V147 V148 V149 V150 V151
## 1: 14.899569 0.2973792 -1.4045348 22.93680 1.112065 1 0 0.7522654
## 2: 9.283158 1.5241806 -0.5923406 19.84037 1.107726 1 0 0.7412935
## 3: 8.126532 1.9441522 1.7871841 17.53077 1.109680 1 0 0.7495357
## 4: 7.990564 1.3594995 0.2679640 24.41190 1.112921 0 0 0.7961531
## 5: 4.432384 0.9491806 -0.7068222 24.55891 1.101129 0 0 0.7781383
## 6: 5.261740 0.8242018 -0.6108768 20.55625 1.109204 0 1 0.8590800
## V152 V153 V154 V155 V156 V157 V158 V159 V160
## 1: 0.09487609 1 61.75640 0 0.1555205 1.299214 0 8.9809734 1
## 2: 0.23032948 1 45.97932 1 0.1844182 3.095729 3 14.0603393 1
## 3: 0.13815855 1 51.36620 0 0.2829257 0.810593 4 0.9319917 0
## 4: 0.14440331 0 48.24623 0 0.1138707 2.181276 5 2.5809724 0
## 5: 0.13364686 1 47.29790 1 0.2961928 -0.434977 4 7.2345673 1
## 6: 0.33345130 0 56.33525 0 0.1476424 -1.926538 3 4.4139180 1
## V161 V162 V163 V164 V165 V166 V167 V168 V169 V170
## 1: 14.597886 0 0.3639882 0 0 2.733640 0 0.16228697 2.2205500 0
## 2: 3.380854 0 0.5129260 0 0 2.724039 0 0.10602104 -0.6852898 0
## 3: 3.109170 1 -1.2629269 0 0 2.715319 0 0.05108764 1.9828793 1
## 4: 57.611061 0 0.5685741 0 0 2.711964 0 0.13866903 1.4471618 2
## 5: 32.457049 0 -0.5958445 0 1 2.713607 0 0.06787200 0.4291680 4
## 6: 24.568721 0 -0.2389066 0 0 2.742426 0 0.11932443 2.4901435 2
## V171 V172 V173 V174 V175 V176 V177 V178 V179 V180
## 1: 30.147837 0 5.890918 0.5035608 1 0 0 0.64153940 -0.7871431 31
## 2: -86.355520 0 6.902555 0.2127434 4 0 0 0.01900934 -1.0393244 21
## 3: -2.545955 0 6.704886 0.4259085 3 0 0 0.80010433 0.2331956 21
## 4: -33.742974 1 7.528415 0.3282203 6 0 1 1.35364593 0.1319536 32
## 5: -20.346347 1 5.539800 0.4322122 4 0 0 0.90278942 -0.6498747 30
## 6: -17.029889 1 7.514143 0.2331520 1 0 0 1.22747635 0.8908464 20
## V181 V182 V183 V184 V185 V186 V187 V188 V189 V190
## 1: 3 0 1 17 0.21751273 0.16146110 -0.1212995 0.46743435 3 0
## 2: 3 0 0 16 0.54973231 0.22095415 -0.5137855 -0.11908883 3 0
## 3: 0 0 0 15 0.42325170 0.11036085 0.2995553 -0.25991994 3 0
## 4: 6 1 0 0 -0.18960878 0.02418432 -0.1869971 -0.08153095 2 0
## 5: 3 0 0 17 -0.03580849 0.24367914 -0.6039507 0.53261710 1 0
## 6: 2 0 0 0 -0.70241461 0.57273510 1.8515487 0.64336691 0 0
## V191 V192 V193 V194 V195 V196 V197 V198 V199 V200
## 1: 1.462837 1 1627.274 0 0 4.683956 0.5656669 0 3 0.02433804
## 2: 1.330522 1 1661.484 1 0 6.766661 -0.3954021 0 4 0.05651780
## 3: 1.384151 1 1658.939 0 0 5.647794 1.1127661 0 0 0.01344207
## 4: 1.220303 1 1650.802 0 0 5.370363 -0.3058420 0 4 0.03463235
## 5: 1.170094 1 1676.819 0 0 3.446532 2.4406606 0 1 0.01751381
## 6: 1.802945 1 1634.093 1 0 5.294410 1.0869714 1 0 0.04712806
We see that the data set consists of 1000 observations (= website visitors) and 202 variables:
Y
: A customer’s purchases (in \(100\$\))A
: Binary treatment variable with a value 1 indicating
that a customer has been exposed to the new ad design (and value 0
otherwise).V1
,…, V200
: The remaining 200 columns
\(V\) represent individual
characteristics of the customers (=confounders).To start our analysis, we initialize the data backend from the
previously loaded data set, i.e., we create a new instance of a DoubleMLData
object. During initialization, we specify the roles of the variables in
the data set, i.e., in our example the outcome variable \(Y\) via the parameter y_col
,
the treatment variable \(A\) via
d_cols
and the confounding variables \(V\) via x_cols
.
# Specify explanatory variables for data-backend
features_base = colnames(df)[grep("V", colnames(df))]
# Initialize DoubleMLData (data-backend of DoubleML)
data_dml = DoubleMLData$new(df,
y_col = "Y",
d_cols = "A",
x_cols = features_base)
We can print the data-backend to see the variables, which we have assigned as outcome, treatment and controls.
print(data_dml)
## ================= DoubleMLData Object ==================
##
##
## ------------------ Data summary ------------------
## Outcome variable: Y
## Treatment variable(s): A
## Covariates: V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15, V16, V17, V18, V19, V20, V21, V22, V23, V24, V25, V26, V27, V28, V29, V30, V31, V32, V33, V34, V35, V36, V37, V38, V39, V40, V41, V42, V43, V44, V45, V46, V47, V48, V49, V50, V51, V52, V53, V54, V55, V56, V57, V58, V59, V60, V61, V62, V63, V64, V65, V66, V67, V68, V69, V70, V71, V72, V73, V74, V75, V76, V77, V78, V79, V80, V81, V82, V83, V84, V85, V86, V87, V88, V89, V90, V91, V92, V93, V94, V95, V96, V97, V98, V99, V100, V101, V102, V103, V104, V105, V106, V107, V108, V109, V110, V111, V112, V113, V114, V115, V116, V117, V118, V119, V120, V121, V122, V123, V124, V125, V126, V127, V128, V129, V130, V131, V132, V133, V134, V135, V136, V137, V138, V139, V140, V141, V142, V143, V144, V145, V146, V147, V148, V149, V150, V151, V152, V153, V154, V155, V156, V157, V158, V159, V160, V161, V162, V163, V164, V165, V166, V167, V168, V169, V170, V171, V172, V173, V174, V175, V176, V177, V178, V179, V180, V181, V182, V183, V184, V185, V186, V187, V188, V189, V190, V191, V192, V193, V194, V195, V196, V197, V198, V199, V200
## Instrument(s):
## No. Observations: 1000
The inference problem is to determine the causal effect of seeing the new ad design \(A\) on customers’ purchases \(Y\) once we control for individual characteristics \(V\). In our example, we are interested in the average treatment effect. Basically, there are two causal models available in DoubleML that can be used to estimate the ATE.
The so-called interactive regression model (IRM) called by DoubleMLIRM is a flexible (nonparametric) model to estimate this causal quantity. The model does not impose functional form restrictions on the underlying regression relationships, for example, linearity or additivity as in a standard linear regression model. This means that the model hosts heterogeneous treatment effects, i.e., account for variation in the effect of the new ad design across customers. Moreover, it is possible to also estimate other causal parameters with the IRM, for example, the average treatment effect on the treated (= those customers who have been exposed to the new ad), which might be of interest too.
We briefly introduce the interactive regression model where the main regression relationship of interest is provided by
\[Y = g_0(A, V) + U_1, \quad E(U_1 | V, A) = 0,\]
where the treatment variable is binary, \(A \in \lbrace 0,1 \rbrace\). We consider estimation of the average treatment effect (ATE):
\[\theta_0 = \mathbb{E}[g_0(1, V) - g_0(0,V)],\]
when treatment effects are heterogeneous. In order to be able to use ML methods, the estimation framework generally requires a property called “double robustness” or “Neyman orthogonality”. In the IRM, double robustness can be achieved by including the first-stage estimation
\[A = m_0(V) + U_2, \quad E(U_2| V) = 0,\]
which amounts to estimation of the propensity score, i.e., the probability that a customer is exposed to the treatment provided her observed characteristics. Both predictions are then combined in the doubly robust score for the average treatment effect which is given by
\[\psi(W; \theta, \eta) := g(1,V) - g(0,V) + \frac{A (Y - g(1,V))}{m(V)} - \frac{(1 - A)(Y - g(0,V))}{1 - m(V)} - \theta.\]
As a naive estimate, we could calculate the unconditional average treatment effect. In other words, we simply take the difference between \(Y\) observed for the customers who have been exposed to the treatment \((A=1)\) and those who haven’t been exposed \((A=0)\).
Since the unconditional ATE does not account for the confounding variables, it will generally not correspond to the true ATE (only in the case of unconditionally random treatment assignment, the unconditional ATE will correspond to the true ATE). For example, if the unconditional ATE estimate is greater than the actual ATE, the manager would erroneously overinterpret the effect of the new ad design and probably make misleading decisions for the marketing budget in the future.
df[, mean(Y), by = A]
## A V1
## 1: 1 7.953744
## 2: 0 6.836141
ATE_uncond = df[A == 1, mean(Y)] - df[A==0, mean(Y)]
ATE_uncond
## [1] 1.117603
In this step, we define the learners that will be used for estimation of the nuisance functions later.
Let us first start with a benchmark model that is based on (unpenalized) linear and logistic regression. Hence, we estimate the functions \(g_0(A,V)\) using a linear regression model and \(m_0(V)\) by using an (unpenalized) logistic regression. In both cases, we include all available characteristics \(V\). We will later compare the performance of this model to that using more advanced ML methods.
# TODO: Initialize Linear and Logistic Regression learners
linreg = lrn("regr.lm")
logreg_class = lrn("classif.log_reg")
# TODO: Initialize one ML learner of your choice
lasso = lrn("regr.cv_glmnet", nfolds = 5, s = "lambda.min")
lasso_class = lrn("classif.cv_glmnet", nfolds = 5, s = "lambda.min")
# TODO: Initialize a second ML learner of your choice
# (proceed as long as you like)
randomForest = lrn("regr.ranger")
randomForest_class = lrn("classif.ranger")
At this stage, we instantiate a causal model object of the class DoubleMLIRM.
Provide the learners via parameters ml_g
and
ml_m
. You can either stick with the default setting or
change the parameters. The documentation for the DoubleMLIRM
class is available here.
Also have a look at the documentation of the abstract base class DoubleML
Hint: Use set.seed()
to set a random
seed prior to your initialization. This makes the sample splits of the
different models comparable. Also try to use the same DML specifications
in all models to attain some comparability.
# TODO: Initialize benchmark DoubleMLIRM model
set.seed(1234)
dml_irm_regression = DoubleMLIRM$new(data_dml,
ml_g = linreg,
ml_m = logreg_class,
trimming_threshold = 0.025,
n_folds = 3,
n_rep = 3)
# TODO: Initialize a DoubleMLIRM model using the ML learners of your choice
set.seed(1234)
dml_irm_lasso = DoubleMLIRM$new(data_dml,
ml_g = lasso,
ml_m = lasso_class,
trimming_threshold = 0.025,
n_folds = 3,
n_rep = 3)
Proceed with the models using the other ML learners.
# TODO: Initialize a DoubleMLIRM model using the ML learners of your choice
set.seed(1234)
dml_irm_forest = DoubleMLIRM$new(data_dml,
ml_g = randomForest,
ml_m = randomForest_class,
trimming_threshold = 0.025,
n_folds = 3,
n_rep = 3)
# Set nuisance-part specific parameters
dml_irm_forest$set_ml_nuisance_params("ml_g0", "A",
list("mtry" = 200,
"num.trees" = 250))
dml_irm_forest$set_ml_nuisance_params("ml_g1", "A",
list("mtry" = 200,
"num.trees" = 250))
dml_irm_forest$set_ml_nuisance_params("ml_m", "A",
list("mtry" = 200,
"num.trees" = 250))
# TODO: Fit benchmark DoubleMLIRM model using the fit() method
dml_irm_regression$fit(store_predictions = TRUE)
dml_irm_regression$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## A 0.1728 0.7152 0.242 0.809
To evaluate the different models we can compare how well the employed estimators fit the nuisance functions \(g_0(\cdot)\) and \(m_0(\cdot)\). Use the following helper function to compare the predictive performance of your models.
# A function to calculate prediction accuracy values for every repetition
# of a Double Machine Learning model using IRM, DoubleMLIRM
pred_acc_irm = function(obj, prop) {
# obj : DoubleML::DoubleMLIRM
# The IRM Double Machine Learning model
# prop : logical
# Indication if RMSE values have to be computed for main regression or
# log loss values for propensity score
if (obj$data$n_treat > 1) {
stop("Number of treatment variable is > 1. Helper function for nuisance accuracy is only implemented for 1 treatment variable.")
}
h = obj$data$n_obs
w = obj$n_rep
y = obj$data$data_model[[obj$data$y_col]]
d = obj$data$data_model[[obj$data$treat_col]]
g0 = matrix(obj$predictions[['ml_g0']][,,1], ncol = w)
g1 = matrix(obj$predictions[['ml_g1']][,,1], ncol = w)
m = matrix(obj$predictions[['ml_m']][,,1], ncol = w)
if (!all(unique(d) %in% c(0,1))) {
stop("Treatment must be a binary variable.")
}
if (!prop) {
export_pred = d*g1 + (1-d) * g0
# Calculate MSE for every repetition
pred_acc = apply(export_pred, 2,
function(x) mlr3measures::rmse(y,x))
} else {
pred_acc = rep(NA, w)
for (j in seq_len(w)) {
class_probs = matrix(c(1-m[,j],m[,j]), ncol = 2)
colnames(class_probs) = c("0", "1")
pred_acc[j] = mlr3measures::logloss(as.factor(d),class_probs)
}
}
return(pred_acc)
}
# TODO: Evaluate the predictive performance for `ml_g` and `ml_m` using the
# helper function `pred_acc_irm()`.
rmse_main_linlog_irm = pred_acc_irm(dml_irm_regression, prop = FALSE)
rmse_main_linlog_irm_mean = mean(rmse_main_linlog_irm)
rmse_main_linlog_irm_sd = sd(rmse_main_linlog_irm)
logloss_prop_linlog_irm = pred_acc_irm(dml_irm_regression, prop = TRUE)
logloss_prop_linlog_irm_mean = mean(logloss_prop_linlog_irm)
logloss_prop_linlog_irm_sd = sd(logloss_prop_linlog_irm)
The propensity score \(m_0(A,V)\) plays an important role in the score of the IRM model. Try to summarize the estimates for \(m_0(A,V)\) using some descriptive statistics or visualization.
# (TODO): Summarize the propensity score estimates
# Function to plot propensity scores
rep_propscore_plot = function(obj) {
# obj : doubleml
# The Double Machine Learning model
if (obj$data$n_treat > 1) {
stop("Number of treatment variable is > 1. Helper function for nuisance accuracy is only implemented for 1 treatment variable.")
}
m = data.table(obj$predictions[['ml_m']][,,1])
colnames(m) = paste("Repetition", 1:obj$n_rep)
m = melt(m,
measure.vars = names(m))
hist_ps = ggplot(m) +
geom_histogram(aes(y = ..count.., x = value),
bins = 25, fill = "darkblue",
col= "darkblue", alpha = 0.5) +
xlim(c(0,1)) + theme_minimal() +
facet_grid(. ~ variable )
return(hist_ps)
}
rep_propscore_plot(dml_irm_regression)
## Warning: Removed 6 rows containing missing values (geom_bar).
# TODO: Fit the ML DoubleMLIRM model using the fit() method
dml_irm_lasso$fit(store_predictions = TRUE)
dml_irm_lasso$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## A 0.85985 0.07177 11.98 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
# TODO: Evaluate the predictive performance for `ml_g` and `ml_m` using the
# helper function `pred_acc_irm()`.
rmse_main_lasso_irm = pred_acc_irm(dml_irm_lasso, prop = FALSE)
rmse_main_lasso_irm_mean = mean(rmse_main_lasso_irm)
rmse_main_lasso_irm_sd = sd(rmse_main_lasso_irm)
logloss_prop_lasso_irm = pred_acc_irm(dml_irm_lasso, prop = TRUE)
logloss_prop_lasso_irm_mean = mean(logloss_prop_lasso_irm)
logloss_prop_lasso_irm_sd = sd(logloss_prop_lasso_irm)
# (TODO): Summarize the propensity score estimates
rep_propscore_plot(dml_irm_lasso)
## Warning: Removed 6 rows containing missing values (geom_bar).
Proceed with the models using the other ML learners.
dml_irm_forest$fit(store_predictions = TRUE)
dml_irm_forest$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## A 0.88472 0.07656 11.56 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
rmse_main_forest_irm = pred_acc_irm(dml_irm_forest, prop = FALSE)
rmse_main_forest_irm_mean = mean(rmse_main_forest_irm)
rmse_main_forest_irm_sd = sd(rmse_main_forest_irm)
logloss_prop_forest_irm = pred_acc_irm(dml_irm_forest, prop = TRUE)
logloss_prop_forest_irm_mean = mean(logloss_prop_forest_irm)
logloss_prop_forest_irm_sd = sd(logloss_prop_forest_irm)
rep_propscore_plot(dml_irm_forest)
## Warning: Removed 6 rows containing missing values (geom_bar).
Provide a brief summary of your estimation results, for example by creating a table or figure.
# TODO: Summarize the results on the nuisance estimation in a table or figure
estimators = c("linear regression", "lasso", "random forest")
estimators = factor(estimators, levels = estimators)
irm_rmse = data.table(
"ML" = estimators,
"RMSE (mean)" = c(rmse_main_linlog_irm_mean,
rmse_main_lasso_irm_mean,
rmse_main_forest_irm_mean),
"RMSE (sd)" = c(rmse_main_linlog_irm_sd,
rmse_main_lasso_irm_sd,
rmse_main_forest_irm_sd),
"log loss (mean)" = c(logloss_prop_linlog_irm_mean,
logloss_prop_lasso_irm_mean,
logloss_prop_forest_irm_mean),
"log loss (sd)" = c(logloss_prop_linlog_irm_sd,
logloss_prop_lasso_irm_sd,
logloss_prop_forest_irm_sd)
)
print(irm_rmse, 4)
## ML RMSE (mean) RMSE (sd) log loss (mean) log loss (sd)
## 1: linear regression 1.693588 0.040248655 1.0675439 0.044847925
## 2: lasso 1.125516 0.006438981 0.6667697 0.004654689
## 3: random forest 1.162918 0.009308529 0.6780709 0.002578629
irm_rmse[, ':=' (
lower_rmse = `RMSE (mean)` - `RMSE (sd)`,
upper_rmse = `RMSE (mean)` + `RMSE (sd)`,
lower_logloss = `log loss (mean)` - `log loss (sd)`,
upper_logloss = `log loss (mean)` + `log loss (sd)`)]
g_rmse_irm = ggplot(irm_rmse, aes(x = ML, y = `RMSE (mean)`, color = ML)) +
geom_point() +
geom_errorbar(aes(ymin = lower_rmse, ymax = upper_rmse), color = "darkgrey") +
theme_minimal() + ylab("Mean RMSE +/- 1 sd") +
xlab("") +
theme(axis.text.x = element_text(angle = 90), legend.position = "none", text = element_text(size = 18))
g_rmse_irm
g_logloss_irm = ggplot(irm_rmse, aes(x = ML, y = `log loss (mean)`, color = ML)) +
geom_point() +
geom_errorbar(aes(ymin = lower_logloss, ymax = upper_logloss), color = "darkgrey") +
theme_minimal() + ylab("Mean log loss +/- 1 sd") +
xlab("") +
theme(axis.text.x = element_text(angle = 90), legend.position = "none", text = element_text(size = 18))
g_logloss_irm
Summarize your results on the coefficient estimate for \(\theta_0\) as well as the standard errors and / or confidence intervals, respectively. You can create a table or a figure illustrating your findings.
Try to answer the following questions:
Solution:
## TODO: After calling fit(), access the coefficient parameter,
## the standard error and confidence interva by calling the method
## `summary()` and `confint().
dml_irm_regression$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## A 0.1728 0.7152 0.242 0.809
dml_irm_regression$confint()
## 2.5 % 97.5 %
## A -1.228837 1.574525
## TODO: After calling fit(), access the coefficient parameter,
## the standard error and confidence interval by calling the methods
## `summary()` and `confint()`.
dml_irm_lasso$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## A 0.85985 0.07177 11.98 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
dml_irm_lasso$confint()
## 2.5 % 97.5 %
## A 0.7191893 1.000519
Proceed with the models using the other ML learners.
dml_irm_forest$summary()
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## A 0.88472 0.07656 11.56 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
dml_irm_forest$confint()
## 2.5 % 97.5 %
## A 0.7346654 1.034768
As an alternative to the (nonparametric) IRM model, the DoubleML package also includes the partial linear regression (PLR) model, which assumes the population regression has a linear and additive structure. Although in reality, we never know if this structure really holds for the underlying data generating process, we can apply this model and see how the estimates compare to those from the IRM.
We can estimate the nuisance functions \(g_0\) and \(m_0\) in the following PLR model:
\[\begin{eqnarray} & Y = A\theta_0 + g_0(V) + \zeta, &\quad E[\zeta \mid A,V]= 0,\\ & A = m_0(V) + U_3, &\quad E[U_3 \mid V] = 0. \end{eqnarray}\]
Instead of the learners used above, we can experiment with different
learners that are available from the mlr3
ecosystem. A
searchable list of all implemented learners is available here.
The learner section of the user guide explains how to perform parameter tuning using the mlr3tuning package.
It is also possible to implement pipelines using the mlr3pipelines package. You can find an experimental notebook here.
Notes and Acknowledgement
We would like to thank the organizers of the ACIC 2019 Data Challenge for setting up this data challenge and making the numerous synthetic data examples publicly available. Although the data examples in the ACIC 2019 Data Challenge do not explicitly adress A/B testing, we put the data example here in this context to give a tractable example on the use of causal machine learning in practice. The parameters for the random forests and extreme gradient boosting learners have been tuned externally. The corresponding tuning notebook will be uploaded in the examples gallery in the future.
Chernozhukov, V., Chetverikov, D., Demirer, M., Duflo, E., Hansen, C., Newey, W. and Robins, J. (2018), Double/debiased machine learning for treatment and structural parameters. The Econometrics Journal, 21: C1-C68. doi:10.1111/ectj.12097.