Double machine learning for interactive regression models.
Format
R6::R6Class object inheriting from DoubleML.
Details
Interactive regression (IRM) models take the form
\(Y = g_0(D,X) + U\),
\(D = m_0(X) + V\),
with \(E[U|X,D]=0\) and \(E[V|X] = 0\). \(Y\) is the outcome variable and \(D \in \{0,1\}\) is the binary treatment variable. We consider estimation of the average treamtent effects when treatment effects are fully heterogeneous. Target parameters of interest in this model are the average treatment effect (ATE),
\(\theta_0 = E[g_0(1,X) - g_0(0,X)]\)
and the average treament effect on the treated (ATTE),
\(\theta_0 = E[g_0(1,X) - g_0(0,X)|D=1]\).
See also
Other DoubleML:
DoubleML
,
DoubleMLIIVM
,
DoubleMLPLIV
,
DoubleMLPLR
Super class
DoubleML::DoubleML
-> DoubleMLIRM
Active bindings
trimming_rule
(
character(1)
)
Acharacter(1)
specifying the trimming approach.trimming_threshold
(
numeric(1)
)
The threshold used for timming.
Methods
Inherited methods
DoubleML::DoubleML$bootstrap()
DoubleML::DoubleML$confint()
DoubleML::DoubleML$fit()
DoubleML::DoubleML$get_params()
DoubleML::DoubleML$learner_names()
DoubleML::DoubleML$p_adjust()
DoubleML::DoubleML$params_names()
DoubleML::DoubleML$print()
DoubleML::DoubleML$set_ml_nuisance_params()
DoubleML::DoubleML$set_sample_splitting()
DoubleML::DoubleML$split_samples()
DoubleML::DoubleML$summary()
DoubleML::DoubleML$tune()
Method new()
Creates a new instance of this R6 class.
Usage
DoubleMLIRM$new(
data,
ml_g,
ml_m,
n_folds = 5,
n_rep = 1,
score = "ATE",
trimming_rule = "truncate",
trimming_threshold = 1e-12,
dml_procedure = "dml2",
draw_sample_splitting = TRUE,
apply_cross_fitting = TRUE
)
Arguments
data
(
DoubleMLData
)
TheDoubleMLData
object providing the data and specifying the variables of the causal model.ml_g
(
LearnerRegr
,LearnerClassif
,Learner
,character(1)
)
A learner of the classLearnerRegr
, which is available from mlr3 or its extension packages mlr3learners or mlr3extralearners. For binary treatment outcomes, an object of the classLearnerClassif
can be passed, for examplelrn("classif.cv_glmnet", s = "lambda.min")
. Alternatively, aLearner
object with public fieldtask_type = "regr"
ortask_type = "classif"
can be passed, respectively, for example of classGraphLearner
.ml_g
refers to the nuisance function \(g_0(X) = E[Y|X,D]\).ml_m
(
LearnerClassif
,Learner
,character(1)
)
A learner of the classLearnerClassif
, which is available from mlr3 or its extension packages mlr3learners or mlr3extralearners. Alternatively, aLearner
object with public fieldtask_type = "classif"
can be passed, for example of classGraphLearner
. The learner can possibly be passed with specified parameters, for examplelrn("classif.cv_glmnet", s = "lambda.min")
.ml_m
refers to the nuisance function \(m_0(X) = E[D|X]\).n_folds
(
integer(1)
)
Number of folds. Default is5
.n_rep
(
integer(1)
)
Number of repetitions for the sample splitting. Default is1
.score
(
character(1)
,function()
)
Acharacter(1)
("ATE"
orATTE
) or afunction()
specifying the score function. If afunction()
is provided, it must be of the formfunction(y, d, g0_hat, g1_hat, m_hat, smpls)
and the returned output must be a namedlist()
with elementspsi_a
andpsi_b
. Default is"ATE"
.trimming_rule
(
character(1)
)
Acharacter(1)
("truncate"
is the only choice) specifying the trimming approach. Default is"truncate"
.trimming_threshold
(
numeric(1)
)
The threshold used for timming. Default is1e-12
.dml_procedure
(
character(1)
)
Acharacter(1)
("dml1"
or"dml2"
) specifying the double machine learning algorithm. Default is"dml2"
.draw_sample_splitting
(
logical(1)
)
Indicates whether the sample splitting should be drawn during initialization of the object. Default isTRUE
.apply_cross_fitting
(
logical(1)
)
Indicates whether cross-fitting should be applied. Default isTRUE
.
Examples
# \donttest{
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
set.seed(2)
ml_g = lrn("regr.ranger",
num.trees = 100, mtry = 20,
min.node.size = 2, max.depth = 5)
ml_m = lrn("classif.ranger",
num.trees = 100, mtry = 20,
min.node.size = 2, max.depth = 5)
obj_dml_data = make_irm_data(theta = 0.5)
dml_irm_obj = DoubleMLIRM$new(obj_dml_data, ml_g, ml_m)
dml_irm_obj$fit()
dml_irm_obj$summary()
#> Estimates and significance testing of the effect of target variables
#> Estimate. Std. Error t value Pr(>|t|)
#> d 0.6722 0.2851 2.358 0.0184 *
#> ---
#> Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
#>
#>
# }
if (FALSE) {
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(mlr3uning)
library(data.table)
set.seed(2)
ml_g = lrn("regr.rpart")
ml_m = lrn("classif.rpart")
obj_dml_data = make_irm_data(theta = 0.5)
dml_irm_obj = DoubleMLIRM$new(obj_dml_data, ml_g, ml_m)
param_grid = list(
"ml_g" = paradox::ParamSet$new(list(
paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
paradox::ParamInt$new("minsplit", lower = 1, upper = 2))),
"ml_m" = paradox::ParamSet$new(list(
paradox::ParamDbl$new("cp", lower = 0.01, upper = 0.02),
paradox::ParamInt$new("minsplit", lower = 1, upper = 2))))
# minimum requirements for tune_settings
tune_settings = list(
terminator = mlr3tuning::trm("evals", n_evals = 5),
algorithm = mlr3tuning::tnr("grid_search", resolution = 5))
dml_irm_obj$tune(param_set = param_grid, tune_settings = tune_settings)
dml_irm_obj$fit()
dml_irm_obj$summary()
}