Python: Causal Machine Learning with TabPFN#

In this example, we demonstrate how to use TabPFN (Tabular Prior-data Fitted Network) as a machine learning estimator within the DoubleML framework for causal inference. We compare TabPFN’s performance against (untuned) traditional machine learning methods including Random Forest, Linear models, and LightGBM.

TabPFN is a foundation model specifically designed for tabular data that can perform inference without traditional training. It leverages a transformer architecture trained on a vast collection of synthetic tabular datasets, making it particularly effective for small to medium-sized datasets commonly encountered in causal inference applications.

We will estimate Average Potential Outcomes (APOs) using the DoubleMLAPOS model, which allows us to estimate:

\[\theta_d = \mathbb{E}[Y(d)]\]

for different treatment levels \(d\) in a discrete treatment setting.

Imports and Setup#

We start by importing the necessary libraries. Note that TabPFN requires a separate installation, see documentation.

For GPU acceleration (recommended), ensure you have CUDA-enabled PyTorch installed. Instead you can also use the TabPFN API Client.

[1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
import time

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.linear_model import LinearRegression, LogisticRegression
import lightgbm as lgbm
from tabpfn import TabPFNRegressor, TabPFNClassifier
from tabpfn.constants import ModelVersion

import doubleml as dml
from doubleml.irm.datasets import make_irm_data_discrete_treatments

import warnings
warnings.filterwarnings("ignore", message="Running on CPU*", category=UserWarning, module="tabpfn")
warnings.filterwarnings("ignore", message=".*does not have valid feature names.*", category=UserWarning)

Data Generating Process (DGP)#

We generate synthetic data using DoubleML’s discrete treatment data generating process. This creates:

  • A continuous treatment variable that is subsequently discretized into multiple levels \(D\)

  • True individual treatment effects (ITEs) for comparison with our estimates

  • Covariates \(X\) that affect both treatment assignment \(D\) and outcomes \(Y\)

The discretization allows us to compare estimated Average Potential Outcomes (APOs) and Average Treatment Effects (ATEs) against their true values, providing a clear benchmark for evaluating different machine learning methods. For more details on the data generating process and the APO model, we refer to the APO Model Example Notebook.

[2]:
# Parameters
n_obs = 1000
n_levels = 5
linear = False
n_rep = 1

np.random.seed(42)
data_apo = make_irm_data_discrete_treatments(n_obs=n_obs, n_levels=n_levels, linear=linear)

y0 = data_apo['oracle_values']['y0']
cont_d = data_apo['oracle_values']['cont_d']
ite = data_apo['oracle_values']['ite']
d = data_apo['d']
potential_level = data_apo['oracle_values']['potential_level']
level_bounds = data_apo['oracle_values']['level_bounds']

average_ites = np.full(n_levels + 1, np.nan)
apos = np.full(n_levels + 1, np.nan)
mid_points = np.full(n_levels, np.nan)

for i in range(n_levels + 1):
    average_ites[i] = np.mean(ite[d == i]) * (i > 0)
    apos[i] = np.mean(y0) + average_ites[i]

print(f"Average treatment effects in each group:\n{np.round(average_ites,2)}\n")
print(f"Average potential outcomes in each group:\n{np.round(apos,2)}\n")
print(f"Levels and their counts:\n{np.unique(d, return_counts=True)}")
Average treatment effects in each group:
[ 0.    1.46  6.67  9.31 10.36 10.47]

Average potential outcomes in each group:
[209.9  211.36 216.56 219.2  220.26 220.37]

Levels and their counts:
(array([0., 1., 2., 3., 4., 5.]), array([183, 165, 154, 162, 175, 161]))

Visualizing the Treatment Effect Structure#

To better understand our data, let’s visualize the relationship between the continuous treatment variable and the individual treatment effects, along with how the treatment is discretized into levels.

[3]:
# Get a colorblind-friendly palette
palette = sns.color_palette("colorblind")

df = pd.DataFrame({'cont_d': cont_d, 'ite': ite})
df_sorted = df.sort_values('cont_d')

mid_points = np.full(n_levels, np.nan)
for i in range(n_levels):
    mid_points[i] = (level_bounds[i] + level_bounds[i + 1]) / 2

df_apos = pd.DataFrame({'mid_points': mid_points, 'treatment effects': apos[1:] - apos[0]})

# Create the primary plot with scatter and line plots
fig, ax1 = plt.subplots()

sns.lineplot(data=df_sorted, x='cont_d', y='ite', color=palette[0], label='ITE', ax=ax1)
sns.scatterplot(data=df_apos, x='mid_points', y='treatment effects', color=palette[1], label='Grouped Treatment Effects', ax=ax1)

# Add vertical dashed lines at level_bounds
for bound in level_bounds:
    ax1.axvline(x=bound, color='grey', linestyle='--', alpha=0.7)

ax1.set_title('Grouped Effects vs. Continuous Treatment')
ax1.set_xlabel('Continuous Treatment')
ax1.set_ylabel('Effects')

# Create a secondary y-axis for the histogram
ax2 = ax1.twinx()

# Plot the histogram on the secondary y-axis
ax2.hist(df_sorted['cont_d'], bins=30, alpha=0.3, weights=np.ones_like(df_sorted['cont_d']) / len(df_sorted['cont_d']), color=palette[2])
ax2.set_ylabel('Density')

# Make sure the legend includes all plots
lines, labels = ax1.get_legend_handles_labels()
ax1.legend(lines, labels, loc='upper left')

plt.show()
../../_images/examples_learners_py_tabpfn_6_0.png

Creating the DoubleMLData Object#

As with all DoubleML models, we need to create a DoubleMLData object to properly structure our data for causal inference. This object handles the separation of outcome variables, treatment variables, and covariates.

[4]:
y = data_apo['y']
x = data_apo['x']
d = data_apo['d']
df_apo = pd.DataFrame(
    np.column_stack((y, d, x)),
    columns=['y', 'd'] + ['x' + str(i) for i in range(data_apo['x'].shape[1])]
)

dml_data = dml.DoubleMLData(df_apo, 'y', 'd')
print(dml_data)
================== DoubleMLData Object ==================

------------------ Data summary      ------------------
Outcome variable: y
Treatment variable(s): ['d']
Covariates: ['x0', 'x1', 'x2', 'x3', 'x4']
Instrument variable(s): None
No. Observations: 1000

------------------ DataFrame info    ------------------
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Columns: 7 entries, y to x4
dtypes: float64(7)
memory usage: 54.8 KB

DoubleML with TabPFN#

The TabPFN package integrates seamlessly with the DoubleML framework for causal inference tasks.

For fitting average potential outcome models, the DoubleML interface requires to specify the ml_g and ml_m learners:

  • ml_g: A regressor for the outcome model \(g_0(D,X) = \mathbb{E}[Y|X,D]\)

  • ml_m: A classifier for the propensity score model \(m_{0,d}(X) = \mathbb{E}[1\{D=d\}|X]\)

Note: TabPFN works best with CUDA acceleration. If CUDA is not available, it will fall back to CPU computation. Instead you can use TabPFN API Client.

[5]:
device = 'cpu'
ml_g = TabPFNRegressor.create_default_for_version(ModelVersion.V2, device=device, random_state=42)
ml_m = TabPFNClassifier.create_default_for_version(ModelVersion.V2, device=device, random_state=42)

To model average potential outcomes, we initialize the DoubleMLAPOS object with the specified machine learning methods and treatment levels.

[6]:
treatment_levels = np.unique(dml_data.d)
dml_obj = dml.DoubleMLAPOS(
    dml_data,
    ml_g=ml_g,
    ml_m=ml_m,
    treatment_levels=treatment_levels,
    n_rep=n_rep,
)

As usual, you can estimate the parameters by calling the fit method on the dml_obj instance.

[7]:
dml_obj.fit()
print(dml_obj)
================== DoubleMLAPOS Object ==================

------------------ Fit summary       ------------------
           coef   std err           t  P>|t|       2.5 %      97.5 %
0.0  209.309952  1.208485  173.200225    0.0  206.941364  211.678540
1.0  210.984368  1.359622  155.178751    0.0  208.319559  213.649177
2.0  216.506080  1.265888  171.030964    0.0  214.024985  218.987175
3.0  219.259184  1.316197  166.585402    0.0  216.679486  221.838883
4.0  219.891299  1.279140  171.905612    0.0  217.384232  222.398367
5.0  219.380538  1.195920  183.440790    0.0  217.036577  221.724498

Machine Learning Methods Comparison#

We compare four different machine learning approaches for estimating the nuisance functions in our causal model:

  • Random Forest

  • Linear Models

  • Boosted Trees (LightGBM)

  • Foundation Model (TabPFN)

Remark that we did not tune the machine learning models in detail.

[8]:
rf_arguments = {
    'n_estimators': 500,
    'min_samples_leaf': 10,
    'random_state': 42
}

lgbm_arguments = {
    'n_estimators': 500,
    'learning_rate': 0.01,
    'max_depth': 3,
    'min_data_in_leaf': 10,
    'lambda_l1': 1,
    'lambda_l2': 2,
    'random_state': 42,
    'verbose': -1,
}
[9]:
learner_dict = {
    'RandomForest': {
        'ml_g': RandomForestRegressor(**rf_arguments),
        'ml_m': RandomForestClassifier(**rf_arguments)
    },
    'Linear': {
        'ml_g': LinearRegression(),
        'ml_m': LogisticRegression(max_iter=1000)
    },
    'LightGBM': {
        'ml_g': lgbm.LGBMRegressor(**lgbm_arguments),
        'ml_m': lgbm.LGBMClassifier(**lgbm_arguments)
    },
    'TabPFN': {
        'ml_g': TabPFNRegressor.create_default_for_version(ModelVersion.V2, device=device, random_state=42),
        'ml_m': TabPFNClassifier.create_default_for_version(ModelVersion.V2, device=device, random_state=42)
    }
}

Estimation of Average Potential Outcomes#

Now we estimate the Average Potential Outcomes (APOs) for each treatment level using all four machine learning methods. We use the DoubleMLAPOS class, which:

  1. Estimates nuisance functions: Uses cross-fitting to estimate \(g_0(D,X)\) and \(m_{0,d}(X)\)

  2. Computes APO estimates: Uses the efficient influence function to estimate \(\theta_d = \mathbb{E}[Y(d)]\)

  3. Provides confidence intervals: Based on the asymptotic distribution of the estimator

We also compute causal contrasts (Average Treatment Effects) as differences between treatment levels and the reference level (no treatment).

[10]:
reference_level = 0

apo_results = []
causal_contrast_results = []
model_list = []

for learner_name, learner_pair in learner_dict.items():
    print(f"\n{'='*60}")
    print(f"Fitting model: {learner_name}")
    print(f"{'='*60}")

    start_time = time.time()

    dml_obj = dml.DoubleMLAPOS(
        dml_data,
        learner_pair['ml_g'],
        learner_pair['ml_m'],
        treatment_levels=treatment_levels,
        n_rep=n_rep,
    )
    dml_obj.fit()

    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f"{learner_name} fitted in {elapsed_time:.2f} seconds")

    model_list.append(dml_obj)

    # APO confidence intervals
    ci_pointwise = dml_obj.confint(level=0.95)
    df_apos = pd.DataFrame({
        'learner': learner_name,
        'treatment_level': treatment_levels,
        'apo': dml_obj.coef,
        'ci_lower': ci_pointwise.values[:, 0],
        'ci_upper': ci_pointwise.values[:, 1]}
    )
    apo_results.append(df_apos)

    # ATE confidence intervals
    causal_contrast_model = dml_obj.causal_contrast(reference_levels=reference_level)
    ates = causal_contrast_model.thetas
    ci_ates = causal_contrast_model.confint(level=0.95)
    df_ates = pd.DataFrame({
        'learner': learner_name,
        'treatment_level': treatment_levels[1:],
        'ate': ates,
        'ci_lower': ci_ates.iloc[:, 0].values,
        'ci_upper': ci_ates.iloc[:, 1].values
    })
    causal_contrast_results.append(df_ates)

# Combine all results
df_all_apos = pd.concat(apo_results, ignore_index=True)
df_all_ates = pd.concat(causal_contrast_results, ignore_index=True)
df_all_ates

============================================================
Fitting model: RandomForest
============================================================
RandomForest fitted in 46.82 seconds

============================================================
Fitting model: Linear
============================================================
Linear fitted in 0.11 seconds

============================================================
Fitting model: LightGBM
============================================================
LightGBM fitted in 3.12 seconds

============================================================
Fitting model: TabPFN
============================================================
TabPFN fitted in 183.38 seconds
[10]:
learner treatment_level ate ci_lower ci_upper
0 RandomForest 1.0 2.253828 -2.699097 7.206752
1 RandomForest 2.0 7.498048 2.237756 12.758340
2 RandomForest 3.0 11.756539 6.289006 17.224073
3 RandomForest 4.0 7.678953 2.767027 12.590880
4 RandomForest 5.0 6.865101 0.994615 12.735587
5 Linear 1.0 3.932233 -1.962948 9.827414
6 Linear 2.0 8.720447 3.747057 13.693837
7 Linear 3.0 10.625888 6.871080 14.380695
8 Linear 4.0 11.339177 7.917038 14.761317
9 Linear 5.0 8.254858 4.324910 12.184806
10 LightGBM 1.0 1.276002 -4.056172 6.608177
11 LightGBM 2.0 6.166088 1.817977 10.514199
12 LightGBM 3.0 10.974491 6.011016 15.937967
13 LightGBM 4.0 9.726078 5.797966 13.654191
14 LightGBM 5.0 8.085730 3.570351 12.601109
15 TabPFN 1.0 1.621060 0.289549 2.952572
16 TabPFN 2.0 7.064673 6.089229 8.040118
17 TabPFN 3.0 10.021728 8.609947 11.433509
18 TabPFN 4.0 10.454991 9.231879 11.678104
19 TabPFN 5.0 9.751194 8.762049 10.740339

Visualizing Average Potential Outcomes#

Let’s compare the estimated APOs across all methods with their true values. The plot shows:

  • Estimated APOs: Point estimates with 95% confidence intervals for each method

  • True APOs: Red horizontal lines showing the oracle values

  • Treatment levels: Different dosage levels of the treatment (0 = no treatment)

[11]:
# Plot APOs and 95% CIs for all models
plt.figure(figsize=(12, 7))
palette = sns.color_palette("colorblind")
learners = df_all_apos['learner'].unique()
n_learners = len(learners)
jitter_strength = 0.12

for i, learner in enumerate(learners):
    df = df_all_apos[df_all_apos['learner'] == learner]
    # Jitter x positions for each learner
    jitter = (i - (n_learners - 1) / 2) * jitter_strength
    x_jittered = df['treatment_level'] + jitter
    plt.errorbar(
        x_jittered,
        df['apo'],
        yerr=[df['apo'] - df['ci_lower'], df['ci_upper'] - df['apo']],
        fmt='o',
        capsize=5,
        capthick=2,
        ecolor=palette[i % len(palette)],
        color=palette[i % len(palette)],
        label=f"{learner} APO ±95% CI",
        zorder=2
    )

# Get treatment levels for proper line positioning
treatment_levels_plot = sorted(df_all_apos['treatment_level'].unique())
x_range = plt.xlim()
total_width = x_range[1] - x_range[0]

# Add true APOs as red horizontal lines
for i, level in enumerate(treatment_levels_plot):
    # Center each line around its treatment level with a reasonable width
    line_width = 0.6  # Width of each horizontal line relative to treatment level spacing
    x_center = level
    x_start = x_center - line_width/2
    x_end = x_center + line_width/2

    # Convert to relative coordinates (0-1) for xmin/xmax
    xmin_rel = max(0, (x_start - x_range[0]) / total_width)
    xmax_rel = min(1, (x_end - x_range[0]) / total_width)

    plt.axhline(y=apos[int(level)], color='red', linestyle='-', alpha=0.7,
                xmin=xmin_rel, xmax=xmax_rel,
                linewidth=3, label='True APO' if i == 0 else "")

plt.title('Estimated APO and 95% Confidence Interval by Treatment Level')
plt.xlabel('Treatment Level')
plt.ylabel('Average Potential Outcome (APO)')
plt.xticks(sorted(df_all_apos['treatment_level'].unique()))
plt.legend()
plt.grid(True)
plt.show()
../../_images/examples_learners_py_tabpfn_21_0.png

It is quite clear to see, that without tuning the hyperparameters of the models, the TabPFN model achieves the best performance (smallest confidence intervals) across all treatment levels.

Visualizing Average Treatment Effects#

Now let’s examine the Average Treatment Effects (ATEs), which represent the causal effect of each treatment level compared to the reference level (no treatment). The ATE for treatment level \(d\) is defined as:

\[\text{ATE}_d = \mathbb{E}[Y(d)] - \mathbb{E}[Y(0)]\]
[12]:
# Plot ATEs and 95% CIs for all models
plt.figure(figsize=(12, 7))
palette = sns.color_palette("colorblind")
learners = df_all_ates['learner'].unique()
n_learners = len(learners)
jitter_strength = 0.12

for i, learner in enumerate(learners):
    df = df_all_ates[df_all_ates['learner'] == learner]
    # Jitter x positions for each learner
    jitter = (i - (n_learners - 1) / 2) * jitter_strength
    x_jittered = df['treatment_level'] + jitter
    plt.errorbar(
        x_jittered,
        df['ate'],
        yerr=[df['ate'] - df['ci_lower'], df['ci_upper'] - df['ate']],
        fmt='o',
        capsize=5,
        capthick=2,
        ecolor=palette[i % len(palette)],
        color=palette[i % len(palette)],
        label=f"{learner} ATE ±95% CI",
        zorder=2
    )

# Get treatment levels for proper line positioning
treatment_levels_plot = sorted(df_all_ates['treatment_level'].unique())
x_range = plt.xlim()
total_width = x_range[1] - x_range[0]

# Add true ATEs as red horizontal lines
for i, level in enumerate(treatment_levels_plot):
    # Center each line around its treatment level with a reasonable width
    line_width = 0.6  # Width of each horizontal line relative to treatment level spacing
    x_center = level
    x_start = x_center - line_width/2
    x_end = x_center + line_width/2

    # Convert to relative coordinates (0-1) for xmin/xmax
    xmin_rel = max(0, (x_start - x_range[0]) / total_width)
    xmax_rel = min(1, (x_end - x_range[0]) / total_width)

    # Use average_ites[level] for the true ATE (treatment levels start from 1 for ATEs)
    plt.axhline(y=average_ites[int(level)], color='red', linestyle='-', alpha=0.7,
                xmin=xmin_rel, xmax=xmax_rel,
                linewidth=3, label='True ATE' if i == 0 else "")

plt.title('Estimated ATE and 95% Confidence Interval by Treatment Level')
plt.xlabel('Treatment Level (vs. 0)')
plt.ylabel('ATE')
plt.xticks(sorted(df_all_ates['treatment_level'].unique()))
plt.legend()
plt.grid(True)
plt.show()
../../_images/examples_learners_py_tabpfn_24_0.png

Model Performance Evaluation#

To understand why different methods perform differently, let’s examine the performance of the underlying machine learning models used for the nuisance functions. DoubleML provides access to performance metrics for each component:

  • RMSE g0: Root Mean Square Error for the outcome model when treatment \(D \neq d\)

  • RMSE g1: Root Mean Square Error for the outcome model when treatment \(D = d\)

  • LogLoss m: Logarithmic loss for the propensity score model (treatment assignment prediction)

Better performance on these nuisance functions typically translates to more accurate causal estimates.

[13]:
# Create a comprehensive table with RMSE for g0, g1 and log loss for all learners and treatment levels
performance_results = []

for idx_learner, learner_name in enumerate(learner_dict.keys()):
    for idx_treat, treatment_level in enumerate(treatment_levels):
        # Get the specific model for this learner and treatment level
        model = model_list[idx_learner].modellist[idx_treat]

        # Extract performance metrics from nuisance_loss
        if model.nuisance_loss is not None:
            # RMSE for g0 (outcome model for treatment level != d)
            rmse_g0 = model.nuisance_loss['ml_g_d_lvl0'][0][0]

            # RMSE for g1 (outcome model for treatment level = d)
            rmse_g1 = model.nuisance_loss['ml_g_d_lvl1'][0][0]

            # Log loss for propensity score model
            logloss_m = model.nuisance_loss['ml_m'][0][0]
        else:
            rmse_g0 = rmse_g1 = logloss_m = None

        # Store results
        performance_results.append({
            'Learner': learner_name,
            'Treatment_Level': treatment_level,
            'RMSE_g0': rmse_g0,
            'RMSE_g1': rmse_g1,
            'LogLoss_m': logloss_m
        })

# Create DataFrame and display as a nicely formatted table
df_performance = pd.DataFrame(performance_results)

# Round values for better readability
df_performance['RMSE_g0'] = df_performance['RMSE_g0'].round(4)
df_performance['RMSE_g1'] = df_performance['RMSE_g1'].round(4)
df_performance['LogLoss_m'] = df_performance['LogLoss_m'].round(4)

print("\n\nRMSE g0 by Learner and Treatment Level:")
print("=" * 80)
pivot_rmse_g0 = df_performance.pivot(index='Learner', columns='Treatment_Level', values='RMSE_g0')
print(pivot_rmse_g0.to_string())

print("\n\nRMSE g1 by Learner and Treatment Level:")
print("=" * 80)
pivot_rmse_g1 = df_performance.pivot(index='Learner', columns='Treatment_Level', values='RMSE_g1')
print(pivot_rmse_g1.to_string())

print("\n\nLogLoss m by Learner and Treatment Level:")
print("=" * 80)
pivot_logloss = df_performance.pivot(index='Learner', columns='Treatment_Level', values='LogLoss_m')
print(pivot_logloss.to_string())


RMSE g0 by Learner and Treatment Level:
================================================================================
Treatment_Level      0.0      1.0      2.0      3.0      4.0      5.0
Learner
LightGBM         16.6043  12.6722  16.4420  16.1879  16.1063  16.3937
Linear           21.4232  17.3711  20.5572  21.3341  21.3044  21.3389
RandomForest     17.8429  14.2921  17.1360  17.0445  17.4103  17.7712
TabPFN            9.3977   2.8226   9.5238   9.4965   9.2663   9.2945


RMSE g1 by Learner and Treatment Level:
================================================================================
Treatment_Level      0.0      1.0      2.0      3.0      4.0      5.0
Learner
LightGBM         16.2032  30.3334  18.8350  16.5337  15.6930  13.1095
Linear           16.3087  29.3359  21.6232  17.8452  16.6807  14.6180
RandomForest     21.0889  31.6349  22.8979  20.8461  19.2861  17.7393
TabPFN            3.9209  15.6921   4.9981   7.4276   5.9365   3.2138


LogLoss m by Learner and Treatment Level:
================================================================================
Treatment_Level     0.0     1.0     2.0     3.0     4.0     5.0
Learner
LightGBM         0.4996  0.4506  0.4417  0.4663  0.4782  0.4302
Linear           0.4770  0.4391  0.4247  0.4450  0.4637  0.4234
RandomForest     0.5011  0.4401  0.4235  0.4649  0.4766  0.4344
TabPFN           0.4772  0.4348  0.4286  0.4481  0.4652  0.4218

Performance Summary and Insights#

Let’s summarize the average performance across all treatment levels to identify the best-performing methods:

[14]:
# Best performing learners for each metric
print("\nBest performing learners (averaged across treatment levels):")
print("-" * 60)

# Calculate average metrics across treatment levels for each learner
summary_stats = df_performance.groupby('Learner')[['RMSE_g0', 'RMSE_g1', 'LogLoss_m']].mean().round(4)
print(summary_stats)

Best performing learners (averaged across treatment levels):
------------------------------------------------------------
              RMSE_g0  RMSE_g1  LogLoss_m
Learner
LightGBM      15.7344  18.4513     0.4611
Linear        20.5548  19.4020     0.4455
RandomForest  16.9162  22.2489     0.4568
TabPFN         8.3002   6.8648     0.4460

Key Takeaways#

This example demonstrates several important findings about using TabPFN for causal inference.

  • Outcome modeling: TabPFN significantly outperforms traditional methods for both g0 and g1 functions, with much lower RMSE values

  • Causal estimates: The superior nuisance function performance translates to more accurate APO and ATE estimates

  • No hyperparameter tuning: TabPFN achieves these results without any model-specific tuning