{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DoubleML meets FLAML - How to tune learners automatically within `DoubleML`\n", "\n", "Recent advances in automated machine learning make it easier to tune hyperparameters of ML estimators automatically. These optimized learners can be used for the estimation part within DoubleML. In this notebook we are going to explore how to tune learners with AutoML for the DoubleML framework.\n", "\n", "This notebook will use [FLAML](https://github.com/microsoft/FLAML), but there are also many other AutoML frameworks. Particularly useful for DoubleML are packages that provide some way to export the model in `sklearn`-style.\n", "\n", "Examples are: [TPOT](https://epistasislab.github.io/tpot/), [autosklearn](https://automl.github.io/auto-sklearn/master/), [H20](https://docs.h2o.ai/h2o/latest-stable/h2o-docs/automl.html) or [Gama](https://openml-labs.github.io/gama/master/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Generation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create synthetic data using the [make_plr_CCDDHNR2018()](https://docs.doubleml.org/stable/api/generated/doubleml.datasets.make_plr_CCDDHNR2018.html) process, with $1000$ observations of $50$ covariates as well as $1$ treatment variable and an outcome. We calibrate the process such that hyperparameter tuning becomes more important." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | X1 | \n", "X2 | \n", "X3 | \n", "X4 | \n", "X5 | \n", "X6 | \n", "X7 | \n", "X8 | \n", "X9 | \n", "X10 | \n", "... | \n", "X43 | \n", "X44 | \n", "X45 | \n", "X46 | \n", "X47 | \n", "X48 | \n", "X49 | \n", "X50 | \n", "y | \n", "d | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1.065368 | \n", "1.162593 | \n", "1.089964 | \n", "0.824657 | \n", "0.157733 | \n", "-1.228404 | \n", "-0.675775 | \n", "-0.223928 | \n", "0.166238 | \n", "0.124480 | \n", "... | \n", "-2.021823 | \n", "-1.662975 | \n", "-2.100385 | \n", "-1.225670 | \n", "-1.223158 | \n", "0.397536 | \n", "-0.450031 | \n", "0.511257 | \n", "0.845534 | \n", "-0.784792 | \n", "
1 | \n", "0.214458 | \n", "1.699616 | \n", "3.222882 | \n", "3.550242 | \n", "2.692460 | \n", "1.821970 | \n", "1.223617 | \n", "-0.100154 | \n", "-0.234431 | \n", "0.375844 | \n", "... | \n", "-0.695711 | \n", "-0.819507 | \n", "-1.465424 | \n", "-0.341472 | \n", "-0.023537 | \n", "0.436016 | \n", "-0.503374 | \n", "-1.342632 | \n", "1.987307 | \n", "0.835035 | \n", "
2 | \n", "0.725820 | \n", "-0.310145 | \n", "-0.586921 | \n", "-0.879058 | \n", "0.239267 | \n", "0.638461 | \n", "0.131024 | \n", "0.459436 | \n", "-1.140081 | \n", "-0.583692 | \n", "... | \n", "-0.002388 | \n", "0.716801 | \n", "0.075942 | \n", "1.439958 | \n", "0.674747 | \n", "-0.268343 | \n", "0.682122 | \n", "0.978303 | \n", "0.154890 | \n", "-0.168089 | \n", "
3 | \n", "0.265744 | \n", "0.479655 | \n", "0.013313 | \n", "1.417736 | \n", "0.908767 | \n", "1.786090 | \n", "0.996892 | \n", "-0.026822 | \n", "-0.867201 | \n", "0.433753 | \n", "... | \n", "-0.482616 | \n", "-0.172628 | \n", "-0.309539 | \n", "-0.609522 | \n", "-0.830263 | \n", "-0.883953 | \n", "-1.249986 | \n", "-2.688641 | \n", "1.254035 | \n", "0.161288 | \n", "
4 | \n", "1.581827 | \n", "0.926901 | \n", "2.302382 | \n", "0.803112 | \n", "-0.152896 | \n", "-0.389164 | \n", "-0.569590 | \n", "-0.124306 | \n", "0.055439 | \n", "-0.383531 | \n", "... | \n", "0.048220 | \n", "-0.698751 | \n", "-0.754678 | \n", "-0.689600 | \n", "0.726658 | \n", "0.780068 | \n", "1.475517 | \n", "0.777718 | \n", "1.773769 | \n", "1.786563 | \n", "
5 rows × 52 columns
\n", "\n", " | \n", " | coef | \n", "std err | \n", "t | \n", "P>|t| | \n", "2.5 % | \n", "97.5 % | \n", "
---|---|---|---|---|---|---|---|
Model Type | \n", "Metric | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
Full Sample | \n", "d | \n", "0.498286 | \n", "0.032738 | \n", "15.220407 | \n", "2.589147e-52 | \n", "0.434121 | \n", "0.562452 | \n", "
On the folds | \n", "d | \n", "0.502016 | \n", "0.033265 | \n", "15.091263 | \n", "1.848688e-51 | \n", "0.436817 | \n", "0.567215 | \n", "
Default | \n", "d | \n", "0.431253 | \n", "0.032580 | \n", "13.236884 | \n", "5.373218e-40 | \n", "0.367398 | \n", "0.495108 | \n", "
Less time | \n", "d | \n", "0.436394 | \n", "0.031007 | \n", "14.073929 | \n", "5.493102e-45 | \n", "0.375621 | \n", "0.497168 | \n", "