{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Python: IRM and APO Model Comparison\n", "\n", "In this simple example, we illustrate how the (binary) [DoubleMLIRM](https://docs.doubleml.org/stable/guide/models.html#binary-interactive-regression-model-irm) class relates to the [DoubleMLAPOS](https://docs.doubleml.org/stable/guide/models.html#average-potential-outcomes-apos-for-multiple-treatment-levels) class.\n", "\n", "More specifically, we focus on the `causal_contrast()` method of [DoubleMLAPOS](https://docs.doubleml.org/stable/guide/models.html#average-potential-outcomes-apos-for-multiple-treatment-levels) in a binary setting to highlight, when both methods coincide." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import doubleml as dml\n", "\n", "from sklearn.linear_model import LinearRegression, LogisticRegression\n", "from sklearn.preprocessing import PolynomialFeatures\n", "\n", "from matplotlib import pyplot as plt\n", "\n", "from doubleml.datasets import make_irm_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data\n", "\n", "We rely on the [make_irm_data](https://docs.doubleml.org/stable/api/generated/doubleml.datasets.make_irm_data.html) go generate data with a binary treatment." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_obs = 2000\n", "\n", "np.random.seed(42)\n", "df = make_irm_data(\n", " n_obs=n_obs,\n", " dim_x=10,\n", " theta=5.0,\n", " return_type='DataFrame'\n", ")\n", "\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, define the ``DoubleMLData`` object." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dml_data = dml.DoubleMLData(\n", " df,\n", " y_col='y',\n", " d_cols='d'\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learners and Hyperparameters\n", "\n", "To simplify the comparison and keep the variation in learners as small as possible, we will use linear models." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_folds = 5\n", "n_rep = 1\n", "\n", "dml_kwargs = {\n", " \"obj_dml_data\": dml_data,\n", " \"ml_g\": LinearRegression(),\n", " \"ml_m\": LogisticRegression(random_state=42),\n", " \"n_folds\": n_folds,\n", " \"n_rep\": n_rep,\n", " \"normalize_ipw\": False,\n", " \"trimming_threshold\": 1e-2,\n", " \"draw_sample_splitting\": False,\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Remark:**\n", "All results rely on the exact same predictions for the machine learning algorithms. If the more than two treatment levels exists the `DoubleMLAPOS` model fit multiple binary models such that the combined model might differ." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Further, to remove all uncertainty from sample splitting, we will rely on externally provided sample splits." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from doubleml.utils import DoubleMLResampling\n", "\n", "rskf = DoubleMLResampling(\n", " n_folds=n_folds,\n", " n_rep=n_rep,\n", " n_obs=n_obs,\n", " stratify=df['d'],\n", ")\n", "all_smpls = rskf.split_samples()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Average Treatment Effect\n", "\n", "Comparing the effect estimates for the `DoubleMLIRM` and `causal_contrasts` of the `DoubleMLAPOS` model, we can numerically equivalent results for the ATE." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dml_irm = dml.DoubleMLIRM(**dml_kwargs)\n", "dml_irm.set_sample_splitting(all_smpls)\n", "print(\"Training IRM Model\")\n", "dml_irm.fit()\n", "\n", "print(dml_irm.summary)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dml_apos = dml.DoubleMLAPOS(treatment_levels=[0,1], **dml_kwargs)\n", "dml_apos.set_sample_splitting(all_smpls)\n", "print(\"Training APOS Model\")\n", "dml_apos.fit()\n", "print(dml_apos.summary)\n", "\n", "print(\"Evaluate Causal Contrast\")\n", "causal_contrast = dml_apos.causal_contrast(reference_levels=[0])\n", "print(causal_contrast.summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a direct comparison, see" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"IRM Model\")\n", "print(dml_irm.summary)\n", "print(\"Causal Contrast\")\n", "print(causal_contrast.summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Average Treatment Effect on the Treated\n", "\n", "For the average treatment effect on the treated we can adjust the score in `DoubleMLIRM` model to `score=\"ATTE\"`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dml_irm_atte = dml.DoubleMLIRM(score=\"ATTE\", **dml_kwargs)\n", "dml_irm_atte.set_sample_splitting(all_smpls)\n", "print(\"Training IRM Model\")\n", "dml_irm_atte.fit()\n", "\n", "print(dml_irm_atte.summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to consider weighted effects in the `DoubleMLAPOS` model, we have to specify the correct weight, see [User Guide](https://docs.doubleml.org/stable/guide/heterogeneity.html#weighted-average-treatment-effects).\n", "\n", "As these weights include the propensity score, we will use the predicted propensity score from the previous `DoubleMLIRM` model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p_hat = df[\"d\"].mean()\n", "m_hat = dml_irm_atte.predictions[\"ml_m\"][:, :, 0]\n", "\n", "weights_dict = {\n", " \"weights\": df[\"d\"] / p_hat,\n", " \"weights_bar\": m_hat / p_hat,\n", "}\n", "\n", "dml_apos_atte = dml.DoubleMLAPOS(treatment_levels=[0,1], weights=weights_dict, **dml_kwargs)\n", "dml_apos_atte.set_sample_splitting(all_smpls)\n", "print(\"Training APOS Model\")\n", "dml_apos_atte.fit()\n", "print(dml_apos_atte.summary)\n", "\n", "print(\"Evaluate Causal Contrast\")\n", "causal_contrast_atte = dml_apos_atte.causal_contrast(reference_levels=[0])\n", "print(causal_contrast_atte.summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same results can be achieved by specifying the weights for `DoubleMLIRM` class with `score='ATE'`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dml_irm_weighted_atte = dml.DoubleMLIRM(score=\"ATE\", weights=weights_dict, **dml_kwargs)\n", "dml_irm_weighted_atte.set_sample_splitting(all_smpls)\n", "print(\"Training IRM Model\")\n", "dml_irm_weighted_atte.fit()\n", "\n", "print(dml_irm_weighted_atte.summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In summary, see" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"IRM Model ATTE Score\")\n", "print(dml_irm_atte.summary.round(4))\n", "print(\"IRM Model (Weighted)\")\n", "print(dml_irm_weighted_atte.summary.round(4))\n", "print(\"Causal Contrast (Weighted)\")\n", "print(causal_contrast_atte.summary.round(4))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sensitivity Analysis\n", "\n", "The sensitvity analysis gives identical results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dml_irm.sensitivity_analysis()\n", "print(dml_irm.sensitivity_summary)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "causal_contrast.sensitivity_analysis()\n", "print(causal_contrast.sensitivity_summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Effect Heterogeneity\n", "\n", "For conditional treatment effects the exact same methods do not exist.\n", "Nevertheless, we will compare the `capo()` variant of the `DoubleMLAPO` class to the corresponding `cate()` method of the `DoubleMLIRM` class.\n", "\n", "For a simple case we will just use a polynomial basis of the first feature `X1`. To plot the data we will evaluate the methods on the corresponding grid of basis values." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = df[[\"X1\"]]\n", "poly = PolynomialFeatures(degree=2, include_bias=True)\n", "\n", "basis_matrix = poly.fit_transform(X)\n", "basis_df = pd.DataFrame(basis_matrix, columns=poly.get_feature_names_out([\"X1\"]))\n", "\n", "grid = pd.DataFrame({\"X1\": np.linspace(np.quantile(df[\"X1\"], 0.1), np.quantile(df[\"X1\"], 0.9), 100)})\n", "grid_basis = pd.DataFrame( poly.transform(grid), columns=poly.get_feature_names_out([\"X1\"]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Apply the `cate()` method to the basis and evaluate on the transformed grid values." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cate = dml_irm.cate(basis_df)\n", "print(cate)\n", "np.random.seed(42)\n", "df_cate = cate.confint(grid_basis, level=0.95, joint=True, n_rep_boot=2000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The corresponding `apo()` method can be used for the treatment levels $0$ and $1$." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "capo0 = dml_apos.modellist[0].capo(basis_df)\n", "print(capo0)\n", "np.random.seed(42)\n", "df_capo0 = capo0.confint(grid_basis, level=0.95, joint=True, n_rep_boot=2000)\n", "\n", "capo1 = dml_apos.modellist[1].capo(basis_df)\n", "print(capo1)\n", "np.random.seed(42)\n", "df_capo1 = capo1.confint(grid_basis, level=0.95, joint=True, n_rep_boot=2000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example the average potential outcome of the control group is zero (as can be seen in the outcome definition, see [documentation](https://docs.doubleml.org/stable/api/generated/doubleml.datasets.make_irm_data.html#doubleml.datasets.make_irm_data)).\n", "Let us visualize the effects" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "nbsphinx-thumbnail" ] }, "outputs": [], "source": [ "df_cate['x'] = grid_basis['X1']\n", "\n", "plt.rcParams['figure.figsize'] = 10., 7.5\n", "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "\n", "# Plot CATE\n", "ax1.plot(df_cate['x'], df_cate['effect'], label='Estimated Effect')\n", "ax1.fill_between(df_cate['x'], df_cate['2.5 %'], df_cate['97.5 %'], alpha=.3, label='Confidence Interval')\n", "ax1.legend()\n", "ax1.set_title('CATE')\n", "ax1.set_xlabel('X1')\n", "ax1.set_ylabel('Effect and 95%-CI')\n", "\n", "# Plot Average Potential Outcomes\n", "ax2.plot(df_cate['x'], df_capo0['effect'], label='APO(0)')\n", "ax2.fill_between(df_cate['x'], df_capo0['2.5 %'], df_capo0['97.5 %'], alpha=.3, label='Confidence Interval')\n", "ax2.plot(df_cate['x'], df_capo1['effect'], label='APO(1)')\n", "ax2.fill_between(df_cate['x'], df_capo1['2.5 %'], df_capo1['97.5 %'], alpha=.3, label='Confidence Interval')\n", "ax2.legend()\n", "ax2.set_title('Average Potential Outcomes')\n", "ax2.set_xlabel('X1')\n", "ax2.set_ylabel('Effect and 95%-CI')\n", "\n", "# Ensure the same scale on y-axis\n", "ax1.set_ylim(min(ax1.get_ylim()[0], ax2.get_ylim()[0]), max(ax1.get_ylim()[1], ax2.get_ylim()[1]))\n", "ax2.set_ylim(min(ax1.get_ylim()[0], ax2.get_ylim()[0]), max(ax1.get_ylim()[1], ax2.get_ylim()[1]))\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `causal_contrast()` method does not currently not have a `cate()` method implemented. But the cate can be manually constructed via the the correct score function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "orth_signal = -1.0 * (causal_contrast.scaled_psi.reshape(-1) - causal_contrast.thetas)\n", "\n", "causal_contrast_cate = dml.utils.DoubleMLBLP(orth_signal, basis_df)\n", "causal_contrast_cate.fit()\n", "print(causal_contrast_cate.summary)\n", "np.random.seed(42)\n", "df_causal_contrast_cate = causal_contrast_cate.confint(grid_basis, level=0.95, joint=True, n_rep_boot=2000)\n", "\n", "print(\"CATE (IRM) as comparison:\")\n", "print(cate.summary)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.rcParams['figure.figsize'] = 10., 7.5\n", "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "\n", "# Plot CATE\n", "ax1.plot(df_cate['x'], df_cate['effect'], label='Estimated Effect')\n", "ax1.fill_between(df_cate['x'], df_cate['2.5 %'], df_cate['97.5 %'], alpha=.3, label='Confidence Interval')\n", "ax1.legend()\n", "ax1.set_title('CATE (IRM)')\n", "ax1.set_xlabel('X1')\n", "ax1.set_ylabel('Effect and 95%-CI')\n", "\n", "# Plot Average Potential Outcomes\n", "ax2.plot(df_cate['x'], df_causal_contrast_cate['effect'], label='Estimated Effect')\n", "ax2.fill_between(df_cate['x'], df_causal_contrast_cate['2.5 %'], df_causal_contrast_cate['97.5 %'], alpha=.3, label='Confidence Interval')\n", "ax2.legend()\n", "ax2.set_title('CATE (Causal Contrast)')\n", "ax2.set_xlabel('X1')\n", "ax2.set_ylabel('Effect and 95%-CI')\n", "\n", "# Ensure the same scale on y-axis\n", "ax1.set_ylim(min(ax1.get_ylim()[0], ax2.get_ylim()[0]), max(ax1.get_ylim()[1], ax2.get_ylim()[1]))\n", "ax2.set_ylim(min(ax1.get_ylim()[0], ax2.get_ylim()[0]), max(ax1.get_ylim()[1], ax2.get_ylim()[1]))\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }