In this example, we will demonstrate the use of the DoubleML package in a real-data industry example: Estimation of price elasticity of demand. This notebook is based on a blogpost by Lars Roemheld (Roemheld, 2021) with code and preprocessed data being available from GitHub. The original data file is made available as a public domain (CC0 1.0 Universal) data set and shared on kaggle. It contains data on sales from an online retailer in the period of December 2010 until December 2011.
The data preprocessing is performed in a separate notebook that is available online. To keep the computational effort at a moderate level, we will only use a subset of the data that is used in Roemheld (2021). Our main goal is to illustrate the main steps of elasticity estimation with DoubleML.
The following case study is organized according to the steps of the DoubleML workflow.
"Supply" and "demand" are probably the very first terms that economics and business students hear in their studies. In industry, the price elasticity of demand is a very important quantity: It indicates how much the demand for a product (= the quantity sold by the firm) changes due to a change in its price. As a retailer, this quantity is of great interest because it makes it possible to increase revenues, and eventually profits, by optimally adjusting prices according to elasticities.
The price elasticity of demand is formally defined as the relative change of the demanded quantity ($q$) of a product given a percent-change of the price ($p$)
$$\theta_0 = \frac{\partial q/q}{\partial p/p}.$$In words, the parameter $\theta_0$ can be interpreted as follows: Provided the price for a product increases by $1\%$, the demanded quantity changes by $\theta_0\%$.
In general, it would be possible to estimate $\theta_0$ based on an experiment or A/B test. However, this is not possible in our case as the data set only contains information on actual purchases in the period of consideration.
The causal problem of price estimation based on an observational study is quite complex: It involves many (simultaneous) decisions made by the customers and the sellers. One approach for estimation of the causal parameter $\theta_0$ would be to account for confounding variables, that might have an effect to both the price and the quantity sold. The approach taken in Roemheld (2021) is to flexibly account for and construct confounding variables, for example including similarities in their product description or seasonal patterns, and thereby justifying identification of $\theta_0$.
We can use a partially linear regression (PLR) model for estimation of $\theta_0$
$$\log Q = \theta_0 \log P + g_0(X) + \zeta,$$with $\mathbb{E}(\zeta|D,X)=0$. The confounders can enter the regression equation nonlinearily via the function $g_0(X)$. In order to equip $\theta_0$ (approximately) with the interpretation of a price elasticity, we applied the $\log()$ to both the demanded quantity ($Q$) and the prices ($P$), i.e., we set up a $\log$-$\log$-regression.
Before we proceed with the data analysis, it is important to mention a potential drawback to our analysis: The data only contains information on sales, not on stock days. Hence, based on this data set, it is not possible to assess what happened on days without sales (sales = 0). This drawback must be kept in mind when we draw causal conclusions from this analysis.
To give an idea on the general setting we briefly load an exemplary data excerpt from the original data set. We can see that the data lists the transaction of a (online) retailer selling products like inflatable political globes or fancy pens.
# Load required modules
from sklearn import linear_model
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler, RobustScaler
from sklearn.feature_extraction.text import CountVectorizer
import sklearn.preprocessing
import pandas as pd, numpy as np
from datetime import datetime, date
from matplotlib import pyplot as plt
import seaborn as sns
import doubleml as dml
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LassoCV
from sklearn.ensemble import RandomForestRegressor
import numpy as np
# Load example data set
url = 'https://raw.githubusercontent.com/DoubleML/doubleml-docs/master/doc/examples/data/orig_demand_data_example.csv'
data_example = pd.read_csv(url)
data_example
Unnamed: 0 | Date | StockCode | Country | Description | Quantity | revenue | UnitPrice | |
---|---|---|---|---|---|---|---|---|
0 | 0 | 2010-12-01 | 10002 | France | INFLATABLE POLITICAL GLOBE | 48 | 40.80 | 0.850 |
1 | 1 | 2010-12-01 | 10002 | United Kingdom | INFLATABLE POLITICAL GLOBE | 12 | 10.20 | 0.850 |
2 | 2 | 2010-12-01 | 10125 | United Kingdom | MINI FUNKY DESIGN TAPES | 2 | 1.70 | 0.850 |
3 | 3 | 2010-12-01 | 10133 | United Kingdom | COLOURING PENCILS BROWN TUBE | 5 | 4.25 | 0.850 |
4 | 4 | 2010-12-01 | 10135 | United Kingdom | COLOURING PENCILS BROWN TUBE | 1 | 2.51 | 2.510 |
5 | 5 | 2010-12-01 | 11001 | United Kingdom | ASSTD DESIGN RACING CAR PEN | 3 | 10.08 | 3.360 |
6 | 6 | 2010-12-01 | 15044B | United Kingdom | BLUE PAPER PARASOL | 1 | 2.95 | 2.950 |
7 | 7 | 2010-12-01 | 15056BL | United Kingdom | EDWARDIAN PARASOL BLACK | 20 | 113.00 | 5.650 |
8 | 8 | 2010-12-01 | 15056N | United Kingdom | EDWARDIAN PARASOL NATURAL | 50 | 236.30 | 4.726 |
9 | 9 | 2010-12-01 | 15056P | United Kingdom | EDWARDIAN PARASOL PINK | 48 | 220.80 | 4.600 |
In our analysis, we will use a preprocessed data set. Each row corresponds to the sales of a product at a specific date $t$.
In the data we have,
Quantity
: Quantity demandedrevenue
: RevenueUnitPrice
: Price per unitmonth
: MonthDoM
: Day of monthDoW
: Day of weekstock_age_days
: Number of days product has been sold / observed in the datasku_avg_p
: Average (=median) price of the product2010-12-01
, ...: Date dummiesAustralia
, ...: Country dummies1
, 2
, ... : Numerical variables constructed to capture similarities in product descriptions (n-grams)dLnP
: Change in PricedLnQ
: Change in QuantityNote that we do not include product dummies as the price and quantity variables have been demeaned to account for product characteristics.
url2 = 'https://raw.githubusercontent.com/DoubleML/doubleml-docs/master/doc/examples/data/elasticity_subset.csv'
demand_data = pd.read_csv(url2)
print(demand_data.columns)
Index(['Unnamed: 0', 'Quantity', 'revenue', 'UnitPrice', 'month', 'DoM', 'DoW', 'stock_age_days', 'sku_avg_p', '2010-12-01 00:00:00', ... '544', '545', '546', '547', '548', '549', '550', '551', 'dLnP', 'dLnQ'], dtype='object', length=906)
# Print dimensions of data set
print(demand_data.shape)
(10000, 906)
# Glimpse at first rows of data set
demand_data.head()
Unnamed: 0 | Quantity | revenue | UnitPrice | month | DoM | DoW | stock_age_days | sku_avg_p | 2010-12-01 00:00:00 | ... | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | dLnP | dLnQ | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 189628 | 5 | 8.15 | 1.630000 | 9 | 6 | 1 | 278 | 0.85 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0.408612 | -1.013840 |
1 | 37914 | 19 | 41.93 | 2.206842 | 1 | 25 | 1 | 55 | 2.10 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | -0.048811 | -0.306368 |
2 | 80103 | 24 | 20.40 | 0.850000 | 3 | 31 | 3 | 120 | 0.85 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | -0.127108 | 1.371479 |
3 | 75019 | 12 | 23.40 | 1.950000 | 3 | 24 | 3 | 113 | 2.08 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | -0.228256 | -1.034483 |
4 | 99878 | 4 | 39.80 | 9.950000 | 5 | 5 | 3 | 155 | 9.95 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | -0.034922 | -0.050881 |
5 rows × 906 columns
To initiate the data backend, we create a new DoubleMLData object. During instantiation, we assign the roles of the variables, i.e., dLnQ
as the dependent var, dLnP
as the treatment variable and the remaining variables as confounders.
feature_names = list(demand_data.columns[4:(demand_data.shape[1]-2),])
data_dml = dml.DoubleMLData(demand_data,
y_col = 'dLnQ',
d_cols = 'dLnP',
x_cols = feature_names)
print(data_dml)
================== DoubleMLData Object ================== ------------------ Data summary ------------------ Outcome variable: dLnQ Treatment variable(s): ['dLnP'] Covariates: ['month', 'DoM', 'DoW', 'stock_age_days', 'sku_avg_p', '2010-12-01 00:00:00', '2010-12-02 00:00:00', '2010-12-03 00:00:00', '2010-12-05 00:00:00', '2010-12-06 00:00:00', '2010-12-07 00:00:00', '2010-12-08 00:00:00', '2010-12-09 00:00:00', '2010-12-10 00:00:00', '2010-12-12 00:00:00', '2010-12-13 00:00:00', '2010-12-14 00:00:00', '2010-12-15 00:00:00', '2010-12-16 00:00:00', '2010-12-17 00:00:00', '2010-12-19 00:00:00', '2010-12-20 00:00:00', '2010-12-21 00:00:00', '2010-12-22 00:00:00', '2010-12-23 00:00:00', '2011-01-04 00:00:00', '2011-01-05 00:00:00', '2011-01-06 00:00:00', '2011-01-07 00:00:00', '2011-01-09 00:00:00', '2011-01-10 00:00:00', '2011-01-11 00:00:00', '2011-01-12 00:00:00', '2011-01-13 00:00:00', '2011-01-14 00:00:00', '2011-01-16 00:00:00', '2011-01-17 00:00:00', '2011-01-18 00:00:00', '2011-01-19 00:00:00', '2011-01-20 00:00:00', '2011-01-21 00:00:00', '2011-01-23 00:00:00', '2011-01-24 00:00:00', '2011-01-25 00:00:00', '2011-01-26 00:00:00', '2011-01-27 00:00:00', '2011-01-28 00:00:00', '2011-01-30 00:00:00', '2011-01-31 00:00:00', '2011-02-01 00:00:00', '2011-02-02 00:00:00', '2011-02-03 00:00:00', '2011-02-04 00:00:00', '2011-02-06 00:00:00', '2011-02-07 00:00:00', '2011-02-08 00:00:00', '2011-02-09 00:00:00', '2011-02-10 00:00:00', '2011-02-11 00:00:00', '2011-02-13 00:00:00', '2011-02-14 00:00:00', '2011-02-15 00:00:00', '2011-02-16 00:00:00', '2011-02-17 00:00:00', '2011-02-18 00:00:00', '2011-02-20 00:00:00', '2011-02-21 00:00:00', '2011-02-22 00:00:00', '2011-02-23 00:00:00', '2011-02-24 00:00:00', '2011-02-25 00:00:00', '2011-02-27 00:00:00', '2011-02-28 00:00:00', '2011-03-01 00:00:00', '2011-03-02 00:00:00', '2011-03-03 00:00:00', '2011-03-04 00:00:00', '2011-03-06 00:00:00', '2011-03-07 00:00:00', '2011-03-08 00:00:00', '2011-03-09 00:00:00', '2011-03-10 00:00:00', '2011-03-11 00:00:00', '2011-03-13 00:00:00', '2011-03-14 00:00:00', '2011-03-15 00:00:00', '2011-03-16 00:00:00', '2011-03-17 00:00:00', '2011-03-18 00:00:00', '2011-03-20 00:00:00', '2011-03-21 00:00:00', '2011-03-22 00:00:00', '2011-03-23 00:00:00', '2011-03-24 00:00:00', '2011-03-25 00:00:00', '2011-03-27 00:00:00', '2011-03-28 00:00:00', '2011-03-29 00:00:00', '2011-03-30 00:00:00', '2011-03-31 00:00:00', '2011-04-01 00:00:00', '2011-04-03 00:00:00', '2011-04-04 00:00:00', '2011-04-05 00:00:00', '2011-04-06 00:00:00', '2011-04-07 00:00:00', '2011-04-08 00:00:00', '2011-04-10 00:00:00', '2011-04-11 00:00:00', '2011-04-12 00:00:00', '2011-04-13 00:00:00', '2011-04-14 00:00:00', '2011-04-15 00:00:00', '2011-04-17 00:00:00', '2011-04-18 00:00:00', '2011-04-19 00:00:00', '2011-04-20 00:00:00', '2011-04-21 00:00:00', '2011-04-26 00:00:00', '2011-04-27 00:00:00', '2011-04-28 00:00:00', '2011-05-01 00:00:00', '2011-05-03 00:00:00', '2011-05-04 00:00:00', '2011-05-05 00:00:00', '2011-05-06 00:00:00', '2011-05-08 00:00:00', '2011-05-09 00:00:00', '2011-05-10 00:00:00', '2011-05-11 00:00:00', '2011-05-12 00:00:00', '2011-05-13 00:00:00', '2011-05-15 00:00:00', '2011-05-16 00:00:00', '2011-05-17 00:00:00', '2011-05-18 00:00:00', '2011-05-19 00:00:00', '2011-05-20 00:00:00', '2011-05-22 00:00:00', '2011-05-23 00:00:00', '2011-05-24 00:00:00', '2011-05-25 00:00:00', '2011-05-26 00:00:00', '2011-05-27 00:00:00', '2011-05-29 00:00:00', '2011-05-31 00:00:00', '2011-06-01 00:00:00', '2011-06-02 00:00:00', '2011-06-03 00:00:00', '2011-06-05 00:00:00', '2011-06-06 00:00:00', '2011-06-07 00:00:00', '2011-06-08 00:00:00', '2011-06-09 00:00:00', '2011-06-10 00:00:00', '2011-06-12 00:00:00', '2011-06-13 00:00:00', '2011-06-14 00:00:00', '2011-06-15 00:00:00', '2011-06-16 00:00:00', '2011-06-17 00:00:00', '2011-06-19 00:00:00', '2011-06-20 00:00:00', '2011-06-21 00:00:00', '2011-06-22 00:00:00', '2011-06-23 00:00:00', '2011-06-24 00:00:00', '2011-06-26 00:00:00', '2011-06-27 00:00:00', '2011-06-28 00:00:00', '2011-06-29 00:00:00', '2011-06-30 00:00:00', '2011-07-01 00:00:00', '2011-07-03 00:00:00', '2011-07-04 00:00:00', '2011-07-05 00:00:00', '2011-07-06 00:00:00', '2011-07-07 00:00:00', '2011-07-08 00:00:00', '2011-07-10 00:00:00', '2011-07-11 00:00:00', '2011-07-12 00:00:00', '2011-07-13 00:00:00', '2011-07-14 00:00:00', '2011-07-15 00:00:00', '2011-07-17 00:00:00', '2011-07-18 00:00:00', '2011-07-19 00:00:00', '2011-07-20 00:00:00', '2011-07-21 00:00:00', '2011-07-22 00:00:00', '2011-07-24 00:00:00', '2011-07-25 00:00:00', '2011-07-26 00:00:00', '2011-07-27 00:00:00', '2011-07-28 00:00:00', '2011-07-29 00:00:00', '2011-07-31 00:00:00', '2011-08-01 00:00:00', '2011-08-02 00:00:00', '2011-08-03 00:00:00', '2011-08-04 00:00:00', '2011-08-05 00:00:00', '2011-08-07 00:00:00', '2011-08-08 00:00:00', '2011-08-09 00:00:00', '2011-08-10 00:00:00', '2011-08-11 00:00:00', '2011-08-12 00:00:00', '2011-08-14 00:00:00', '2011-08-15 00:00:00', '2011-08-16 00:00:00', '2011-08-17 00:00:00', '2011-08-18 00:00:00', '2011-08-19 00:00:00', '2011-08-21 00:00:00', '2011-08-22 00:00:00', '2011-08-23 00:00:00', '2011-08-24 00:00:00', '2011-08-25 00:00:00', '2011-08-26 00:00:00', '2011-08-28 00:00:00', '2011-08-30 00:00:00', '2011-08-31 00:00:00', '2011-09-01 00:00:00', '2011-09-02 00:00:00', '2011-09-04 00:00:00', '2011-09-05 00:00:00', '2011-09-06 00:00:00', '2011-09-07 00:00:00', '2011-09-08 00:00:00', '2011-09-09 00:00:00', '2011-09-11 00:00:00', '2011-09-12 00:00:00', '2011-09-13 00:00:00', '2011-09-14 00:00:00', '2011-09-15 00:00:00', '2011-09-16 00:00:00', '2011-09-18 00:00:00', '2011-09-19 00:00:00', '2011-09-20 00:00:00', '2011-09-21 00:00:00', '2011-09-22 00:00:00', '2011-09-23 00:00:00', '2011-09-25 00:00:00', '2011-09-26 00:00:00', '2011-09-27 00:00:00', '2011-09-28 00:00:00', '2011-09-29 00:00:00', '2011-09-30 00:00:00', '2011-10-02 00:00:00', '2011-10-03 00:00:00', '2011-10-04 00:00:00', '2011-10-05 00:00:00', '2011-10-06 00:00:00', '2011-10-07 00:00:00', '2011-10-09 00:00:00', '2011-10-10 00:00:00', '2011-10-11 00:00:00', '2011-10-12 00:00:00', '2011-10-13 00:00:00', '2011-10-14 00:00:00', '2011-10-16 00:00:00', '2011-10-17 00:00:00', '2011-10-18 00:00:00', '2011-10-19 00:00:00', '2011-10-20 00:00:00', '2011-10-21 00:00:00', '2011-10-23 00:00:00', '2011-10-24 00:00:00', '2011-10-25 00:00:00', '2011-10-26 00:00:00', '2011-10-27 00:00:00', '2011-10-28 00:00:00', '2011-10-30 00:00:00', '2011-10-31 00:00:00', '2011-11-01 00:00:00', '2011-11-02 00:00:00', '2011-11-03 00:00:00', '2011-11-04 00:00:00', '2011-11-06 00:00:00', '2011-11-07 00:00:00', '2011-11-08 00:00:00', '2011-11-09 00:00:00', '2011-11-10 00:00:00', '2011-11-11 00:00:00', '2011-11-13 00:00:00', '2011-11-14 00:00:00', '2011-11-15 00:00:00', '2011-11-16 00:00:00', '2011-11-17 00:00:00', '2011-11-18 00:00:00', '2011-11-20 00:00:00', '2011-11-21 00:00:00', '2011-11-22 00:00:00', '2011-11-23 00:00:00', '2011-11-24 00:00:00', '2011-11-25 00:00:00', '2011-11-27 00:00:00', '2011-11-28 00:00:00', '2011-11-29 00:00:00', '2011-11-30 00:00:00', '2011-12-01 00:00:00', '2011-12-02 00:00:00', '2011-12-04 00:00:00', '2011-12-05 00:00:00', '2011-12-06 00:00:00', '2011-12-07 00:00:00', '2011-12-08 00:00:00', '2011-12-09 00:00:00', 'Australia', 'Austria', 'Bahrain', 'Belgium', 'Brazil', 'Canada', 'Channel Islands', 'Cyprus', 'Czech Republic', 'Denmark', 'EIRE', 'European Community', 'Finland', 'France', 'Germany', 'Greece', 'Hong Kong', 'Iceland', 'Israel', 'Italy', 'Japan', 'Lebanon', 'Lithuania', 'Malta', 'Netherlands', 'Norway', 'Poland', 'Portugal', 'RSA', 'Saudi Arabia', 'Singapore', 'Spain', 'Sweden', 'Switzerland', 'USA', 'United Arab Emirates', 'United Kingdom', 'Unspecified', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', '199', '200', '201', '202', '203', '204', '205', '206', '207', '208', '209', '210', '211', '212', '213', '214', '215', '216', '217', '218', '219', '220', '221', '222', '223', '224', '225', '226', '227', '228', '229', '230', '231', '232', '233', '234', '235', '236', '237', '238', '239', '240', '241', '242', '243', '244', '245', '246', '247', '248', '249', '250', '251', '252', '253', '254', '255', '256', '257', '258', '259', '260', '261', '262', '263', '264', '265', '266', '267', '268', '269', '270', '271', '272', '273', '274', '275', '276', '277', '278', '279', '280', '281', '282', '283', '284', '285', '286', '287', '288', '289', '290', '291', '292', '293', '294', '295', '296', '297', '298', '299', '300', '301', '302', '303', '304', '305', '306', '307', '308', '309', '310', '311', '312', '313', '314', '315', '316', '317', '318', '319', '320', '321', '322', '323', '324', '325', '326', '327', '328', '329', '330', '331', '332', '333', '334', '335', '336', '337', '338', '339', '340', '341', '342', '343', '344', '345', '346', '347', '348', '349', '350', '351', '352', '353', '354', '355', '356', '357', '358', '359', '360', '361', '362', '363', '364', '365', '366', '367', '368', '369', '370', '371', '372', '373', '374', '375', '376', '377', '378', '379', '380', '381', '382', '383', '384', '385', '386', '387', '388', '389', '390', '391', '392', '393', '394', '395', '396', '397', '398', '399', '400', '401', '402', '403', '404', '405', '406', '407', '408', '409', '410', '411', '412', '413', '414', '415', '416', '417', '418', '419', '420', '421', '422', '423', '424', '425', '426', '427', '428', '429', '430', '431', '432', '433', '434', '435', '436', '437', '438', '439', '440', '441', '442', '443', '444', '445', '446', '447', '448', '449', '450', '451', '452', '453', '454', '455', '456', '457', '458', '459', '460', '461', '462', '463', '464', '465', '466', '467', '468', '469', '470', '471', '472', '473', '474', '475', '476', '477', '478', '479', '480', '481', '482', '483', '484', '485', '486', '487', '488', '489', '490', '491', '492', '493', '494', '495', '496', '497', '498', '499', '500', '501', '502', '503', '504', '505', '506', '507', '508', '509', '510', '511', '512', '513', '514', '515', '516', '517', '518', '519', '520', '521', '522', '523', '524', '525', '526', '527', '528', '529', '530', '531', '532', '533', '534', '535', '536', '537', '538', '539', '540', '541', '542', '543', '544', '545', '546', '547', '548', '549', '550', '551'] Instrument variable(s): None No. Observations: 10000 ------------------ DataFrame info ------------------ <class 'pandas.core.frame.DataFrame'> RangeIndex: 10000 entries, 0 to 9999 Columns: 906 entries, Unnamed: 0 to dLnQ dtypes: float64(5), int64(901) memory usage: 69.1 MB
We already stated that a partially linear regression model in a $\log$-$\log$-specification will allow us to interpret the regression coefficient $\theta_0$ as the price elasticity of demand. We restate the main regression as well as the auxiliary regression that is required for orthogonality
$$\begin{aligned}\log Q &= \theta_0 \log P + g_0(X) + \zeta,\\ \log P &= m_0(X) + V\end{aligned},$$with $\mathbb{E}(\zeta|D,X)=0$ and $\mathbb{E}(V|X)=0$. As stated above, we hope to justify the assumption $\mathbb{E}(\zeta|D,X)=0$ by sufficiently accounting for the confounding variables $X$.
We start with the linear regression model as a benchmark lerner for learning nuisance parameters $g_0(X)$ and $m_0(X)$. We additionally set up two models based on a lasso learner as well as a random forest learner and compare our results.
ml_l_lin_reg = LinearRegression()
ml_m_lin_reg = LinearRegression()
Cs = 0.0001*np.logspace(0, 4, 10)
ml_l_lasso = make_pipeline(StandardScaler(), LassoCV(cv=5, max_iter=10000,
n_jobs = -1))
ml_m_lasso = make_pipeline(StandardScaler(), LassoCV(cv=5, max_iter=10000,
n_jobs = -1))
ml_l_forest = RandomForestRegressor(n_estimators=50,
min_samples_leaf=3,
n_jobs=-1, verbose=0)
ml_m_forest = RandomForestRegressor(n_estimators=50,
min_samples_leaf=3,
n_jobs=-1, verbose=0)
For each learner configuration, we initialize a new DoubleMLPLR object. We stick to the default options, i.e., dml_procedure = 'dml2'
, score = "partialling out"
, n_folds = 5.
np.random.seed(123)
dml_plr_lin_reg = dml.DoubleMLPLR(data_dml,
ml_l = ml_l_lin_reg,
ml_m = ml_m_lin_reg)
np.random.seed(123)
dml_plr_lasso = dml.DoubleMLPLR(data_dml,
ml_l = ml_l_lasso,
ml_m = ml_m_lasso)
np.random.seed(123)
dml_plr_forest = dml.DoubleMLPLR(data_dml,
ml_l = ml_l_forest,
ml_m = ml_m_forest)
To estimate our target parameter $\theta_0$, we call the fit()
method. The results can be summarized by accessing the summary
field.
dml_plr_lin_reg.fit(store_predictions = True)
summary_plr_lin_reg = dml_plr_lin_reg.summary
summary_plr_lin_reg
coef | std err | t | P>|t| | 2.5 % | 97.5 % | |
---|---|---|---|---|---|---|
dLnP | -0.725216 | 0.003777 | -191.994849 | 0.0 | -0.732619 | -0.717813 |
dml_plr_lasso.fit(store_predictions = True)
summary_plr_lasso = dml_plr_lasso.summary
summary_plr_lasso
coef | std err | t | P>|t| | 2.5 % | 97.5 % | |
---|---|---|---|---|---|---|
dLnP | -1.81883 | 0.04411 | -41.233963 | 0.0 | -1.905284 | -1.732376 |
dml_plr_forest.fit(store_predictions = True)
summary_plr_forest = dml_plr_forest.summary
summary_plr_forest
coef | std err | t | P>|t| | 2.5 % | 97.5 % | |
---|---|---|---|---|---|---|
dLnP | -1.800381 | 0.043246 | -41.631025 | 0.0 | -1.885142 | -1.71562 |
Let us now compare how well the three models approximate the nuisance functions $g_0(X)$ and $m_0(X)$. We first define a helper function that calculates the RMSE for both.
from sklearn.metrics import mean_squared_error
def pred_acc_plr(DoubleML, nuis):
"""
A function to calculate prediction accuracy values for every repetition
of a Double Machine Learning model using PLR, DoubleMLPLR
...
Parameters
----------
DoubleML : doubleml.double_ml_plr.DoubleMLPLR
The PLR Double Machine Learning model
nuis : str
Indicates nuisance component for evaluation of RMSE, either `'ml_l'` or `ml_m`.
"""
# export data, fitted coefficient and predictions of the DoubleML model
y = DoubleML._dml_data.y
d = DoubleML._dml_data.d
theta = DoubleML.coef
ml_nuis = DoubleML.predictions.get(nuis)
# dimensions of prediction array
h = ml_nuis.shape[0]
export_pred_array = np.zeros(h)
if nuis == 'ml_l':
for j in range(h):
export_pred_array[j] = theta*d[j] + ml_nuis[j]
elif nuis == 'ml_m':
for j in range(h):
export_pred_array[j] = ml_nuis[j]
rmse = mean_squared_error(y, export_pred_array, squared=False)
return rmse
rmse_lin_reg_ml_l = pred_acc_plr(dml_plr_lin_reg, 'ml_l')
rmse_lin_reg_ml_m = pred_acc_plr(dml_plr_lin_reg, 'ml_m')
rmse_lasso_ml_l = pred_acc_plr(dml_plr_lasso, 'ml_l')
rmse_lasso_ml_m = pred_acc_plr(dml_plr_lasso, 'ml_m')
rmse_forest_ml_l = pred_acc_plr(dml_plr_forest, 'ml_l')
rmse_forest_ml_m = pred_acc_plr(dml_plr_forest, 'ml_m')
We visualize and compare the results in terms of the predictive performance.
plr_rmse_index = ['regression','lasso', 'forest']
plr_rmse = pd.DataFrame([[rmse_lin_reg_ml_l, rmse_lin_reg_ml_m],
[rmse_lasso_ml_l, rmse_lasso_ml_m],
[rmse_forest_ml_l, rmse_forest_ml_m]],
index=plr_rmse_index,
columns=['ml_l', 'ml_m'])
plr_rmse.round(3)
ml_l | ml_m | |
---|---|---|
regression | 77729.446 | 107174.807 |
lasso | 1.018 | 1.287 |
forest | 1.019 | 1.292 |
plt.scatter(x = plr_rmse_index, y= plr_rmse['ml_l'])
plt.title('RMSE, ml_l')
plt.xlabel('learner')
_ = plt.ylabel('RMSE')
plt.scatter(x = plr_rmse_index, y= plr_rmse['ml_m'])
plt.title('RMSE, ml_m')
plt.xlabel('learner')
_ = plt.ylabel('RMSE')
We can visualize and summarize our findings so far. We can conclude that the price elasticity of demand, as indicated by the causal parameter $\theta_0$, is around $-1.8$. In all models, the coefficient is significantly different from zero.
plr_summary = pd.concat((summary_plr_lin_reg,
summary_plr_lasso,
summary_plr_forest))
plr_summary.index = ['linear regression', 'lasso', 'forest']
plr_summary[['coef', '2.5 %', '97.5 %']]
coef | 2.5 % | 97.5 % | |
---|---|---|---|
linear regression | -0.725216 | -0.732619 | -0.717813 |
lasso | -1.818830 | -1.905284 | -1.732376 |
forest | -1.800381 | -1.885142 | -1.715620 |
errors = np.full((2, plr_summary.shape[0]), np.nan)
errors[0, :] = plr_summary['coef'] - plr_summary['2.5 %']
errors[1, :] = plr_summary['97.5 %'] - plr_summary['coef']
plt.errorbar(plr_summary.index, plr_summary.coef, fmt='o', yerr=errors)
plt.axhline(y=0, color='gray')
plt.ylim([-3, 0.1])
plt.title('Partially Linear Regression Model (PLR)')
plt.xlabel('ML method')
_ = plt.ylabel('Coefficients and 95%-CI')
Acknowledgement
We would like to thank Lars Roemheld for setting up the blog post on demand estimation using double machine learning as well as for sharing the code and preprocessed data set. We hope that with this notebook, we illustrate how to run such an analysis using DoubleML. Moreover, we would like to thank Anzony Quispe for excellent assistance in creating this notebook.