{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "126e2f75", "metadata": {}, "source": [ "# R: Ensemble Learners and More with `mlr3pipelines`" ] }, { "attachments": {}, "cell_type": "markdown", "id": "aa30571a", "metadata": {}, "source": [ "This notebook illustrates how to exploit the powerful tools provided by the [mlr3pipelines](https://mlr3pipelines.mlr-org.com/) package (Binder et al. 2021). For example, [mlr3pipelines](https://mlr3pipelines.mlr-org.com/) can be used in combination with [DoubleML](https://docs.doubleml.org/stable/index.html) for feature engineering, combination of learners (ensemble learners, stacking), subsampling and hyperparameter tuning. The underlying idea of [mlr3pipelines](https://mlr3pipelines.mlr-org.com/) is to define a pipeline that incorporates a user's desired operations. As a result, the pipeline returns an object of class [Learner](https://mlr3.mlr-org.com/reference/Learner.html) which can easily be passed to [DoubleML](https://docs.doubleml.org/stable/index.html). For an introduction to [mlr3pipelines](https://mlr3pipelines.mlr-org.com/), we refer to the [Pipelines Chapter](https://mlr3book.mlr-org.com/pipelines.html) in the [mlr3book](https://mlr3book.mlr-org.com) (Becker et al. 2020) and to the [package website](https://mlr3pipelines.mlr-org.com/).\n", "\n", "\n", "We intend to illustrate the major idea of how to use [mlr3pipelines](https://mlr3pipelines.mlr-org.com/) in combination with [DoubleML](https://docs.doubleml.org/stable/index.html) in very simple examples. We use pipelines that are identical or very similar to the ones in the [Pipelines Chapter](https://mlr3book.mlr-org.com/pipelines.html) in the [mlr3book](https://mlr3book.mlr-org.com). Hence, we do not claim that the proposed learners are optimal in terms of their performance.\n", "\n", "We start with the simple simulated data example and the [Bonus data set](https://docs.doubleml.org/r/stable/reference/fetch_bonus.html) from the [Getting Started Section in the DoubleML user guide](https://docs.doubleml.org/dev/intro/intro.html)." ] }, { "cell_type": "code", "execution_count": 1, "id": "fc314dcf", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/plain": [ "================= DoubleMLData Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: y\n", "Treatment variable(s): d\n", "Covariates: X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, X29, X30, X31, X32, X33, X34, X35, X36, X37, X38, X39, X40, X41, X42, X43, X44, X45, X46, X47, X48, X49, X50, X51, X52, X53, X54, X55, X56, X57, X58, X59, X60, X61, X62, X63, X64, X65, X66, X67, X68, X69, X70, X71, X72, X73, X74, X75, X76, X77, X78, X79, X80, X81, X82, X83, X84, X85, X86, X87, X88, X89, X90, X91, X92, X93, X94, X95, X96, X97, X98, X99, X100\n", "Instrument(s): \n", "No. Observations: 500" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "library(DoubleML)\n", "\n", "# Simulate data\n", "set.seed(3141)\n", "n_obs = 500\n", "n_vars = 100\n", "theta = 3\n", "X = matrix(rnorm(n_obs*n_vars), nrow=n_obs, ncol=n_vars)\n", "d = X[,1:3]%*%c(5,5,5) + rnorm(n_obs)\n", "y = theta*d + X[, 1:3]%*%c(5,5,5) + rnorm(n_obs)\n", "\n", "\n", "# Specify the data and variables for the causal model\n", "# matrix interface to DoubleMLData\n", "dml_data_sim = double_ml_data_from_matrix(X=X, y=y, d=d)\n", "dml_data_sim" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5913f88c", "metadata": {}, "source": [ "To have an example with a classification learner, we load the [Bonus data set](https://docs.doubleml.org/r/stable/reference/fetch_bonus.html)." ] }, { "cell_type": "code", "execution_count": 2, "id": "dea5e13a", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\t\n", "\t\n", "\n", "\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\n", "
A data.table: 6 × 17
inuidur1femaleblackothracedep1dep2q2q3q4q5q6agelt35agegt54durablelusdhusdtg
<dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
2.8903720000100010000010
0.0000000000000010000100
3.2958370000000100000100
2.1972250000001000100001
3.2958370001000010011100
3.2958371000000010010100
\n" ], "text/latex": [ "A data.table: 6 × 17\n", "\\begin{tabular}{lllllllllllllllll}\n", " inuidur1 & female & black & othrace & dep1 & dep2 & q2 & q3 & q4 & q5 & q6 & agelt35 & agegt54 & durable & lusd & husd & tg\\\\\n", " & & & & & & & & & & & & & & & & \\\\\n", "\\hline\n", "\t 2.890372 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 1 & 0\\\\\n", "\t 0.000000 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 1 & 0 & 0\\\\\n", "\t 3.295837 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0\\\\\n", "\t 2.197225 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 1\\\\\n", "\t 3.295837 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 1 & 1 & 1 & 0 & 0\\\\\n", "\t 3.295837 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 & 1 & 0 & 1 & 0 & 0\\\\\n", "\\end{tabular}\n" ], "text/markdown": [ "\n", "A data.table: 6 × 17\n", "\n", "| inuidur1 <dbl> | female <dbl> | black <dbl> | othrace <dbl> | dep1 <dbl> | dep2 <dbl> | q2 <dbl> | q3 <dbl> | q4 <dbl> | q5 <dbl> | q6 <dbl> | agelt35 <dbl> | agegt54 <dbl> | durable <dbl> | lusd <dbl> | husd <dbl> | tg <dbl> |\n", "|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|\n", "| 2.890372 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |\n", "| 0.000000 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |\n", "| 3.295837 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |\n", "| 2.197225 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 |\n", "| 3.295837 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 1 | 1 | 0 | 0 |\n", "| 3.295837 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 1 | 0 | 0 |\n", "\n" ], "text/plain": [ " inuidur1 female black othrace dep1 dep2 q2 q3 q4 q5 q6 agelt35 agegt54\n", "1 2.890372 0 0 0 0 1 0 0 0 1 0 0 0 \n", "2 0.000000 0 0 0 0 0 0 0 0 1 0 0 0 \n", "3 3.295837 0 0 0 0 0 0 0 1 0 0 0 0 \n", "4 2.197225 0 0 0 0 0 0 1 0 0 0 1 0 \n", "5 3.295837 0 0 0 1 0 0 0 0 1 0 0 1 \n", "6 3.295837 1 0 0 0 0 0 0 0 1 0 0 1 \n", " durable lusd husd tg\n", "1 0 0 1 0 \n", "2 0 1 0 0 \n", "3 0 1 0 0 \n", "4 0 0 0 1 \n", "5 1 1 0 0 \n", "6 0 1 0 0 " ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLData Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: inuidur1\n", "Treatment variable(s): tg\n", "Covariates: female, black, othrace, dep1, dep2, q2, q3, q4, q5, q6, agelt35, agegt54, durable, lusd, husd\n", "Instrument(s): \n", "No. Observations: 5099\n" ] } ], "source": [ "# Load bonus data\n", "df_bonus = fetch_bonus(return_type=\"data.table\")\n", "head(df_bonus)\n", "\n", "# Specify the data and variables for the causal model\n", "x_vars = c(\"female\", \"black\", \"othrace\", \"dep1\", \"dep2\",\n", " \"q2\", \"q3\", \"q4\", \"q5\", \"q6\", \"agelt35\", \"agegt54\",\n", " \"durable\", \"lusd\", \"husd\")\n", "dim_x = length(x_vars)\n", "dml_data_bonus = DoubleMLData$new(df_bonus,\n", " y_col = \"inuidur1\",\n", " d_cols = \"tg\",\n", " x_cols = x_vars)\n", "print(dml_data_bonus)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1bfd66c6", "metadata": {}, "source": [ "To specify a learner for the nuisance part in a causal model, we can either use [mlr3](https://mlr3.mlr-org.com/)'s [LearnerRegr](https://mlr3.mlr-org.com/reference/LearnerRegr.html) for a model's nuisance part with a continuous dependent variable or a [LearnerClassif](https://mlr3.mlr-org.com/reference/LearnerClassif.html) if the corresponding outcome variable is binary. \n", "\n", "Moreover, it's possible to create a learner based on a pipeline. For example, we could think of [ensemble learners](https://mlr3book.mlr-org.com/pipelines.html#sec-pipelines-intro) which combine several estimators." ] }, { "attachments": {}, "cell_type": "markdown", "id": "5e23aab7", "metadata": {}, "source": [ "## Using learners from `mlr3`, `mlr3learners` and `mlr3extralearners`\n", "\n", "Let's begin with a \"standard\" example on how to use any of the learners provided by [mlr3](https://mlr3.mlr-org.com/) (Lang et al. 2020), [mlr3learners](https://mlr3learners.mlr-org.com/) (Lang et al. 2021) and [mlr3extralearners](https://mlr3extralearners.mlr-org.com/) (Sonabend and Schratz 2021) in [DoubleML](https://docs.doubleml.org/stable/index.html): We create an object of the class [Learner](https://mlr3.mlr-org.com/reference/Learner.html) which [DoubleML](https://docs.doubleml.org/stable/index.html) internally uses for model training and generation of predictions. \n", "\n", "In the simulated example, we will use a lasso estimator for the continuous treatment variable, which is based on the [glmnet package](https://glmnet.stanford.edu/index.html). For the binary treatment variable in the Bonus data example, we use a random forest classifier as provided by [ranger](https://github.com/imbs-hl/ranger)." ] }, { "cell_type": "code", "execution_count": 3, "id": "c0260edd", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Warning message:\n", "\"Paket 'mlr3' wurde unter R Version 4.2.3 erstellt\"\n" ] }, { "data": { "text/html": [ "\n", "
  1. 'LearnerRegrCVGlmnet'
  2. 'LearnerRegr'
  3. 'Learner'
  4. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'LearnerRegrCVGlmnet'\n", "\\item 'LearnerRegr'\n", "\\item 'Learner'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'LearnerRegrCVGlmnet'\n", "2. 'LearnerRegr'\n", "3. 'Learner'\n", "4. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"LearnerRegrCVGlmnet\" \"LearnerRegr\" \"Learner\" \n", "[4] \"R6\" " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "library(mlr3)\n", "library(mlr3learners)\n", "\n", "# suppress messages during fitting\n", "lgr::get_logger(\"mlr3\")$set_threshold(\"warn\")\n", "\n", "learner_lasso = lrn(\"regr.cv_glmnet\", s=\"lambda.min\")\n", "ml_l_lasso = learner_lasso$clone()\n", "ml_m_lasso = learner_lasso$clone()\n", "class(ml_l_lasso)" ] }, { "cell_type": "code", "execution_count": 4, "id": "51ae80df", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/html": [ "\n", "
  1. 'LearnerRegrRanger'
  2. 'LearnerRegr'
  3. 'Learner'
  4. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'LearnerRegrRanger'\n", "\\item 'LearnerRegr'\n", "\\item 'Learner'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'LearnerRegrRanger'\n", "2. 'LearnerRegr'\n", "3. 'Learner'\n", "4. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"LearnerRegrRanger\" \"LearnerRegr\" \"Learner\" \n", "[4] \"R6\" " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Random forest learner for nuisance part ml_l\n", "learner_forest_regr = lrn(\"regr.ranger\",\n", " num.trees=500, mtry=floor(sqrt(dim_x)),\n", " max.depth=5, min.node.size=2)\n", "\n", "# Random forest learner for nuisance part ml_m (binary outcome)\n", "learner_forest_classif = lrn(\"classif.ranger\",\n", " num.trees=500,\n", " mtry=floor(sqrt(dim_x)),\n", " max.depth=5, min.node.size=2)\n", "\n", "ml_l_forest = learner_forest_regr$clone()\n", "ml_m_forest = learner_forest_classif$clone()\n", "class(ml_l_forest)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "28795baa", "metadata": {}, "source": [ "We set up a causal model, here we specify a partially linear model and thereby pass the learner as an input. Let's fit the models." ] }, { "cell_type": "code", "execution_count": 5, "id": "2a231d42", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: y\n", "Treatment variable(s): d\n", "Covariates: X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, X29, X30, X31, X32, X33, X34, X35, X36, X37, X38, X39, X40, X41, X42, X43, X44, X45, X46, X47, X48, X49, X50, X51, X52, X53, X54, X55, X56, X57, X58, X59, X60, X61, X62, X63, X64, X65, X66, X67, X68, X69, X70, X71, X72, X73, X74, X75, X76, X77, X78, X79, X80, X81, X82, X83, X84, X85, X86, X87, X88, X89, X90, X91, X92, X93, X94, X95, X96, X97, X98, X99, X100\n", "Instrument(s): \n", "No. Observations: 500\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.cv_glmnet\n", "ml_m: regr.cv_glmnet\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "d 3.01219 0.04415 68.22 <2e-16 ***\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "set.seed(123)\n", "obj_dml_plr_sim = DoubleMLPLR$new(dml_data_sim,\n", " ml_l=ml_l_lasso,\n", " ml_m=ml_m_lasso)\n", "obj_dml_plr_sim$fit()\n", "print(obj_dml_plr_sim)" ] }, { "cell_type": "code", "execution_count": 6, "id": "fdd9e0f8", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: inuidur1\n", "Treatment variable(s): tg\n", "Covariates: female, black, othrace, dep1, dep2, q2, q3, q4, q5, q6, agelt35, agegt54, durable, lusd, husd\n", "Instrument(s): \n", "No. Observations: 5099\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.ranger\n", "ml_m: classif.ranger\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "tg -0.0765 0.0354 -2.161 0.0307 *\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "set.seed(123)\n", "obj_dml_plr_bonus = DoubleMLPLR$new(dml_data_bonus,\n", " ml_l=ml_l_forest,\n", " ml_m=ml_m_forest)\n", "obj_dml_plr_bonus$fit()\n", "print(obj_dml_plr_bonus)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4da0184d", "metadata": {}, "source": [ "## Set up learners based on `mlr3pipelines`" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c26bbf84", "metadata": {}, "source": [ "These learners can also be constructed using [mlr3pipelines](https://mlr3pipelines.mlr-org.com/). We'll first use the PipeOp constructor [po()](https://mlr3pipelines.mlr-org.com/reference/po.html) to define the learner construction and then initiate a new instance of the [Learner](https://mlr3.mlr-org.com/reference/Learner.html) class. [po()](https://mlr3pipelines.mlr-org.com/reference/po.html) implements a computational step in a pipeline. For more information, we refer to the [Pipelines Chapter in the mlr3book](https://mlr3book.mlr-org.com/pipelines.html)." ] }, { "cell_type": "code", "execution_count": 7, "id": "4b09ebeb", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/html": [ "\n", "
  1. 'GraphLearner'
  2. 'Learner'
  3. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'GraphLearner'\n", "\\item 'Learner'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'GraphLearner'\n", "2. 'Learner'\n", "3. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"GraphLearner\" \"Learner\" \"R6\" " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Lasso learner\n", "library(mlr3pipelines)\n", "pipe_lasso = po(lrn(\"regr.cv_glmnet\"), s = \"lambda.min\")\n", "ml_l_lasso_pipe = as_learner(pipe_lasso)\n", "ml_m_lasso_pipe = as_learner(pipe_lasso)\n", "\n", "# Class of the lasso learner\n", "class(ml_l_lasso_pipe)" ] }, { "cell_type": "code", "execution_count": 8, "id": "52e41460", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/html": [ "\n", "
  1. 'GraphLearner'
  2. 'Learner'
  3. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'GraphLearner'\n", "\\item 'Learner'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'GraphLearner'\n", "2. 'Learner'\n", "3. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"GraphLearner\" \"Learner\" \"R6\" " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
  1. 'GraphLearner'
  2. 'Learner'
  3. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'GraphLearner'\n", "\\item 'Learner'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'GraphLearner'\n", "2. 'Learner'\n", "3. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"GraphLearner\" \"Learner\" \"R6\" " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Random forest learner for nuisance part ml_l\n", "pipe_forest_regr = po(lrn(\"regr.ranger\"),\n", " num.trees=500, mtry=floor(sqrt(dim_x)),\n", " max.depth=5, min.node.size=2)\n", "\n", "# Random forest learner for nuisance part ml_m (binary outcome)\n", "pipe_forest_classif = po(lrn(\"classif.ranger\"),\n", " num.trees=500,\n", " mtry=floor(sqrt(dim_x)),\n", " max.depth=5, min.node.size=2)\n", "\n", "ml_l_forest_pipe = as_learner(pipe_forest_regr)\n", "ml_m_forest_pipe = as_learner(pipe_forest_classif)\n", "\n", "# Class of the random forest learners\n", "class(ml_l_forest_pipe)\n", "class(ml_m_forest_pipe)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "14a6bfb6", "metadata": {}, "source": [ "Let's use these learners to fit the PLR in both examples." ] }, { "cell_type": "code", "execution_count": 9, "id": "bcf86234", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: y\n", "Treatment variable(s): d\n", "Covariates: X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, X29, X30, X31, X32, X33, X34, X35, X36, X37, X38, X39, X40, X41, X42, X43, X44, X45, X46, X47, X48, X49, X50, X51, X52, X53, X54, X55, X56, X57, X58, X59, X60, X61, X62, X63, X64, X65, X66, X67, X68, X69, X70, X71, X72, X73, X74, X75, X76, X77, X78, X79, X80, X81, X82, X83, X84, X85, X86, X87, X88, X89, X90, X91, X92, X93, X94, X95, X96, X97, X98, X99, X100\n", "Instrument(s): \n", "No. Observations: 500\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.cv_glmnet\n", "ml_m: regr.cv_glmnet\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "d 3.01219 0.04415 68.22 <2e-16 ***\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "set.seed(123)\n", "obj_dml_plr_sim_pipe = DoubleMLPLR$new(dml_data_sim,\n", " ml_l=ml_l_lasso_pipe,\n", " ml_m=ml_m_lasso_pipe)\n", "obj_dml_plr_sim_pipe$fit()\n", "print(obj_dml_plr_sim_pipe)" ] }, { "cell_type": "code", "execution_count": 10, "id": "b9decfad", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: inuidur1\n", "Treatment variable(s): tg\n", "Covariates: female, black, othrace, dep1, dep2, q2, q3, q4, q5, q6, agelt35, agegt54, durable, lusd, husd\n", "Instrument(s): \n", "No. Observations: 5099\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.ranger\n", "ml_m: classif.ranger\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "tg -0.0765 0.0354 -2.161 0.0307 *\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "set.seed(123)\n", "obj_dml_plr_bonus_pipe = DoubleMLPLR$new(dml_data_bonus,\n", " ml_l=ml_l_forest_pipe,\n", " ml_m=ml_m_forest_pipe)\n", "obj_dml_plr_bonus_pipe$fit()\n", "print(obj_dml_plr_bonus_pipe)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "952de1c7", "metadata": {}, "source": [ "## Use ensemble learners based on `mlr3pipelines`\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e0ad79d6", "metadata": {}, "source": [ "First, let's see how we can use more complicated [GraphLearner](https://mlr3pipelines.mlr-org.com/reference/mlr_learners_graph.html)s like ensemble learners in [DoubleML](https://docs.doubleml.org/). For example, we want to create a learner with predictions that are generated as an average from three different learners. In the first step, we split up the pipeline into three branches. In our example, the learners estimate the nuisance parts independently of each other. In the last step, we average the predictions. Thereby we will use the pipe operator `%>>%`. For more details, we refer to the [Pipelines Chapter in the mlr3book](https://mlr3book.mlr-org.com/pipelines.html)." ] }, { "cell_type": "code", "execution_count": 11, "id": "dd0a7ab6", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/html": [ "\n", "
  1. 'Graph'
  2. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'Graph'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'Graph'\n", "2. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"Graph\" \"R6\" " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# For regression (nuisance parts with continuous outcome)\n", "graph_ensemble_regr = gunion(list(\n", " po(\"learner\", lrn(\"regr.cv_glmnet\", s = \"lambda.min\")),\n", " po(\"learner\", lrn(\"regr.ranger\")),\n", " po(\"learner\", lrn(\"regr.rpart\", cp = 0.01))\n", " )) %>>%\n", " po(\"regravg\", 3)\n", "\n", "# Class of ' graph_ensemble_regr'\n", "class(graph_ensemble_regr)" ] }, { "cell_type": "code", "execution_count": 12, "id": "d931817d", "metadata": { "tags": [ "nbsphinx-thumbnail" ], "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "Plot with title \"\"" ] }, "metadata": { "image/png": { "height": 420, "width": 420 } }, "output_type": "display_data" } ], "source": [ "# Plot the graph\n", "graph_ensemble_regr$plot()" ] }, { "cell_type": "code", "execution_count": 13, "id": "94f8e41b", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/html": [ "\n", "
  1. 'Graph'
  2. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'Graph'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'Graph'\n", "2. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"Graph\" \"R6\" " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# For classification (nuisance part ml_m in the Bonus example)\n", "graph_ensemble_classif = gunion(list(\n", " po(\"learner\", lrn(\"classif.cv_glmnet\", s = \"lambda.min\")),\n", " po(\"learner\", lrn(\"classif.ranger\")),\n", " po(\"learner\", lrn(\"classif.rpart\", cp = 0.01))\n", " )) %>>%\n", " po(\"classifavg\", 3)\n", "\n", "# Class of 'graph_ensemble_classif'\n", "class(graph_ensemble_classif)" ] }, { "cell_type": "code", "execution_count": 14, "id": "34c3e25a", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "Plot with title \"\"" ] }, "metadata": { "image/png": { "height": 420, "width": 420 } }, "output_type": "display_data" } ], "source": [ "# Plot the graph\n", "graph_ensemble_classif$plot()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ac7308a1", "metadata": {}, "source": [ "We create a new instance of a [GraphLearner](https://mlr3pipelines.mlr-org.com/reference/mlr_learners_graph.html) which is later used in [DoubleML](https://docs.doubleml.org)." ] }, { "cell_type": "code", "execution_count": 15, "id": "e15ae052", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "ensemble_pipe_regr = as_learner(graph_ensemble_regr)\n", "ensemble_pipe_classif = as_learner(graph_ensemble_classif)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "32724fc8", "metadata": {}, "source": [ "Let's estimate the two PLR examples with the ensemble learner." ] }, { "cell_type": "code", "execution_count": 16, "id": "4cfdd413", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: y\n", "Treatment variable(s): d\n", "Covariates: X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, X29, X30, X31, X32, X33, X34, X35, X36, X37, X38, X39, X40, X41, X42, X43, X44, X45, X46, X47, X48, X49, X50, X51, X52, X53, X54, X55, X56, X57, X58, X59, X60, X61, X62, X63, X64, X65, X66, X67, X68, X69, X70, X71, X72, X73, X74, X75, X76, X77, X78, X79, X80, X81, X82, X83, X84, X85, X86, X87, X88, X89, X90, X91, X92, X93, X94, X95, X96, X97, X98, X99, X100\n", "Instrument(s): \n", "No. Observations: 500\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.cv_glmnet.regr.ranger.regr.rpart.regravg\n", "ml_m: regr.cv_glmnet.regr.ranger.regr.rpart.regravg\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "d 3.88664 0.02584 150.4 <2e-16 ***\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "# Initiate new DoubleML object and estimate with graph learner\n", "set.seed(123)\n", "obj_dml_plr_sim_pipe_ensemble = DoubleMLPLR$new(dml_data_sim,\n", " ml_l = ensemble_pipe_regr,\n", " ml_m = ensemble_pipe_regr)\n", "obj_dml_plr_sim_pipe_ensemble$fit()\n", "print(obj_dml_plr_sim_pipe_ensemble)" ] }, { "cell_type": "code", "execution_count": 17, "id": "876aee8e", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: inuidur1\n", "Treatment variable(s): tg\n", "Covariates: female, black, othrace, dep1, dep2, q2, q3, q4, q5, q6, agelt35, agegt54, durable, lusd, husd\n", "Instrument(s): \n", "No. Observations: 5099\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.cv_glmnet.regr.ranger.regr.rpart.regravg\n", "ml_m: classif.cv_glmnet.classif.ranger.classif.rpart.classifavg\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "tg -0.07689 0.03545 -2.169 0.0301 *\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "set.seed(123)\n", "obj_dml_plr_bonus_pipe_ensemble = DoubleMLPLR$new(dml_data_bonus,\n", " ml_l = ensemble_pipe_regr,\n", " ml_m = ensemble_pipe_classif)\n", "obj_dml_plr_bonus_pipe_ensemble$fit()\n", "print(obj_dml_plr_bonus_pipe_ensemble)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "e4ddb48a", "metadata": {}, "source": [ "Alternatively, different learners could also be stacked. Here we simply repeat the example from the [Pipelines Chapter in the mlr3book](https://mlr3book.mlr-org.com/pipelines.html#sec-pipelines-stack) in our Bonus data example." ] }, { "cell_type": "code", "execution_count": 18, "id": "25dc9006", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "lrn = lrn(\"classif.rpart\")\n", "lrn_0 = po(\"learner_cv\", lrn$clone())\n", "lrn_0$id = \"rpart_cv\"" ] }, { "cell_type": "code", "execution_count": 19, "id": "2555c7a5", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "# Pass original features to final estimation step\n", "level_0 = gunion(list(lrn_0, po(\"nop\")))" ] }, { "cell_type": "code", "execution_count": 20, "id": "99422ee6", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "combined = level_0 %>>% po(\"featureunion\", 2)" ] }, { "cell_type": "code", "execution_count": 21, "id": "bd7357b2", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "Plot with title \"\"" ] }, "metadata": { "image/png": { "height": 420, "width": 420 } }, "output_type": "display_data" } ], "source": [ "stack = combined %>>% po(\"learner\", lrn$clone())\n", "stack$plot(html = FALSE)" ] }, { "cell_type": "code", "execution_count": 22, "id": "d0d13d0c", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: inuidur1\n", "Treatment variable(s): tg\n", "Covariates: female, black, othrace, dep1, dep2, q2, q3, q4, q5, q6, agelt35, agegt54, durable, lusd, husd\n", "Instrument(s): \n", "No. Observations: 5099\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.ranger\n", "ml_m: rpart_cv.nop.featureunion.classif.rpart\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "tg -0.07915 0.03538 -2.237 0.0253 *\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "# Create a stacked learner and pass it to a DoubleML object\n", "stacklrn = as_learner(stack)\n", "\n", "set.seed(123)\n", "obj_dml_plr_bonus_pipe = DoubleMLPLR$new(dml_data_bonus,\n", " ml_l=ml_l_forest,\n", " ml_m=stacklrn)\n", "obj_dml_plr_bonus_pipe$fit()\n", "print(obj_dml_plr_bonus_pipe)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6f99764a", "metadata": {}, "source": [ "## How to exploit more features of `mlr3pipelines` in `DoubleML`\n", "\n", "[mlr3pipelines](https://mlr3pipelines.mlr-org.com/reference/Graph.html) can do much more. For example, we could use it to perform some [feature engineering](https://mlr3book.mlr-org.com/pipelines.html#sec-pipelines-combined) and even perform pipeline-based [parameter tuning](https://mlr3book.mlr-org.com/pipelines.html#sec-pipelines-tuning). We just have to define the steps we want to have in our pipeline by using the PipeOps.\n", "\n", "Let's have a look at two more examples from the [Pipelines Chapter in the mlr3book](https://mlr3book.mlr-org.com/pipelines.html). In the first one, we will do some data manipulation. The second example illustrate how we could use [mlr3pipelines](https://mlr3pipelines.mlr-org.com/reference/Graph.html) for parameter tuning." ] }, { "attachments": {}, "cell_type": "markdown", "id": "347a53d1", "metadata": {}, "source": [ "### Data preprocessing\n", "\n", "Let's perform some data preprocessing and then use a regression tree for prediction." ] }, { "cell_type": "code", "execution_count": 23, "id": "2f425d59", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "mutate = po(\"mutate\")\n", "filter = po(\"filter\",\n", " filter = mlr3filters::flt(\"variance\"),\n", " param_vals = list(filter.frac = 0.5))" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5a162782", "metadata": {}, "source": [ "Collect them in a graph and plot it." ] }, { "cell_type": "code", "execution_count": 24, "id": "a795412a", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/html": [ "\n", "
  1. 'Graph'
  2. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'Graph'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'Graph'\n", "2. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"Graph\" \"R6\" " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "graph = mutate %>>%\n", " filter %>>%\n", " po(\"learner\",\n", " learner = lrn(\"classif.rpart\"))\n", "\n", "class(graph)" ] }, { "cell_type": "code", "execution_count": 25, "id": "d2ce3898", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0gAAANICAMAAADKOT/pAAAAflBMVEUAAAAAADgAAEMAAEwAAGEAAHEAAHYAAHoAAH8AAIMAAIcAAItFMIdNTa5eQYNoaLpwTX98fMN+V3qLYHaMjMqWaHGamtGhb2ynp9epqamystyzfGG7gVu9veHDh1THx+XLjEzQ0OnSkUPZljjZ2e3h4fHmnwDp6fXw8Pj////WzieCAAAACXBIWXMAABJ0AAASdAHeZh94AAAVjElEQVR4nO3dC1MiSb7GYXd29+wlQbRtbfuqYgvW9/+Cpwq8ooNGnHeszFPPEzHcbJysiP8vEhDloAP+zw7GXgD8fyAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIKC1kA56Y68BXmhoKoeGDla9AzVRm2YGctvQIylRk0bGcbWT0SalsRcFD9qYxt3tyKZEZVqYxde2I5sSVWlgFP+8IyVRiwYmcU9HSqIS9Q/i3o6URB2qn8M3OlISVah+DIVEC2ofwzc7UhI1qHwK39HRqvZjYAoqH8L3hGRLYnx1DmG5M/9jU0p5dHh6fX/92e1bY6+byaozpO62lNvu9rz8Z7vnLMtyOOnbuT6abS8u728vJwdDef2dzoTEWCoM6fbi+K6M/vRoW1LZbkz9ya9ycn9xe3u5vg/pthxf3I68eCaqupCuzhbfbh9C+lROd0Lanj4J6XJ4tWH7z2/W3xZnV+Oun2mqK6Tf54vzm82lu5BWR+XLGyFtXm14fHo0fIvfIy2f6aoopGfbyX1Iy1n5sfPQ7mhvSN12U1t/+PKZtIpCKosn038fUh9OuXwI6cfq+qj8eiukvsmF1x34UBWF9OqOtFp9L7PlfUi9o1+r3ZA6OxJjqyik7ukTnMeQVqfl8OlzpNX+kG48R2IEdYXUPWwndyFt3tlwVI52Qzos15vzT3fvbLgL6dardoyjupCe/xxp+xahWdkN6XTzEkT/uO95SH6OxEgqDGmweWdDty5/3L2xYUjo+u7tDIPr2exytfp1cr0NaV2KJ0WMqc6Qnrx9btPNj4c31t2XtDyZldnp9fYpkjfaMbY6Q3pw/+7vh63oFbUfA1NQ+xC+/XsU1R8CU1D9FPpVc1pQ/Ri+FVL9R8AU1D+G+0tq4ACYggbm0B+IpH4NDOK+kFpYP1PQwiC+/lkUm4xaWD5T0MYkvl5SI4tnChqZxVc2JdsRFWlmGHdSkhFVaWgc7z+MufNhzFSntYHUEFUylRAgJAhoL6SvYy8AXhISBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgS0E9LXXWMvCB61E9JuSWMvB55oNqSxVwNPNRRSpyOqJSQIaCmkTkfUSkgQ0GZIYy8EdjQVUickKtVWSJ2QqFOTIY29CtjVWEidkKiSkCCgxZDGXgO80FpInZCokZAgQEgQICQIEBIECAkCGgxp7BXAS82F5K8IUSMhQYCQIEBIENBaSAe9sdcALzQ0lUNDB6vegZqoTTMDuW3okZSoSSPjuNrJaJPS2IuCB21M4+52ZFOiMi3M4mvbkU2JqjQwin/ekZKoRQOTuKcjJVGJ+gdxb0dKog7Vz+EbHSmJKlQ/hkKiBbWP4ZsdKYkaVD6F7+hoVfsxMAWVD+F7QrIlMb7Kh1BItKHyIRQSbah8CF+EdF1eXqv8GJiCuofw5Yb0vby8VvlBMAV1z+CLkH6U8sq1ug+CKah1Btfn8+5b+c/l6rQcXq7KUMxwUsrm4umslJPl/bXlv8r859gLZtpqDWleylV3UQ6/ry7L4Wo59HK5ranfgY7Kst+NDu+uLWf/7H72/xrGU2tIXSmbk9W2lucXjg6fXjst/UGU47EXzKQ1GdLwMsPRw7XZ5hFeGXvBTFqbIX2ZnS6f3Dz2WqHJkE7L92c3/x57sUxekyE9v3ZSjm+79WLsBTNptYa0LuV2OFkOr9gtV4fl1+p7KYf9E6Llp/7aj+Fa/whvuLbePEW6GXvFTFqtIW1ePij3PzpaXc7Kl1X/xGj1oxwth2un/Ub0abm51v39uJytx14w01ZrSPfeftdq9YfAFFQ/hX7VnBZUP4ZvhVT/ETAF9Y/h/pIaOACmoIE59AciqV8Dg7gvpBbWzxS0MIivfxbFJqMWls8UtDGJr5fUyOKZgkZm8ZVNyXZERZoZxp2UZERVGhrH+w9j7nwYM9VpbSA1RJVMJQQICQKEBAHthfR17AXAS0KCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAHthPR119gLgkfthjT2euCJdkLaLWns5cATzYY09mrgqYZC6nREtRoNaeylwHMthdQJiVoJCQKaCqnTEZUSEgS0FVInJOrUZEhjrwJ2NRZSJySqJCQIaDGksdcAL7QWUickaiQkCBASBAgJAoQEAUKCgOZC8leEqJGQIEBIECAkCGgtpIPe2GuAFxqayqGhg1XvQE3UppmB3Db0SErUpJFxXO1ktElp7EXBgzamcXc7silRmRZm8bXtyKZEVRoYxT/vSEnUooFJ3NORkqhE/YO4tyMlUYfq5/CNjpREFaofQyHRgtrH8M2OlEQNKp/Cd3S0qv0YmILKh/A9IdmSGF/lQygk2lD5ED4L6ehISFSq8iEUEm2oewjf9cjOqw2Mr+4ZfF9ItiRGV/cMColGVDaDP0vvpuv60647n5dyslwtv8xWX8rJank6G6o5nW1vPR1unV32t1z/s5Sz4d7rz2X+c+xDYJIqC6m7OSvD2e3nrjsu64Mf5XDVh/P9+uRoOO+rOSrL1f2tq8v+wmo1/3fXLY77jubnfYlXYx8CU1RbSF03P77tum/9f8eL/qHdEE/p2xlsQzq8u7S5Npyclj+67mLe72BDg+V47ANgiuoL6Vv53XWft5cv/v2QzOrx/PvRs5Bm5e6e87Ix2sqZsPpCui0/u5+bx2ff5ud/vAzpy+x0+Sykh3Q0xGjqC6k7+9xtHp6dl4vu4EVIp/1To9XOjrTe3rEMexmMocKQbuYXm5fehg3mZUhPNqK7C2fD06Kb/sFgf+G2Wy/GXj9TVGFI3WLbwqJcdf/oSzm5C2m5edHhsPxYfd/eer26Hm5a/7d/YrTod6V1uXvxHD5ajSFdnW/ObublvPuf8qmPY7bdkIaiLmfltK/o0+ba5rWFg/VxWWyeVN0cl7P1qEtnqmoM6Rm/ak4Lqh/Dt0Kq/wiYgvrHcH9JDRwAU9DAHPoDkdSvgUHcF1IL62cKWhjE1z+LYpNRC8tnCtqYxNdLamTxTEEjs/jKpmQ7oiLNDONOSjKiKg2N4/2HMXc+jJnqtDaQGqJKphIChAQBQoKA9kL6OvYC4CUhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoKAdkL6umvsBcGjdkMaez3wRDsh7ZY09nLgiWZDGns18FRDIXU6olqNhjT2UuC5lkLqhESthAQBTYXU6YhKCQkC2gqpExJ1ajKksVcBuxoLqRMSVRISBLQY0thrgBdaC6kTEjUSEgQICQKEBAFCggAhQUCDIY29AnipuZD8OS5qJCQIEBIECAkCWgvpoDf2GuCFhqZyaOhg1TtQE7VpZiC3DT2SEjVpZBxXOxltUhp7UfCgjWnc3Y5sSlSmhVl8bTuyKVGVBkbxzztSErVoYBL3dKQkKlH/IO7tSEnUofo5fKMjJVGF6sdQSLSg9jF8syMlUYOapnB9Pt+96VlHR0ebs+uTUo5+Pd5c1TEwUTUN4aKU3ZteCWk5+766PDy0JVGTqobwjZC2Tsq7H9vdvvh+8NdoL6TZ+0O6EBIfpL2QyrtDunr5/eCvUUNIt+elnA0XtoN/Pu+vrvsLn8v/LIdqPpWT4Xx5OhsyGmz6WX6Zrb70X/oyPzgr86snd11/m3ffytnmn454XExIDSHNj7tu0f+3Dem4rPu9ZNF3dNP97ahv5tPl6no4n20DetiR+uvfr0+O+lj+0d0uyu/Hu/Y5XdyeHb+yw8Ffo4KQzvvx7y6GV763IS3uLvU3Hww7UVn2u095SOjxoV0ZvjKcHQwP4z4/vevwLV97qAh/jQpCmj+M+/3gXxwPlxbl/G9DKIfl9PrJs6OnIa0eQrq/7/aud1eExEepIKSyE9K3+fl6uPS7f4T2qQ/l12x7/o6Q7u8qJD5YBSHNt4/DurvBPy8X9wlc/Gdb0PfDzfn+kBZP7iokPlgFIZ2V46676Z/ibAf/4aTcdt0/hudI/QO7768/R7o76+/wu3x7elch8bEqCGndP4Qri35XWm9eI1iUq+6iT+BbOb7p/t7vROXocnX5aXjBYXhx4Ucpv7bvFlo+hHTe3R4PL1bc3/XsrqB+r/s87rExFRWE1K2Py2L4MdD2xz438z6Ms/J53VdQ/jWEsvx098iuL+fu50hDSP35bBvS1byc9dvXw137r2ze/npVjtf7/9eQUUNIe7zjtyhWHsAxPiFBQOUhvaOkdSkevzG22kN6uyRvqKMCzYdU/xEwBfWP4f6SGjgApqCBOfQHIqlfA4O4L6QW1s8UtDCIr38WxSajFpbPFLQxia+X1MjimYJGZvGVTcl2REWaGcadlGREVRoax/sPY+58GDPVaW0gNUSVTCUECAkC2gvp69gLgJeEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQB7YT0ddfYC4JH7YS0W9LYy4Enmg1p7NXAUw2F1OmIagkJAloKqdMRtRISBLQZ0tgLgR1NhdQJiUq1FVInJOrUZEhjrwJ2NRZSJySqJCQIaDGksdcAL7QWUickaiQkCBASBAgJAoQEAUKCgAZDGnsF8FJzIfkrQtRISBAgJAgQEgS0FtJBb+w1wAsNTeXQ0MGqd6AmatPMQG4beiQlatLIOK52MtqkNPai4EEb07i7HdmUqEwLs/jadmRToioNjOKfd6QkatHAJO7pSElUov5B3NuRkqhD9XP4RkdKogrVj6GQaEHtY/hmR0qiBpVP4Ts6WtV+DExB5UP4npBsSYyv8iEUEm2obQivjks5/jlcKhurVXnd4en1/dcOuhdfHfswmJrKQjo+/t11v4+Ph8vrUtbDjrQsy1U5WQ3Z9Ccnpb+hv3B9NFsOX+u/eDD823VXzrohv/7kTEh8sGpCur3o6zmeb6/MNyX1e83moduwLV3fh3Rd7i78eohrOIg+nXLb3YV0W44vbsc7FCaokpCuzhbfbrur8nN79We56p6FdLm6D2m4uL2wOd2cbEO66bq7kLqb9bfF2dVoR8P01BDS7/PF+aaCz/0DtI2b8rl7FtLqMaTVn4W08fj0aPiuvz/6UJiq0UN6unc8VrC5tD+kX+Vob0jddp9bf9SBMGmjh1QWj6P+/pB+rK6Pyq/nz5F2vsXGeuF1Bz7C6CH9yY407/aG1Dv6tXorJDsSH2X0kLonz2Y+l5vtLev7l7IP9jxHerg4HMPLkG48R+Lj1BBSd793XJVv26sXm1ftFuV2E9KnPSEdlutNSJ/vvtFdSLdeteNDVRLSw8+Rtj/+mW+6OC9Xm5C+7wnptH+6NBzDxd23uQvJz5H4WNWEtHW8GN7ZsNj8PLa7nc//vlr9OrneNrN5F8Pg+v7CcHk2uznofp/ddTO8G2LE5TNZlYXU/Rzea3f/oGx99t8yO71+8gpDeXphm9fJvMzP7zryRjtGUltIu95++3f1h8AUVD+FftWcFlQ/hm+FVP8RMAX1j+H+kho4AKaggTn0ByKpXwODuC+kFtbPFLQwiK9/FsUmoxaWzxS0MYmvl9TI4pmCRmbxlU3JdkRFmhnGnZRkRFUaGsf7D2PufBgz1WltIDVElUwlBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAoQEAUKCACFBgJAgQEgQICQIEBIECAkChAQBQoIAIUGAkCBASBAgJAgQEgQICQKEBAFCggAhQYCQIEBIECAkCBASBAgJAv4XY/YbR/l6V40AAAAASUVORK5CYII=", "text/plain": [ "Plot with title \"\"" ] }, "metadata": { "image/png": { "height": 420, "width": 420 } }, "output_type": "display_data" } ], "source": [ "graph$plot()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "11860a7e", "metadata": {}, "source": [ "Create a new learner." ] }, { "cell_type": "code", "execution_count": 26, "id": "5e449f23", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "glrn = as_learner(graph)" ] }, { "cell_type": "code", "execution_count": 27, "id": "909193b3", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "data": { "text/plain": [ "\n", "* Model: -\n", "* Parameters: mutate.mutation=, mutate.delete_originals=FALSE,\n", " variance.filter.frac=0.5, classif.rpart.xval=0\n", "* Packages: mlr3, mlr3pipelines, rpart\n", "* Predict Types: [response], prob\n", "* Feature Types: logical, integer, numeric, character, factor, ordered,\n", " POSIXct\n", "* Properties: featureless, hotstart_backward, hotstart_forward,\n", " importance, loglik, missings, multiclass, oob_error,\n", " selected_features, twoclass, weights" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
  1. 'GraphLearner'
  2. 'Learner'
  3. 'R6'
\n" ], "text/latex": [ "\\begin{enumerate*}\n", "\\item 'GraphLearner'\n", "\\item 'Learner'\n", "\\item 'R6'\n", "\\end{enumerate*}\n" ], "text/markdown": [ "1. 'GraphLearner'\n", "2. 'Learner'\n", "3. 'R6'\n", "\n", "\n" ], "text/plain": [ "[1] \"GraphLearner\" \"Learner\" \"R6\" " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "glrn\n", "class(glrn)" ] }, { "cell_type": "code", "execution_count": 28, "id": "96a281d1", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: inuidur1\n", "Treatment variable(s): tg\n", "Covariates: female, black, othrace, dep1, dep2, q2, q3, q4, q5, q6, agelt35, agegt54, durable, lusd, husd\n", "Instrument(s): \n", "No. Observations: 5099\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.cv_glmnet\n", "ml_m: mutate.variance.classif.rpart\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "tg -0.07366 0.03539 -2.081 0.0374 *\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "set.seed(123)\n", "obj_dml_plr_bonus_pipe2 = DoubleMLPLR$new(dml_data_bonus,\n", " ml_l=ml_l_lasso,\n", " ml_m=glrn)\n", "obj_dml_plr_bonus_pipe2$fit()\n", "print(obj_dml_plr_bonus_pipe2)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "1d552c2f", "metadata": {}, "source": [ "Let's see how to set hyperparameters with a pipeline." ] }, { "cell_type": "code", "execution_count": 29, "id": "27513f2e", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "glrn$param_set$values$variance.filter.frac = 0.25" ] }, { "cell_type": "code", "execution_count": 30, "id": "3f35dca9", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: inuidur1\n", "Treatment variable(s): tg\n", "Covariates: female, black, othrace, dep1, dep2, q2, q3, q4, q5, q6, agelt35, agegt54, durable, lusd, husd\n", "Instrument(s): \n", "No. Observations: 5099\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.cv_glmnet\n", "ml_m: mutate.variance.classif.rpart\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "tg -0.07366 0.03539 -2.081 0.0374 *\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "set.seed(123)\n", "obj_dml_plr_bonus_pipe3 = DoubleMLPLR$new(dml_data_bonus,\n", " ml_l=ml_l_lasso,\n", " ml_m=glrn)\n", "obj_dml_plr_bonus_pipe3$fit()\n", "print(obj_dml_plr_bonus_pipe3)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2f90db3e", "metadata": {}, "source": [ "### Parameter tuning\n", "\n", "Next, we will shortly illustrate how to perform parameter tuning in the simulated data example. Here, we use a pipeline to tune the penalty of the lasso. To do so, we generate a [GraphLearner](https://mlr3pipelines.mlr-org.com/reference/mlr_learners_graph.html) and then call [DoubleML](https://docs.doubleml.org)'s [tune()](https://docs.doubleml.org/r/stable/reference/DoubleML.html#method-tune-) method." ] }, { "attachments": {}, "cell_type": "markdown", "id": "e1cf206d", "metadata": {}, "source": [ "Let's define a [GraphLearner](https://mlr3pipelines.mlr-org.com/reference/mlr_learners_graph.html) based on the lasso." ] }, { "cell_type": "code", "execution_count": 31, "id": "935daf2f", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "lasso_pipe = mutate %>>%\n", " filter %>>%\n", " po(\"learner\",\n", " learner = lrn(\"regr.glmnet\"))\n", "glrn_lasso = as_learner(lasso_pipe)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4690daa8", "metadata": {}, "source": [ "Let's specify the parameter grid and more settings that are required for the parameter tuning. For more details, we refer to [the DoubleML user guide](https://docs.doubleml.org/stable/guide/learners.html#r-learners-and-hyperparameters). " ] }, { "cell_type": "code", "execution_count": 32, "id": "b9a50aa8", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Warning message:\n", "\"Paket 'paradox' wurde unter R Version 4.2.3 erstellt\"\n" ] } ], "source": [ "# Parameter grid for lambda and for optimal variance filter fraction\n", "library(paradox)\n", "par_grids = ps(regr.glmnet.lambda = p_dbl(lower = 0.05, upper = 0.1),\n", " variance.filter.frac = p_dbl(lower = 0.25, upper = 1))" ] }, { "cell_type": "code", "execution_count": 33, "id": "4ec8264d", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [], "source": [ "# Specify further tune settings\n", "library(mlr3tuning)\n", "tune_settings = list(terminator = trm(\"evals\", n_evals = 10),\n", " algorithm = tnr(\"grid_search\", resolution = 10),\n", " rsmp_tune = rsmp(\"cv\", folds = 5),\n", " measure = list(\"ml_l\" = msr(\"regr.mse\"),\n", " \"ml_m\" = msr(\"regr.mse\")))" ] }, { "cell_type": "code", "execution_count": 34, "id": "4885bebd", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO [12:57:48.469] [bbotk] Starting to optimize 2 parameter(s) with '' and ' [n_evals=10, k=0]'\n", "INFO [12:57:48.485] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:49.133] [bbotk] Result of batch 1:\n", "INFO [12:57:49.135] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:49.135] [bbotk] 0.07222222 0.6666667 107.2925 0 0\n", "INFO [12:57:49.135] [bbotk] runtime_learners uhash\n", "INFO [12:57:49.135] [bbotk] 0.38 5574dcd4-fc9e-463b-9345-d21ee5775b5f\n", "INFO [12:57:49.137] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:49.514] [bbotk] Result of batch 2:\n", "INFO [12:57:49.516] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:49.516] [bbotk] 0.1 0.6666667 106.2652 0 0\n", "INFO [12:57:49.516] [bbotk] runtime_learners uhash\n", "INFO [12:57:49.516] [bbotk] 0.32 8497f641-2700-4a53-ae89-5cb31a99b9cc\n", "INFO [12:57:49.517] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:49.892] [bbotk] Result of batch 3:\n", "INFO [12:57:49.893] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:49.893] [bbotk] 0.1 0.4166667 482.9753 0 0\n", "INFO [12:57:49.893] [bbotk] runtime_learners uhash\n", "INFO [12:57:49.893] [bbotk] 0.33 f3d24993-a09b-432f-ab71-44fa97767be8\n", "INFO [12:57:49.894] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:50.273] [bbotk] Result of batch 4:\n", "INFO [12:57:50.274] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:50.274] [bbotk] 0.09444444 0.75 112.1054 0 0\n", "INFO [12:57:50.274] [bbotk] runtime_learners uhash\n", "INFO [12:57:50.274] [bbotk] 0.3 bb2913dc-3cd0-4b8f-a09a-303f00f0bd62\n", "INFO [12:57:50.276] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:50.657] [bbotk] Result of batch 5:\n", "INFO [12:57:50.659] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:50.659] [bbotk] 0.08333333 0.9166667 10.81856 0 0\n", "INFO [12:57:50.659] [bbotk] runtime_learners uhash\n", "INFO [12:57:50.659] [bbotk] 0.28 8da924ce-f2e7-4dd2-a840-b5d34a6f42be\n", "INFO [12:57:50.660] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:51.051] [bbotk] Result of batch 6:\n", "INFO [12:57:51.053] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:51.053] [bbotk] 0.08888889 0.9166667 10.74189 0 0\n", "INFO [12:57:51.053] [bbotk] runtime_learners uhash\n", "INFO [12:57:51.053] [bbotk] 0.29 0434e374-ddc9-468d-b208-dc13a11076b3\n", "INFO [12:57:51.054] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:51.433] [bbotk] Result of batch 7:\n", "INFO [12:57:51.435] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:51.435] [bbotk] 0.07777778 0.5833333 201.8513 0 0\n", "INFO [12:57:51.435] [bbotk] runtime_learners uhash\n", "INFO [12:57:51.435] [bbotk] 0.26 280454dd-5804-498f-80a8-24080030a4de\n", "INFO [12:57:51.436] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:51.868] [bbotk] Result of batch 8:\n", "INFO [12:57:51.869] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:51.869] [bbotk] 0.07222222 0.3333333 531.5255 0 0\n", "INFO [12:57:51.869] [bbotk] runtime_learners uhash\n", "INFO [12:57:51.869] [bbotk] 0.32 7b428990-305b-47be-bde4-7215093d9089\n", "INFO [12:57:51.871] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:52.316] [bbotk] Result of batch 9:\n", "INFO [12:57:52.318] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:52.318] [bbotk] 0.06111111 0.9166667 11.17823 0 0\n", "INFO [12:57:52.318] [bbotk] runtime_learners uhash\n", "INFO [12:57:52.318] [bbotk] 0.34 bd929a9e-3e1c-4fee-ae56-f00584a57972\n", "INFO [12:57:52.319] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:52.773] [bbotk] Result of batch 10:\n", "INFO [12:57:52.775] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:52.775] [bbotk] 0.08888889 0.25 538.991 0 0\n", "INFO [12:57:52.775] [bbotk] runtime_learners uhash\n", "INFO [12:57:52.775] [bbotk] 0.38 e20ea26e-a6ba-4539-a3d9-0005a80b528f\n", "INFO [12:57:52.780] [bbotk] Finished optimizing after 10 evaluation(s)\n", "INFO [12:57:52.780] [bbotk] Result:\n", "INFO [12:57:52.782] [bbotk] regr.glmnet.lambda variance.filter.frac learner_param_vals x_domain regr.mse\n", "INFO [12:57:52.782] [bbotk] 0.08888889 0.9166667 10.74189\n", "INFO [12:57:52.869] [bbotk] Starting to optimize 2 parameter(s) with '' and ' [n_evals=10, k=0]'\n", "INFO [12:57:52.871] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:53.298] [bbotk] Result of batch 1:\n", "INFO [12:57:53.299] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:53.299] [bbotk] 0.08888889 0.5833333 12.6173 0 0\n", "INFO [12:57:53.299] [bbotk] runtime_learners uhash\n", "INFO [12:57:53.299] [bbotk] 0.35 67ad635a-5346-41e5-b5d7-a79359d2da46\n", "INFO [12:57:53.301] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:53.723] [bbotk] Result of batch 2:\n", "INFO [12:57:53.725] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:53.725] [bbotk] 0.1 0.5833333 12.5155 0 0\n", "INFO [12:57:53.725] [bbotk] runtime_learners uhash\n", "INFO [12:57:53.725] [bbotk] 0.34 921e4f0d-e57c-4552-a5e6-8bdee1a1d83d\n", "INFO [12:57:53.726] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:54.130] [bbotk] Result of batch 3:\n", "INFO [12:57:54.132] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:54.132] [bbotk] 0.08333333 0.75 0.9870004 0 0\n", "INFO [12:57:54.132] [bbotk] runtime_learners uhash\n", "INFO [12:57:54.132] [bbotk] 0.3 26bd56a6-fd8a-4dba-9109-0ff823b17d45\n", "INFO [12:57:54.133] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:54.574] [bbotk] Result of batch 4:\n", "INFO [12:57:54.576] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:54.576] [bbotk] 0.07222222 0.5 17.75887 0 0\n", "INFO [12:57:54.576] [bbotk] runtime_learners uhash\n", "INFO [12:57:54.576] [bbotk] 0.34 fb5c25fa-1596-49d4-b371-d0cdb0ea4795\n", "INFO [12:57:54.577] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:55.015] [bbotk] Result of batch 5:\n", "INFO [12:57:55.017] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:55.017] [bbotk] 0.08333333 0.5833333 12.6722 0 0\n", "INFO [12:57:55.017] [bbotk] runtime_learners uhash\n", "INFO [12:57:55.017] [bbotk] 0.33 ee97bda7-6cea-440a-9248-d5a0c70f1d98\n", "INFO [12:57:55.018] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:55.452] [bbotk] Result of batch 6:\n", "INFO [12:57:55.453] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:55.453] [bbotk] 0.07222222 0.5833333 12.78818 0 0\n", "INFO [12:57:55.453] [bbotk] runtime_learners uhash\n", "INFO [12:57:55.453] [bbotk] 0.29 8e3aa840-c895-472e-85c5-681817dcfcda\n", "INFO [12:57:55.455] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:55.913] [bbotk] Result of batch 7:\n", "INFO [12:57:55.915] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:55.915] [bbotk] 0.07777778 0.5 17.69921 0 0\n", "INFO [12:57:55.915] [bbotk] runtime_learners uhash\n", "INFO [12:57:55.915] [bbotk] 0.36 4552b8af-3647-43f0-a5e7-55dc37e31fb1\n", "INFO [12:57:55.917] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:56.386] [bbotk] Result of batch 8:\n", "INFO [12:57:56.387] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:56.387] [bbotk] 0.07777778 0.9166667 0.9880384 0 0\n", "INFO [12:57:56.387] [bbotk] runtime_learners uhash\n", "INFO [12:57:56.387] [bbotk] 0.31 abb0fd28-0359-4ecd-8644-f1718fdeb9b0\n", "INFO [12:57:56.389] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:56.783] [bbotk] Result of batch 9:\n", "INFO [12:57:56.785] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:56.785] [bbotk] 0.07777778 0.4166667 22.43294 0 0\n", "INFO [12:57:56.785] [bbotk] runtime_learners uhash\n", "INFO [12:57:56.785] [bbotk] 0.26 cda85647-3ec2-4849-88ad-193f0d909729\n", "INFO [12:57:56.786] [bbotk] Evaluating 1 configuration(s)\n", "INFO [12:57:57.183] [bbotk] Result of batch 10:\n", "INFO [12:57:57.185] [bbotk] regr.glmnet.lambda variance.filter.frac regr.mse warnings errors\n", "INFO [12:57:57.185] [bbotk] 0.07222222 0.3333333 27.31378 0 0\n", "INFO [12:57:57.185] [bbotk] runtime_learners uhash\n", "INFO [12:57:57.185] [bbotk] 0.31 caac5a95-4462-42ba-99c8-ca1af7be64b2\n", "INFO [12:57:57.189] [bbotk] Finished optimizing after 10 evaluation(s)\n", "INFO [12:57:57.190] [bbotk] Result:\n", "INFO [12:57:57.191] [bbotk] regr.glmnet.lambda variance.filter.frac learner_param_vals x_domain regr.mse\n", "INFO [12:57:57.191] [bbotk] 0.08333333 0.75 0.9870004\n" ] } ], "source": [ "# Initiate new DoubleML object and execute tuning with graph learner\n", "set.seed(123)\n", "obj_dml_plr_sim_pipe_tune = DoubleMLPLR$new(dml_data_sim,\n", " ml_l=glrn_lasso,\n", " ml_m=glrn_lasso)\n", "obj_dml_plr_sim_pipe_tune$tune(param_set = list(\"ml_l\" = par_grids,\n", " \"ml_m\" = par_grids),\n", " tune_settings=tune_settings)" ] }, { "cell_type": "code", "execution_count": 35, "id": "839a8c99", "metadata": { "vscode": { "languageId": "r" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================= DoubleMLPLR Object ==================\n", "\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: y\n", "Treatment variable(s): d\n", "Covariates: X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, X16, X17, X18, X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, X29, X30, X31, X32, X33, X34, X35, X36, X37, X38, X39, X40, X41, X42, X43, X44, X45, X46, X47, X48, X49, X50, X51, X52, X53, X54, X55, X56, X57, X58, X59, X60, X61, X62, X63, X64, X65, X66, X67, X68, X69, X70, X71, X72, X73, X74, X75, X76, X77, X78, X79, X80, X81, X82, X83, X84, X85, X86, X87, X88, X89, X90, X91, X92, X93, X94, X95, X96, X97, X98, X99, X100\n", "Instrument(s): \n", "No. Observations: 500\n", "\n", "------------------ Score & algorithm ------------------\n", "Score function: partialling out\n", "DML algorithm: dml2\n", "\n", "------------------ Machine learner ------------------\n", "ml_l: regr.cv_glmnet\n", "ml_m: regr.cv_glmnet\n", "\n", "------------------ Resampling ------------------\n", "No. folds: 5\n", "No. repeated sample splits: 1\n", "Apply cross-fitting: TRUE\n", "\n", "------------------ Fit summary ------------------\n", " Estimates and significance testing of the effect of target variables\n", " Estimate. Std. Error t value Pr(>|t|) \n", "d 3.00133 0.04424 67.84 <2e-16 ***\n", "---\n", "Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n", "\n", "\n" ] } ], "source": [ "obj_dml_plr_sim_pipe$fit()\n", "print(obj_dml_plr_sim_pipe)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b562abbe", "metadata": {}, "source": [ "## References\n", "\n", "Becker, M., Binder, M., Bischl, B., Lang, M., Pfisterer, F., Reich, N.G., Richter, J., Schratz, P., Sonabend, R. (2020), mlr3 book, available at https://mlr3book.mlr-org.com.\n", "\n", "Binder, M., Pfisterer, F., Lang, M., Schneider, L., Kotthof, L., and Bischl, B. (2021), mlr3pipelines - flexible machine learning pipelines in R, Journal of Machine Learning Research, 22(184): 1-7, https://jmlr.org/papers/v22/21-0281.html.\n", "\n", "Lang, M., Binder, M., Richter, J., Schratz, P., Pfisterer, F., Coors, S., Au, Q., Casalicchio, G., Kotthoff, L., Bischl, B. (2019), mlr3: A modern object-oriented machine learing framework in R. Journal of Open Source Software, [doi:10.21105/joss.01903](https://doi.org/10.21105/joss.01903).\n", "\n", "Lang, M., Au, Q., Coors, S., and Schratz, P. (2021), mlr3learners: Recommended learners for mlr3, R package, https://CRAN.R-project.org/package=mlr3learners.\n", "\n", "Sonabend, R., and Schratz, P. (2021), extralearners: Extra learners for mlr3. R package, https://mlr3extralearners.mlr-org.com/." ] }, { "attachments": {}, "cell_type": "markdown", "id": "0ba7fae8", "metadata": {}, "source": [ "______\n", "\n", "**Acknowledgement**\n", "\n", "We would like to thank the developers of the [mlr3pipelines](https://mlr3pipelines.mlr-org.com/) package for providing such a powerful and easy-to-use implementation of many important ML tools." ] } ], "metadata": { "kernelspec": { "display_name": "R", "language": "R", "name": "ir" }, "language_info": { "codemirror_mode": "r", "file_extension": ".r", "mimetype": "text/x-r-source", "name": "R", "pygments_lexer": "r", "version": "4.2.2" } }, "nbformat": 4, "nbformat_minor": 5 }