{ "cells": [ { "cell_type": "markdown", "id": "3364570e", "metadata": {}, "source": [ "# Python: Causal Machine Learning with TabPFN\n", "\n", "In this example, we demonstrate how to use [TabPFN](https://github.com/automl/TabPFN) (Tabular Prior-data Fitted Network) as a machine learning estimator within the [DoubleML](https://docs.doubleml.org/stable/index.html) framework for causal inference. We compare TabPFN's performance against (untuned) traditional machine learning methods including Random Forest, Linear models, and LightGBM.\n", "\n", "TabPFN is a foundation model specifically designed for tabular data that can perform inference without traditional training. It leverages a transformer architecture trained on a vast collection of synthetic tabular datasets, making it particularly effective for small to medium-sized datasets commonly encountered in causal inference applications.\n", "\n", "We will estimate **Average Potential Outcomes (APOs)** using the [DoubleMLAPOS](https://docs.doubleml.org/stable/api/generated/doubleml.irm.DoubleMLAPOS.html) model, which allows us to estimate:\n", "\n", "$$\\theta_d = \\mathbb{E}[Y(d)]$$\n", "\n", "for different treatment levels $d$ in a discrete treatment setting." ] }, { "cell_type": "markdown", "id": "8dc2e533", "metadata": {}, "source": [ "## Imports and Setup\n", "\n", "We start by importing the necessary libraries. Note that TabPFN requires a separate installation, see [installation instructions](https://priorlabs.ai/getting_started/install/).\n", "\n", "For GPU acceleration (recommended), ensure you have CUDA-enabled PyTorch installed.\n", "Instead you can also use the [TabPFN API Client](https://github.com/PriorLabs/tabpfn-client)." ] }, { "cell_type": "code", "execution_count": 1, "id": "49c76183", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n", "from sklearn.linear_model import LinearRegression, LogisticRegression\n", "import lightgbm as lgbm\n", "from tabpfn import TabPFNRegressor, TabPFNClassifier\n", "\n", "import doubleml as dml\n", "from doubleml.datasets import make_irm_data_discrete_treatments\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", message=\"Running on CPU*\", category=UserWarning, module=\"tabpfn\")\n", "warnings.filterwarnings(\"ignore\", message=\".*does not have valid feature names.*\", category=UserWarning, module=\"lgbm\")\n", "warnings.filterwarnings(\"ignore\", category=FutureWarning, module=\"sklearn\")" ] }, { "cell_type": "markdown", "id": "4a04c896", "metadata": {}, "source": [ "## Data Generating Process (DGP)\n", "\n", "We generate synthetic data using DoubleML's discrete treatment data generating process. This creates:\n", "- A continuous treatment variable that is subsequently discretized into multiple levels $D$\n", "- True individual treatment effects (ITEs) for comparison with our estimates\n", "- Covariates $X$ that affect both treatment assignment $D$ and outcomes $Y$\n", "\n", "The discretization allows us to compare estimated Average Potential Outcomes (APOs) and Average Treatment Effects (ATEs) against their true values, providing a clear benchmark for evaluating different machine learning methods.\n", "For more details on the data generating process and the APO model, we refer to the [APO Model Example Notebook](https://docs.doubleml.org/stable/examples/py_double_ml_apo.html)." ] }, { "cell_type": "code", "execution_count": 2, "id": "746d6b11", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average treatment effects in each group:\n", "[ 0. 1.46 6.67 9.31 10.36 10.47]\n", "\n", "Average potential outcomes in each group:\n", "[209.9 211.36 216.56 219.2 220.26 220.37]\n", "\n", "Levels and their counts:\n", "(array([0., 1., 2., 3., 4., 5.]), array([183, 165, 154, 162, 175, 161]))\n" ] } ], "source": [ "# Parameters\n", "n_obs = 1000\n", "n_levels = 5\n", "linear = False\n", "n_rep = 1\n", "\n", "np.random.seed(42)\n", "data_apo = make_irm_data_discrete_treatments(n_obs=n_obs,n_levels=n_levels, linear=linear)\n", "\n", "y0 = data_apo['oracle_values']['y0']\n", "cont_d = data_apo['oracle_values']['cont_d']\n", "ite = data_apo['oracle_values']['ite']\n", "d = data_apo['d']\n", "potential_level = data_apo['oracle_values']['potential_level']\n", "level_bounds = data_apo['oracle_values']['level_bounds']\n", "\n", "average_ites = np.full(n_levels + 1, np.nan)\n", "apos = np.full(n_levels + 1, np.nan)\n", "mid_points = np.full(n_levels, np.nan)\n", "\n", "for i in range(n_levels + 1):\n", " average_ites[i] = np.mean(ite[d == i]) * (i > 0)\n", " apos[i] = np.mean(y0) + average_ites[i]\n", "\n", "print(f\"Average treatment effects in each group:\\n{np.round(average_ites,2)}\\n\")\n", "print(f\"Average potential outcomes in each group:\\n{np.round(apos,2)}\\n\")\n", "print(f\"Levels and their counts:\\n{np.unique(d, return_counts=True)}\")" ] }, { "cell_type": "markdown", "id": "230ae06b", "metadata": {}, "source": [ "### Visualizing the Treatment Effect Structure\n", "\n", "To better understand our data, let's visualize the relationship between the continuous treatment variable and the individual treatment effects, along with how the treatment is discretized into levels." ] }, { "cell_type": "code", "execution_count": 3, "id": "906c6c36", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Get a colorblind-friendly palette\n", "palette = sns.color_palette(\"colorblind\")\n", "\n", "df = pd.DataFrame({'cont_d': cont_d, 'ite': ite})\n", "df_sorted = df.sort_values('cont_d')\n", "\n", "mid_points = np.full(n_levels, np.nan)\n", "for i in range(n_levels):\n", " mid_points[i] = (level_bounds[i] + level_bounds[i + 1]) / 2\n", "\n", "df_apos = pd.DataFrame({'mid_points': mid_points, 'treatment effects': apos[1:] - apos[0]})\n", "\n", "# Create the primary plot with scatter and line plots\n", "fig, ax1 = plt.subplots()\n", "\n", "sns.lineplot(data=df_sorted, x='cont_d', y='ite', color=palette[0], label='ITE', ax=ax1)\n", "sns.scatterplot(data=df_apos, x='mid_points', y='treatment effects', color=palette[1], label='Grouped Treatment Effects', ax=ax1)\n", "\n", "# Add vertical dashed lines at level_bounds\n", "for bound in level_bounds:\n", " ax1.axvline(x=bound, color='grey', linestyle='--', alpha=0.7)\n", "\n", "ax1.set_title('Grouped Effects vs. Continuous Treatment')\n", "ax1.set_xlabel('Continuous Treatment')\n", "ax1.set_ylabel('Effects')\n", "\n", "# Create a secondary y-axis for the histogram\n", "ax2 = ax1.twinx()\n", "\n", "# Plot the histogram on the secondary y-axis\n", "ax2.hist(df_sorted['cont_d'], bins=30, alpha=0.3, weights=np.ones_like(df_sorted['cont_d']) / len(df_sorted['cont_d']), color=palette[2])\n", "ax2.set_ylabel('Density')\n", "\n", "# Make sure the legend includes all plots\n", "lines, labels = ax1.get_legend_handles_labels()\n", "ax1.legend(lines, labels, loc='upper left')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "2a3ef4e2", "metadata": {}, "source": [ "### Creating the DoubleMLData Object\n", "\n", "As with all DoubleML models, we need to create a [DoubleMLData](https://docs.doubleml.org/stable/api/generated/doubleml.data.DoubleMLData.html) object to properly structure our data for causal inference. This object handles the separation of outcome variables, treatment variables, and covariates." ] }, { "cell_type": "code", "execution_count": 4, "id": "d827dfab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================== DoubleMLData Object ==================\n", "\n", "------------------ Data summary ------------------\n", "Outcome variable: y\n", "Treatment variable(s): ['d']\n", "Covariates: ['x0', 'x1', 'x2', 'x3', 'x4']\n", "Instrument variable(s): None\n", "No. Observations: 1000\n", "\n", "------------------ DataFrame info ------------------\n", "\n", "RangeIndex: 1000 entries, 0 to 999\n", "Columns: 7 entries, y to x4\n", "dtypes: float64(7)\n", "memory usage: 54.8 KB\n", "\n" ] } ], "source": [ "y = data_apo['y']\n", "x = data_apo['x']\n", "d = data_apo['d']\n", "df_apo = pd.DataFrame(\n", " np.column_stack((y, d, x)),\n", " columns=['y', 'd'] + ['x' + str(i) for i in range(data_apo['x'].shape[1])]\n", ")\n", "\n", "dml_data = dml.DoubleMLData(df_apo, 'y', 'd')\n", "print(dml_data)" ] }, { "cell_type": "markdown", "id": "70beea16", "metadata": {}, "source": [ "## DoubleML with TabPFN\n", "\n", "The [TabPFN package](https://github.com/PriorLabs/tabpfn) integrates seamlessly with the [DoubleML](https://docs.doubleml.org/stable/index.html) framework for causal inference tasks.\n", "\n", "For fitting [average potential outcome models](https://docs.doubleml.org/stable/guide/models.html#average-potential-outcomes-apos), the `DoubleML` interface requires to specify the `ml_g` and `ml_m` learners:\n", "- `ml_g`: A regressor for the outcome model $g_0(D,X) = \\mathbb{E}[Y|X,D]$\n", "- `ml_m`: A classifier for the propensity score model $m_{0,d}(X) = \\mathbb{E}[1\\{D=d\\}|X]$\n", "\n", "**Note**: TabPFN works best with CUDA acceleration. If CUDA is not available, it will fall back to CPU computation. Instead you can use [TabPFN API Client](https://github.com/PriorLabs/tabpfn-client)." ] }, { "cell_type": "code", "execution_count": 5, "id": "15aa7b39", "metadata": {}, "outputs": [], "source": [ "device = 'cpu'\n", "ml_g = TabPFNRegressor(device=device)\n", "ml_m = TabPFNClassifier(device=device)" ] }, { "cell_type": "markdown", "id": "23c720e8", "metadata": {}, "source": [ "To model average potential outcomes, we initialize the [DoubleMLAPOS](https://docs.doubleml.org/stable/api/generated/doubleml.irm.DoubleMLAPOS.html#doubleml.irm.DoubleMLAPOS) object with the specified machine learning methods and treatment levels." ] }, { "cell_type": "code", "execution_count": 6, "id": "a85a3301", "metadata": {}, "outputs": [], "source": [ "treatment_levels = np.unique(dml_data.d)\n", "dml_obj = dml.DoubleMLAPOS(\n", " dml_data,\n", " ml_g=ml_g,\n", " ml_m=ml_m,\n", " treatment_levels=treatment_levels,\n", " n_rep=n_rep,\n", ")" ] }, { "cell_type": "markdown", "id": "941f3c8e", "metadata": {}, "source": [ "As usual, you can estimate the parameters by calling the `fit` method on the `dml_obj` instance." ] }, { "cell_type": "code", "execution_count": 7, "id": "dbd90a29", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0b99ff64cd2644dfaa62c41b58bf02f9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tabpfn-v2-classifier-finetuned-zk73skhh.(…): 0%| | 0.00/29.0M [00:00|t| 2.5 % 97.5 %\n", "0.0 209.395480 1.211939 172.777199 0.0 207.020123 211.770838\n", "1.0 210.966031 1.367425 154.279824 0.0 208.285928 213.646134\n", "2.0 216.538410 1.245027 173.922656 0.0 214.098202 218.978618\n", "3.0 219.333914 1.334717 164.329850 0.0 216.717916 221.949912\n", "4.0 219.905724 1.278724 171.972735 0.0 217.399470 222.411977\n", "5.0 219.265669 1.177094 186.277179 0.0 216.958608 221.572730\n" ] } ], "source": [ "dml_obj.fit()\n", "print(dml_obj)" ] }, { "cell_type": "markdown", "id": "bd66a2f9", "metadata": {}, "source": [ "## Machine Learning Methods Comparison\n", "\n", "We compare four different machine learning approaches for estimating the nuisance functions in our causal model:\n", "\n", "1. **Random Forest**: Ensemble method with bagging and random feature selection\n", "2. **Linear Models**: Linear/Logistic regression\n", "3. **LightGBM**: Gradient boosting framework\n", "4. **TabPFN**: A foundation model for tabular data" ] }, { "cell_type": "code", "execution_count": 8, "id": "ea1c0ce4", "metadata": {}, "outputs": [], "source": [ "learner_dict = {\n", " 'RandomForest': {\n", " 'ml_g': RandomForestRegressor(),\n", " 'ml_m': RandomForestClassifier()\n", " },\n", " 'Linear': {\n", " 'ml_g': LinearRegression(),\n", " 'ml_m': LogisticRegression(max_iter=1000)\n", " },\n", " 'LightGBM': {\n", " 'ml_g': lgbm.LGBMRegressor(n_estimators=50, verbose=-1),\n", " 'ml_m': lgbm.LGBMClassifier(n_estimators=50, verbose=-1)\n", " },\n", " 'TabPFN': {\n", " 'ml_g': TabPFNRegressor(device=device),\n", " 'ml_m': TabPFNClassifier(device=device)\n", " }\n", "}" ] }, { "cell_type": "markdown", "id": "7ffd5a74", "metadata": {}, "source": [ "### Estimation of Average Potential Outcomes\n", "\n", "Now we estimate the Average Potential Outcomes (APOs) for each treatment level using all four machine learning methods. We use the [DoubleMLAPOS](https://docs.doubleml.org/dev/api/generated/doubleml.irm.DoubleMLAPOS.html) class, which:\n", "\n", "1. **Estimates nuisance functions**: Uses cross-fitting to estimate $g_0(D,X)$ and $m_{0,d}(X)$ \n", "2. **Computes APO estimates**: Uses the efficient influence function to estimate $\\theta_d = \\mathbb{E}[Y(d)]$\n", "3. **Provides confidence intervals**: Based on the asymptotic distribution of the estimator\n", "\n", "We also compute **causal contrasts** (Average Treatment Effects) as differences between treatment levels and the reference level (no treatment)." ] }, { "cell_type": "code", "execution_count": 9, "id": "db8b5c59", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\utils\\_checks.py:194: UserWarning: Propensity predictions from learner RandomForestClassifier() for ml_m are close to zero or one (eps=1e-12).\n", " warnings.warn(\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\utils\\_checks.py:194: UserWarning: Propensity predictions from learner RandomForestClassifier() for ml_m are close to zero or one (eps=1e-12).\n", " warnings.warn(\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\utils\\_checks.py:194: UserWarning: Propensity predictions from learner RandomForestClassifier() for ml_m are close to zero or one (eps=1e-12).\n", " warnings.warn(\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\utils\\_checks.py:194: UserWarning: Propensity predictions from learner RandomForestClassifier() for ml_m are close to zero or one (eps=1e-12).\n", " warnings.warn(\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n", "c:\\Users\\BAM5698\\AppData\\Local\\miniconda3\\envs\\dml_docs\\Lib\\site-packages\\doubleml\\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).\n", " warnings.warn(msg, UserWarning)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
learnertreatment_levelateci_lowerci_upper
0RandomForest1.01.325370-4.2189196.869658
1RandomForest2.05.7022481.09459510.309901
2RandomForest3.08.7432034.08072913.405677
3RandomForest4.06.6324071.97903111.285783
4RandomForest5.07.9672062.35908313.575329
5Linear1.04.971059-1.23128111.173400
6Linear2.07.9633672.88296513.043769
7Linear3.010.7475386.74553614.749540
8Linear4.011.5943177.88923615.299398
9Linear5.06.6269983.01452510.239471
10LightGBM1.02.156988-16.41126420.725240
11LightGBM2.05.973140-7.91522519.861505
12LightGBM3.010.824140-4.92541026.573689
13LightGBM4.011.547482-2.46547625.560440
14LightGBM5.012.528573-2.62775427.684899
15TabPFN1.01.5316820.2088402.854525
16TabPFN2.06.9586126.0262927.890933
17TabPFN3.010.3179469.04809011.587801
18TabPFN4.010.3913899.41750211.365275
19TabPFN5.09.8168968.87327510.760517
\n", "
" ], "text/plain": [ " learner treatment_level ate ci_lower ci_upper\n", "0 RandomForest 1.0 1.325370 -4.218919 6.869658\n", "1 RandomForest 2.0 5.702248 1.094595 10.309901\n", "2 RandomForest 3.0 8.743203 4.080729 13.405677\n", "3 RandomForest 4.0 6.632407 1.979031 11.285783\n", "4 RandomForest 5.0 7.967206 2.359083 13.575329\n", "5 Linear 1.0 4.971059 -1.231281 11.173400\n", "6 Linear 2.0 7.963367 2.882965 13.043769\n", "7 Linear 3.0 10.747538 6.745536 14.749540\n", "8 Linear 4.0 11.594317 7.889236 15.299398\n", "9 Linear 5.0 6.626998 3.014525 10.239471\n", "10 LightGBM 1.0 2.156988 -16.411264 20.725240\n", "11 LightGBM 2.0 5.973140 -7.915225 19.861505\n", "12 LightGBM 3.0 10.824140 -4.925410 26.573689\n", "13 LightGBM 4.0 11.547482 -2.465476 25.560440\n", "14 LightGBM 5.0 12.528573 -2.627754 27.684899\n", "15 TabPFN 1.0 1.531682 0.208840 2.854525\n", "16 TabPFN 2.0 6.958612 6.026292 7.890933\n", "17 TabPFN 3.0 10.317946 9.048090 11.587801\n", "18 TabPFN 4.0 10.391389 9.417502 11.365275\n", "19 TabPFN 5.0 9.816896 8.873275 10.760517" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reference_level = 0 \n", "\n", "apo_results = []\n", "causal_contrast_results = []\n", "model_list = []\n", "\n", "for learner_name, learner_pair in learner_dict.items():\n", " # Recreate dml_obj for each learner (as in the main loop)\n", " dml_obj = dml.DoubleMLAPOS(\n", " dml_data,\n", " learner_pair['ml_g'],\n", " learner_pair['ml_m'],\n", " treatment_levels=treatment_levels,\n", " n_rep=n_rep,\n", " )\n", " dml_obj.fit()\n", " model_list.append(dml_obj)\n", "\n", " # APO confidence intervals\n", " ci_pointwise = dml_obj.confint(level=0.95)\n", " df_apos = pd.DataFrame({\n", " 'learner': learner_name,\n", " 'treatment_level': treatment_levels,\n", " 'apo': dml_obj.coef,\n", " 'ci_lower': ci_pointwise.values[:, 0],\n", " 'ci_upper': ci_pointwise.values[:, 1]}\n", " )\n", " apo_results.append(df_apos)\n", "\n", " # ATE confidence intervals\n", " causal_contrast_model = dml_obj.causal_contrast(reference_levels=reference_level)\n", " ates = causal_contrast_model.thetas\n", " ci_ates = causal_contrast_model.confint(level=0.95)\n", " df_ates = pd.DataFrame({\n", " 'learner': learner_name,\n", " 'treatment_level': treatment_levels[1:],\n", " 'ate': ates,\n", " 'ci_lower': ci_ates.iloc[:, 0].values,\n", " 'ci_upper': ci_ates.iloc[:, 1].values\n", " })\n", " causal_contrast_results.append(df_ates)\n", "\n", "# Combine all results\n", "df_all_apos = pd.concat(apo_results, ignore_index=True)\n", "df_all_ates = pd.concat(causal_contrast_results, ignore_index=True)\n", "df_all_ates" ] }, { "cell_type": "markdown", "id": "9c4a4e31", "metadata": {}, "source": [ "### Visualizing Average Potential Outcomes\n", "\n", "Let's compare the estimated APOs across all methods with their true values. The plot shows:\n", "- **Estimated APOs**: Point estimates with 95% confidence intervals for each method\n", "- **True APOs**: Red horizontal lines showing the oracle values\n", "- **Treatment levels**: Different dosage levels of the treatment (0 = no treatment)" ] }, { "cell_type": "code", "execution_count": 10, "id": "6076a90e", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot APOs and 95% CIs for all models\n", "plt.figure(figsize=(12, 7))\n", "palette = sns.color_palette(\"colorblind\")\n", "learners = df_all_apos['learner'].unique()\n", "n_learners = len(learners)\n", "jitter_strength = 0.12\n", "\n", "for i, learner in enumerate(learners):\n", " df = df_all_apos[df_all_apos['learner'] == learner]\n", " # Jitter x positions for each learner\n", " jitter = (i - (n_learners - 1) / 2) * jitter_strength\n", " x_jittered = df['treatment_level'] + jitter\n", " plt.errorbar(\n", " x_jittered,\n", " df['apo'],\n", " yerr=[df['apo'] - df['ci_lower'], df['ci_upper'] - df['apo']],\n", " fmt='o',\n", " capsize=5,\n", " capthick=2,\n", " ecolor=palette[i % len(palette)],\n", " color=palette[i % len(palette)],\n", " label=f\"{learner} APO ±95% CI\",\n", " zorder=2\n", " )\n", "\n", "# Get treatment levels for proper line positioning\n", "treatment_levels = sorted(df_all_apos['treatment_level'].unique())\n", "x_range = plt.xlim()\n", "total_width = x_range[1] - x_range[0]\n", "\n", "# Add true APOs as red horizontal lines\n", "for i, level in enumerate(treatment_levels):\n", " # Center each line around its treatment level with a reasonable width\n", " line_width = 0.6 # Width of each horizontal line relative to treatment level spacing\n", " x_center = level\n", " x_start = x_center - line_width/2\n", " x_end = x_center + line_width/2\n", " \n", " # Convert to relative coordinates (0-1) for xmin/xmax\n", " xmin_rel = max(0, (x_start - x_range[0]) / total_width)\n", " xmax_rel = min(1, (x_end - x_range[0]) / total_width)\n", " \n", " plt.axhline(y=apos[int(level)], color='red', linestyle='-', alpha=0.7, \n", " xmin=xmin_rel, xmax=xmax_rel,\n", " linewidth=3, label='True APO' if i == 0 else \"\")\n", "\n", "plt.title('Estimated APO and 95% Confidence Interval by Treatment Level')\n", "plt.xlabel('Treatment Level')\n", "plt.ylabel('Average Potential Outcome (APO)')\n", "plt.xticks(sorted(df_all_apos['treatment_level'].unique()))\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "95e25c9e", "metadata": {}, "source": [ "It is quite clear to see, that without tuning the hyperparameters of the models, the TabPFN model achieves the best performance (smallest confidence intervals) across all treatment levels." ] }, { "cell_type": "markdown", "id": "aee3fc20", "metadata": {}, "source": [ "### Visualizing Average Treatment Effects\n", "\n", "Now let's examine the Average Treatment Effects (ATEs), which represent the causal effect of each treatment level compared to the reference level (no treatment). The ATE for treatment level $d$ is defined as:\n", "\n", "$$\\text{ATE}_d = \\mathbb{E}[Y(d)] - \\mathbb{E}[Y(0)]$$" ] }, { "cell_type": "code", "execution_count": 11, "id": "bd512fd4", "metadata": { "tags": [ "nbsphinx-gallery" ] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot ATEs and 95% CIs for all models\n", "plt.figure(figsize=(12, 7))\n", "palette = sns.color_palette(\"colorblind\")\n", "learners = df_all_ates['learner'].unique()\n", "n_learners = len(learners)\n", "jitter_strength = 0.12\n", "\n", "for i, learner in enumerate(learners):\n", " df = df_all_ates[df_all_ates['learner'] == learner]\n", " # Jitter x positions for each learner\n", " jitter = (i - (n_learners - 1) / 2) * jitter_strength\n", " x_jittered = df['treatment_level'] + jitter\n", " plt.errorbar(\n", " x_jittered,\n", " df['ate'],\n", " yerr=[df['ate'] - df['ci_lower'], df['ci_upper'] - df['ate']],\n", " fmt='o',\n", " capsize=5,\n", " capthick=2,\n", " ecolor=palette[i % len(palette)],\n", " color=palette[i % len(palette)],\n", " label=f\"{learner} ATE ±95% CI\",\n", " zorder=2\n", " )\n", "\n", "# Get treatment levels for proper line positioning\n", "treatment_levels = sorted(df_all_ates['treatment_level'].unique())\n", "x_range = plt.xlim()\n", "total_width = x_range[1] - x_range[0]\n", "\n", "# Add true ATEs as red horizontal lines\n", "for i, level in enumerate(treatment_levels):\n", " # Center each line around its treatment level with a reasonable width\n", " line_width = 0.6 # Width of each horizontal line relative to treatment level spacing\n", " x_center = level\n", " x_start = x_center - line_width/2\n", " x_end = x_center + line_width/2\n", " \n", " # Convert to relative coordinates (0-1) for xmin/xmax\n", " xmin_rel = max(0, (x_start - x_range[0]) / total_width)\n", " xmax_rel = min(1, (x_end - x_range[0]) / total_width)\n", " \n", " # Use average_ites[level] for the true ATE (treatment levels start from 1 for ATEs)\n", " plt.axhline(y=average_ites[int(level)], color='red', linestyle='-', alpha=0.7, \n", " xmin=xmin_rel, xmax=xmax_rel,\n", " linewidth=3, label='True ATE' if i == 0 else \"\")\n", "\n", "plt.title('Estimated ATE and 95% Confidence Interval by Treatment Level')\n", "plt.xlabel('Treatment Level (vs. 0)')\n", "plt.ylabel('ATE')\n", "plt.xticks(sorted(df_all_ates['treatment_level'].unique()))\n", "plt.legend()\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "391d7fb2", "metadata": {}, "source": [ "### Model Performance Evaluation\n", "\n", "To understand why different methods perform differently, let's examine the performance of the underlying machine learning models used for the nuisance functions. DoubleML provides access to performance metrics for each component:\n", "\n", "- **RMSE g0**: Root Mean Square Error for the outcome model when treatment $D \\neq d$\n", "- **RMSE g1**: Root Mean Square Error for the outcome model when treatment $D = d$\n", "- **LogLoss m**: Logarithmic loss for the propensity score model (treatment assignment prediction)\n", "\n", "Better performance on these nuisance functions typically translates to more accurate causal estimates." ] }, { "cell_type": "code", "execution_count": 12, "id": "9d683935", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "RMSE g0 by Learner and Treatment Level:\n", "================================================================================\n", "Treatment_Level 1.0 2.0 3.0 4.0 5.0\n", "Learner \n", "LightGBM 15.2287 11.1751 15.2422 15.1509 14.8526\n", "Linear 21.1795 17.4716 20.3264 20.9783 21.0718\n", "RandomForest 14.3663 11.4107 14.5979 14.5166 14.3334\n", "TabPFN 10.0373 2.8767 10.5702 9.9764 10.2091\n", "\n", "\n", "RMSE g1 by Learner and Treatment Level:\n", "================================================================================\n", "Treatment_Level 1.0 2.0 3.0 4.0 5.0\n", "Learner \n", "LightGBM 17.2116 31.2223 20.1003 18.2962 15.7370\n", "Linear 16.6144 30.1755 21.2086 18.4788 16.9636\n", "RandomForest 16.0477 25.8403 18.6678 19.0762 15.5213\n", "TabPFN 3.3852 16.5840 4.8130 6.5550 4.9701\n", "\n", "\n", "LogLoss m by Learner and Treatment Level:\n", "================================================================================\n", "Treatment_Level 1.0 2.0 3.0 4.0 5.0\n", "Learner \n", "LightGBM 0.5732 0.5118 0.4976 0.5644 0.5650\n", "Linear 0.4822 0.4352 0.4252 0.4460 0.4660\n", "RandomForest 0.5277 0.4642 0.4364 0.4860 0.5101\n", "TabPFN 0.4776 0.4347 0.4320 0.4462 0.4645\n" ] } ], "source": [ "# Create a comprehensive table with RMSE for g0, g1 and log loss for all learners and treatment levels\n", "performance_results = []\n", "\n", "for idx_learner, learner_name in enumerate(learner_dict.keys()):\n", " for idx_treat, treatment_level in enumerate(treatment_levels):\n", " # Get the specific model for this learner and treatment level\n", " model = model_list[idx_learner].modellist[idx_treat]\n", " \n", " # Extract performance metrics from nuisance_loss\n", " if model.nuisance_loss is not None:\n", " # RMSE for g0 (outcome model for treatment level != d)\n", " rmse_g0 = model.nuisance_loss['ml_g_d_lvl0'][0][0]\n", " \n", " # RMSE for g1 (outcome model for treatment level = d)\n", " rmse_g1 = model.nuisance_loss['ml_g_d_lvl1'][0][0]\n", " \n", " # Log loss for propensity score model\n", " logloss_m = model.nuisance_loss['ml_m'][0][0]\n", " else:\n", " rmse_g0 = rmse_g1 = logloss_m = None\n", " \n", " # Store results\n", " performance_results.append({\n", " 'Learner': learner_name,\n", " 'Treatment_Level': treatment_level,\n", " 'RMSE_g0': rmse_g0,\n", " 'RMSE_g1': rmse_g1,\n", " 'LogLoss_m': logloss_m\n", " })\n", "\n", "# Create DataFrame and display as a nicely formatted table\n", "df_performance = pd.DataFrame(performance_results)\n", "\n", "# Round values for better readability\n", "df_performance['RMSE_g0'] = df_performance['RMSE_g0'].round(4)\n", "df_performance['RMSE_g1'] = df_performance['RMSE_g1'].round(4)\n", "df_performance['LogLoss_m'] = df_performance['LogLoss_m'].round(4)\n", "\n", "print(\"\\n\\nRMSE g0 by Learner and Treatment Level:\")\n", "print(\"=\" * 80)\n", "pivot_rmse_g0 = df_performance.pivot(index='Learner', columns='Treatment_Level', values='RMSE_g0')\n", "print(pivot_rmse_g0.to_string())\n", "\n", "print(\"\\n\\nRMSE g1 by Learner and Treatment Level:\")\n", "print(\"=\" * 80)\n", "pivot_rmse_g1 = df_performance.pivot(index='Learner', columns='Treatment_Level', values='RMSE_g1')\n", "print(pivot_rmse_g1.to_string())\n", "\n", "print(\"\\n\\nLogLoss m by Learner and Treatment Level:\")\n", "print(\"=\" * 80)\n", "pivot_logloss = df_performance.pivot(index='Learner', columns='Treatment_Level', values='LogLoss_m')\n", "print(pivot_logloss.to_string())" ] }, { "cell_type": "markdown", "id": "96d1cb08", "metadata": {}, "source": [ "### Performance Summary and Insights\n", "\n", "Let's summarize the average performance across all treatment levels to identify the best-performing methods:" ] }, { "cell_type": "code", "execution_count": 13, "id": "3fb531ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Best performing learners (averaged across treatment levels):\n", "------------------------------------------------------------\n", " RMSE_g0 RMSE_g1 LogLoss_m\n", "Learner \n", "LightGBM 14.3299 20.5135 0.5424\n", "Linear 20.2055 20.6882 0.4509\n", "RandomForest 13.8450 19.0307 0.4849\n", "TabPFN 8.7339 7.2615 0.4510\n" ] } ], "source": [ "# Best performing learners for each metric\n", "print(\"\\nBest performing learners (averaged across treatment levels):\")\n", "print(\"-\" * 60)\n", "\n", "# Calculate average metrics across treatment levels for each learner\n", "summary_stats = df_performance.groupby('Learner')[['RMSE_g0', 'RMSE_g1', 'LogLoss_m']].mean().round(4)\n", "print(summary_stats)" ] }, { "cell_type": "markdown", "id": "99cefd29", "metadata": {}, "source": [ "## Key Takeaways\n", "\n", "This example demonstrates several important findings about using TabPFN for causal inference.\n", "\n", "- **Outcome modeling**: TabPFN significantly outperforms traditional methods for both g0 and g1 functions, with much lower RMSE values\n", "- **Causal estimates**: The superior nuisance function performance translates to more accurate APO and ATE estimates\n", "- **No hyperparameter tuning**: TabPFN achieves these results without any model-specific tuning\n" ] } ], "metadata": { "kernelspec": { "display_name": "dml_docs", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 }