Note
-
Download Jupyter notebook:
https://docs.doubleml.org/stable/examples/learners/py_tabpfn.ipynb.
Python: Causal Machine Learning with TabPFN#
In this example, we demonstrate how to use TabPFN (Tabular Prior-data Fitted Network) as a machine learning estimator within the DoubleML framework for causal inference. We compare TabPFN’s performance against (untuned) traditional machine learning methods including Random Forest, Linear models, and LightGBM.
TabPFN is a foundation model specifically designed for tabular data that can perform inference without traditional training. It leverages a transformer architecture trained on a vast collection of synthetic tabular datasets, making it particularly effective for small to medium-sized datasets commonly encountered in causal inference applications.
We will estimate Average Potential Outcomes (APOs) using the DoubleMLAPOS model, which allows us to estimate:
for different treatment levels \(d\) in a discrete treatment setting.
Imports and Setup#
We start by importing the necessary libraries. Note that TabPFN requires a separate installation, see installation instructions.
For GPU acceleration (recommended), ensure you have CUDA-enabled PyTorch installed. Instead you can also use the TabPFN API Client.
[1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.linear_model import LinearRegression, LogisticRegression
import lightgbm as lgbm
from tabpfn import TabPFNRegressor, TabPFNClassifier
import doubleml as dml
from doubleml.datasets import make_irm_data_discrete_treatments
import warnings
warnings.filterwarnings("ignore", message="Running on CPU*", category=UserWarning, module="tabpfn")
warnings.filterwarnings("ignore", message=".*does not have valid feature names.*", category=UserWarning, module="lgbm")
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
Data Generating Process (DGP)#
We generate synthetic data using DoubleML’s discrete treatment data generating process. This creates:
A continuous treatment variable that is subsequently discretized into multiple levels \(D\)
True individual treatment effects (ITEs) for comparison with our estimates
Covariates \(X\) that affect both treatment assignment \(D\) and outcomes \(Y\)
The discretization allows us to compare estimated Average Potential Outcomes (APOs) and Average Treatment Effects (ATEs) against their true values, providing a clear benchmark for evaluating different machine learning methods. For more details on the data generating process and the APO model, we refer to the APO Model Example Notebook.
[2]:
# Parameters
n_obs = 1000
n_levels = 5
linear = False
n_rep = 1
np.random.seed(42)
data_apo = make_irm_data_discrete_treatments(n_obs=n_obs,n_levels=n_levels, linear=linear)
y0 = data_apo['oracle_values']['y0']
cont_d = data_apo['oracle_values']['cont_d']
ite = data_apo['oracle_values']['ite']
d = data_apo['d']
potential_level = data_apo['oracle_values']['potential_level']
level_bounds = data_apo['oracle_values']['level_bounds']
average_ites = np.full(n_levels + 1, np.nan)
apos = np.full(n_levels + 1, np.nan)
mid_points = np.full(n_levels, np.nan)
for i in range(n_levels + 1):
average_ites[i] = np.mean(ite[d == i]) * (i > 0)
apos[i] = np.mean(y0) + average_ites[i]
print(f"Average treatment effects in each group:\n{np.round(average_ites,2)}\n")
print(f"Average potential outcomes in each group:\n{np.round(apos,2)}\n")
print(f"Levels and their counts:\n{np.unique(d, return_counts=True)}")
Average treatment effects in each group:
[ 0. 1.46 6.67 9.31 10.36 10.47]
Average potential outcomes in each group:
[209.9 211.36 216.56 219.2 220.26 220.37]
Levels and their counts:
(array([0., 1., 2., 3., 4., 5.]), array([183, 165, 154, 162, 175, 161]))
Visualizing the Treatment Effect Structure#
To better understand our data, let’s visualize the relationship between the continuous treatment variable and the individual treatment effects, along with how the treatment is discretized into levels.
[3]:
# Get a colorblind-friendly palette
palette = sns.color_palette("colorblind")
df = pd.DataFrame({'cont_d': cont_d, 'ite': ite})
df_sorted = df.sort_values('cont_d')
mid_points = np.full(n_levels, np.nan)
for i in range(n_levels):
mid_points[i] = (level_bounds[i] + level_bounds[i + 1]) / 2
df_apos = pd.DataFrame({'mid_points': mid_points, 'treatment effects': apos[1:] - apos[0]})
# Create the primary plot with scatter and line plots
fig, ax1 = plt.subplots()
sns.lineplot(data=df_sorted, x='cont_d', y='ite', color=palette[0], label='ITE', ax=ax1)
sns.scatterplot(data=df_apos, x='mid_points', y='treatment effects', color=palette[1], label='Grouped Treatment Effects', ax=ax1)
# Add vertical dashed lines at level_bounds
for bound in level_bounds:
ax1.axvline(x=bound, color='grey', linestyle='--', alpha=0.7)
ax1.set_title('Grouped Effects vs. Continuous Treatment')
ax1.set_xlabel('Continuous Treatment')
ax1.set_ylabel('Effects')
# Create a secondary y-axis for the histogram
ax2 = ax1.twinx()
# Plot the histogram on the secondary y-axis
ax2.hist(df_sorted['cont_d'], bins=30, alpha=0.3, weights=np.ones_like(df_sorted['cont_d']) / len(df_sorted['cont_d']), color=palette[2])
ax2.set_ylabel('Density')
# Make sure the legend includes all plots
lines, labels = ax1.get_legend_handles_labels()
ax1.legend(lines, labels, loc='upper left')
plt.show()

Creating the DoubleMLData Object#
As with all DoubleML models, we need to create a DoubleMLData object to properly structure our data for causal inference. This object handles the separation of outcome variables, treatment variables, and covariates.
[4]:
y = data_apo['y']
x = data_apo['x']
d = data_apo['d']
df_apo = pd.DataFrame(
np.column_stack((y, d, x)),
columns=['y', 'd'] + ['x' + str(i) for i in range(data_apo['x'].shape[1])]
)
dml_data = dml.DoubleMLData(df_apo, 'y', 'd')
print(dml_data)
================== DoubleMLData Object ==================
------------------ Data summary ------------------
Outcome variable: y
Treatment variable(s): ['d']
Covariates: ['x0', 'x1', 'x2', 'x3', 'x4']
Instrument variable(s): None
No. Observations: 1000
------------------ DataFrame info ------------------
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Columns: 7 entries, y to x4
dtypes: float64(7)
memory usage: 54.8 KB
DoubleML with TabPFN#
The TabPFN package integrates seamlessly with the DoubleML framework for causal inference tasks.
For fitting average potential outcome models, the DoubleML
interface requires to specify the ml_g
and ml_m
learners:
ml_g
: A regressor for the outcome model \(g_0(D,X) = \mathbb{E}[Y|X,D]\)ml_m
: A classifier for the propensity score model \(m_{0,d}(X) = \mathbb{E}[1\{D=d\}|X]\)
Note: TabPFN works best with CUDA acceleration. If CUDA is not available, it will fall back to CPU computation. Instead you can use TabPFN API Client.
[5]:
device = 'cpu'
ml_g = TabPFNRegressor(device=device)
ml_m = TabPFNClassifier(device=device)
To model average potential outcomes, we initialize the DoubleMLAPOS object with the specified machine learning methods and treatment levels.
[6]:
treatment_levels = np.unique(dml_data.d)
dml_obj = dml.DoubleMLAPOS(
dml_data,
ml_g=ml_g,
ml_m=ml_m,
treatment_levels=treatment_levels,
n_rep=n_rep,
)
As usual, you can estimate the parameters by calling the fit
method on the dml_obj
instance.
[7]:
dml_obj.fit()
print(dml_obj)
================== DoubleMLAPOS Object ==================
------------------ Fit summary ------------------
coef std err t P>|t| 2.5 % 97.5 %
0.0 209.395480 1.211939 172.777199 0.0 207.020123 211.770838
1.0 210.966031 1.367425 154.279824 0.0 208.285928 213.646134
2.0 216.538410 1.245027 173.922656 0.0 214.098202 218.978618
3.0 219.333914 1.334717 164.329850 0.0 216.717916 221.949912
4.0 219.905724 1.278724 171.972735 0.0 217.399470 222.411977
5.0 219.265669 1.177094 186.277179 0.0 216.958608 221.572730
Machine Learning Methods Comparison#
We compare four different machine learning approaches for estimating the nuisance functions in our causal model:
Random Forest: Ensemble method with bagging and random feature selection
Linear Models: Linear/Logistic regression
LightGBM: Gradient boosting framework
TabPFN: A foundation model for tabular data
[8]:
learner_dict = {
'RandomForest': {
'ml_g': RandomForestRegressor(),
'ml_m': RandomForestClassifier()
},
'Linear': {
'ml_g': LinearRegression(),
'ml_m': LogisticRegression(max_iter=1000)
},
'LightGBM': {
'ml_g': lgbm.LGBMRegressor(n_estimators=50, verbose=-1),
'ml_m': lgbm.LGBMClassifier(n_estimators=50, verbose=-1)
},
'TabPFN': {
'ml_g': TabPFNRegressor(device=device),
'ml_m': TabPFNClassifier(device=device)
}
}
Estimation of Average Potential Outcomes#
Now we estimate the Average Potential Outcomes (APOs) for each treatment level using all four machine learning methods. We use the DoubleMLAPOS class, which:
Estimates nuisance functions: Uses cross-fitting to estimate \(g_0(D,X)\) and \(m_{0,d}(X)\)
Computes APO estimates: Uses the efficient influence function to estimate \(\theta_d = \mathbb{E}[Y(d)]\)
Provides confidence intervals: Based on the asymptotic distribution of the estimator
We also compute causal contrasts (Average Treatment Effects) as differences between treatment levels and the reference level (no treatment).
[9]:
reference_level = 0
apo_results = []
causal_contrast_results = []
model_list = []
for learner_name, learner_pair in learner_dict.items():
# Recreate dml_obj for each learner (as in the main loop)
dml_obj = dml.DoubleMLAPOS(
dml_data,
learner_pair['ml_g'],
learner_pair['ml_m'],
treatment_levels=treatment_levels,
n_rep=n_rep,
)
dml_obj.fit()
model_list.append(dml_obj)
# APO confidence intervals
ci_pointwise = dml_obj.confint(level=0.95)
df_apos = pd.DataFrame({
'learner': learner_name,
'treatment_level': treatment_levels,
'apo': dml_obj.coef,
'ci_lower': ci_pointwise.values[:, 0],
'ci_upper': ci_pointwise.values[:, 1]}
)
apo_results.append(df_apos)
# ATE confidence intervals
causal_contrast_model = dml_obj.causal_contrast(reference_levels=reference_level)
ates = causal_contrast_model.thetas
ci_ates = causal_contrast_model.confint(level=0.95)
df_ates = pd.DataFrame({
'learner': learner_name,
'treatment_level': treatment_levels[1:],
'ate': ates,
'ci_lower': ci_ates.iloc[:, 0].values,
'ci_upper': ci_ates.iloc[:, 1].values
})
causal_contrast_results.append(df_ates)
# Combine all results
df_all_apos = pd.concat(apo_results, ignore_index=True)
df_all_ates = pd.concat(causal_contrast_results, ignore_index=True)
df_all_ates
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\utils\_checks.py:194: UserWarning: Propensity predictions from learner RandomForestClassifier() for ml_m are close to zero or one (eps=1e-12).
warnings.warn(
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\utils\_checks.py:194: UserWarning: Propensity predictions from learner RandomForestClassifier() for ml_m are close to zero or one (eps=1e-12).
warnings.warn(
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\utils\_checks.py:194: UserWarning: Propensity predictions from learner RandomForestClassifier() for ml_m are close to zero or one (eps=1e-12).
warnings.warn(
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\utils\_checks.py:194: UserWarning: Propensity predictions from learner RandomForestClassifier() for ml_m are close to zero or one (eps=1e-12).
warnings.warn(
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
c:\Users\BAM5698\AppData\Local\miniconda3\envs\dml_docs\Lib\site-packages\doubleml\double_ml.py:1479: UserWarning: The estimated nu2 for d is not positive. Re-estimation based on riesz representer (non-orthogonal).
warnings.warn(msg, UserWarning)
[9]:
learner | treatment_level | ate | ci_lower | ci_upper | |
---|---|---|---|---|---|
0 | RandomForest | 1.0 | 1.325370 | -4.218919 | 6.869658 |
1 | RandomForest | 2.0 | 5.702248 | 1.094595 | 10.309901 |
2 | RandomForest | 3.0 | 8.743203 | 4.080729 | 13.405677 |
3 | RandomForest | 4.0 | 6.632407 | 1.979031 | 11.285783 |
4 | RandomForest | 5.0 | 7.967206 | 2.359083 | 13.575329 |
5 | Linear | 1.0 | 4.971059 | -1.231281 | 11.173400 |
6 | Linear | 2.0 | 7.963367 | 2.882965 | 13.043769 |
7 | Linear | 3.0 | 10.747538 | 6.745536 | 14.749540 |
8 | Linear | 4.0 | 11.594317 | 7.889236 | 15.299398 |
9 | Linear | 5.0 | 6.626998 | 3.014525 | 10.239471 |
10 | LightGBM | 1.0 | 2.156988 | -16.411264 | 20.725240 |
11 | LightGBM | 2.0 | 5.973140 | -7.915225 | 19.861505 |
12 | LightGBM | 3.0 | 10.824140 | -4.925410 | 26.573689 |
13 | LightGBM | 4.0 | 11.547482 | -2.465476 | 25.560440 |
14 | LightGBM | 5.0 | 12.528573 | -2.627754 | 27.684899 |
15 | TabPFN | 1.0 | 1.531682 | 0.208840 | 2.854525 |
16 | TabPFN | 2.0 | 6.958612 | 6.026292 | 7.890933 |
17 | TabPFN | 3.0 | 10.317946 | 9.048090 | 11.587801 |
18 | TabPFN | 4.0 | 10.391389 | 9.417502 | 11.365275 |
19 | TabPFN | 5.0 | 9.816896 | 8.873275 | 10.760517 |
Visualizing Average Potential Outcomes#
Let’s compare the estimated APOs across all methods with their true values. The plot shows:
Estimated APOs: Point estimates with 95% confidence intervals for each method
True APOs: Red horizontal lines showing the oracle values
Treatment levels: Different dosage levels of the treatment (0 = no treatment)
[10]:
# Plot APOs and 95% CIs for all models
plt.figure(figsize=(12, 7))
palette = sns.color_palette("colorblind")
learners = df_all_apos['learner'].unique()
n_learners = len(learners)
jitter_strength = 0.12
for i, learner in enumerate(learners):
df = df_all_apos[df_all_apos['learner'] == learner]
# Jitter x positions for each learner
jitter = (i - (n_learners - 1) / 2) * jitter_strength
x_jittered = df['treatment_level'] + jitter
plt.errorbar(
x_jittered,
df['apo'],
yerr=[df['apo'] - df['ci_lower'], df['ci_upper'] - df['apo']],
fmt='o',
capsize=5,
capthick=2,
ecolor=palette[i % len(palette)],
color=palette[i % len(palette)],
label=f"{learner} APO ±95% CI",
zorder=2
)
# Get treatment levels for proper line positioning
treatment_levels = sorted(df_all_apos['treatment_level'].unique())
x_range = plt.xlim()
total_width = x_range[1] - x_range[0]
# Add true APOs as red horizontal lines
for i, level in enumerate(treatment_levels):
# Center each line around its treatment level with a reasonable width
line_width = 0.6 # Width of each horizontal line relative to treatment level spacing
x_center = level
x_start = x_center - line_width/2
x_end = x_center + line_width/2
# Convert to relative coordinates (0-1) for xmin/xmax
xmin_rel = max(0, (x_start - x_range[0]) / total_width)
xmax_rel = min(1, (x_end - x_range[0]) / total_width)
plt.axhline(y=apos[int(level)], color='red', linestyle='-', alpha=0.7,
xmin=xmin_rel, xmax=xmax_rel,
linewidth=3, label='True APO' if i == 0 else "")
plt.title('Estimated APO and 95% Confidence Interval by Treatment Level')
plt.xlabel('Treatment Level')
plt.ylabel('Average Potential Outcome (APO)')
plt.xticks(sorted(df_all_apos['treatment_level'].unique()))
plt.legend()
plt.grid(True)
plt.show()

It is quite clear to see, that without tuning the hyperparameters of the models, the TabPFN model achieves the best performance (smallest confidence intervals) across all treatment levels.
Visualizing Average Treatment Effects#
Now let’s examine the Average Treatment Effects (ATEs), which represent the causal effect of each treatment level compared to the reference level (no treatment). The ATE for treatment level \(d\) is defined as:
[11]:
# Plot ATEs and 95% CIs for all models
plt.figure(figsize=(12, 7))
palette = sns.color_palette("colorblind")
learners = df_all_ates['learner'].unique()
n_learners = len(learners)
jitter_strength = 0.12
for i, learner in enumerate(learners):
df = df_all_ates[df_all_ates['learner'] == learner]
# Jitter x positions for each learner
jitter = (i - (n_learners - 1) / 2) * jitter_strength
x_jittered = df['treatment_level'] + jitter
plt.errorbar(
x_jittered,
df['ate'],
yerr=[df['ate'] - df['ci_lower'], df['ci_upper'] - df['ate']],
fmt='o',
capsize=5,
capthick=2,
ecolor=palette[i % len(palette)],
color=palette[i % len(palette)],
label=f"{learner} ATE ±95% CI",
zorder=2
)
# Get treatment levels for proper line positioning
treatment_levels = sorted(df_all_ates['treatment_level'].unique())
x_range = plt.xlim()
total_width = x_range[1] - x_range[0]
# Add true ATEs as red horizontal lines
for i, level in enumerate(treatment_levels):
# Center each line around its treatment level with a reasonable width
line_width = 0.6 # Width of each horizontal line relative to treatment level spacing
x_center = level
x_start = x_center - line_width/2
x_end = x_center + line_width/2
# Convert to relative coordinates (0-1) for xmin/xmax
xmin_rel = max(0, (x_start - x_range[0]) / total_width)
xmax_rel = min(1, (x_end - x_range[0]) / total_width)
# Use average_ites[level] for the true ATE (treatment levels start from 1 for ATEs)
plt.axhline(y=average_ites[int(level)], color='red', linestyle='-', alpha=0.7,
xmin=xmin_rel, xmax=xmax_rel,
linewidth=3, label='True ATE' if i == 0 else "")
plt.title('Estimated ATE and 95% Confidence Interval by Treatment Level')
plt.xlabel('Treatment Level (vs. 0)')
plt.ylabel('ATE')
plt.xticks(sorted(df_all_ates['treatment_level'].unique()))
plt.legend()
plt.grid(True)
plt.show()

Model Performance Evaluation#
To understand why different methods perform differently, let’s examine the performance of the underlying machine learning models used for the nuisance functions. DoubleML provides access to performance metrics for each component:
RMSE g0: Root Mean Square Error for the outcome model when treatment \(D \neq d\)
RMSE g1: Root Mean Square Error for the outcome model when treatment \(D = d\)
LogLoss m: Logarithmic loss for the propensity score model (treatment assignment prediction)
Better performance on these nuisance functions typically translates to more accurate causal estimates.
[12]:
# Create a comprehensive table with RMSE for g0, g1 and log loss for all learners and treatment levels
performance_results = []
for idx_learner, learner_name in enumerate(learner_dict.keys()):
for idx_treat, treatment_level in enumerate(treatment_levels):
# Get the specific model for this learner and treatment level
model = model_list[idx_learner].modellist[idx_treat]
# Extract performance metrics from nuisance_loss
if model.nuisance_loss is not None:
# RMSE for g0 (outcome model for treatment level != d)
rmse_g0 = model.nuisance_loss['ml_g_d_lvl0'][0][0]
# RMSE for g1 (outcome model for treatment level = d)
rmse_g1 = model.nuisance_loss['ml_g_d_lvl1'][0][0]
# Log loss for propensity score model
logloss_m = model.nuisance_loss['ml_m'][0][0]
else:
rmse_g0 = rmse_g1 = logloss_m = None
# Store results
performance_results.append({
'Learner': learner_name,
'Treatment_Level': treatment_level,
'RMSE_g0': rmse_g0,
'RMSE_g1': rmse_g1,
'LogLoss_m': logloss_m
})
# Create DataFrame and display as a nicely formatted table
df_performance = pd.DataFrame(performance_results)
# Round values for better readability
df_performance['RMSE_g0'] = df_performance['RMSE_g0'].round(4)
df_performance['RMSE_g1'] = df_performance['RMSE_g1'].round(4)
df_performance['LogLoss_m'] = df_performance['LogLoss_m'].round(4)
print("\n\nRMSE g0 by Learner and Treatment Level:")
print("=" * 80)
pivot_rmse_g0 = df_performance.pivot(index='Learner', columns='Treatment_Level', values='RMSE_g0')
print(pivot_rmse_g0.to_string())
print("\n\nRMSE g1 by Learner and Treatment Level:")
print("=" * 80)
pivot_rmse_g1 = df_performance.pivot(index='Learner', columns='Treatment_Level', values='RMSE_g1')
print(pivot_rmse_g1.to_string())
print("\n\nLogLoss m by Learner and Treatment Level:")
print("=" * 80)
pivot_logloss = df_performance.pivot(index='Learner', columns='Treatment_Level', values='LogLoss_m')
print(pivot_logloss.to_string())
RMSE g0 by Learner and Treatment Level:
================================================================================
Treatment_Level 1.0 2.0 3.0 4.0 5.0
Learner
LightGBM 15.2287 11.1751 15.2422 15.1509 14.8526
Linear 21.1795 17.4716 20.3264 20.9783 21.0718
RandomForest 14.3663 11.4107 14.5979 14.5166 14.3334
TabPFN 10.0373 2.8767 10.5702 9.9764 10.2091
RMSE g1 by Learner and Treatment Level:
================================================================================
Treatment_Level 1.0 2.0 3.0 4.0 5.0
Learner
LightGBM 17.2116 31.2223 20.1003 18.2962 15.7370
Linear 16.6144 30.1755 21.2086 18.4788 16.9636
RandomForest 16.0477 25.8403 18.6678 19.0762 15.5213
TabPFN 3.3852 16.5840 4.8130 6.5550 4.9701
LogLoss m by Learner and Treatment Level:
================================================================================
Treatment_Level 1.0 2.0 3.0 4.0 5.0
Learner
LightGBM 0.5732 0.5118 0.4976 0.5644 0.5650
Linear 0.4822 0.4352 0.4252 0.4460 0.4660
RandomForest 0.5277 0.4642 0.4364 0.4860 0.5101
TabPFN 0.4776 0.4347 0.4320 0.4462 0.4645
Performance Summary and Insights#
Let’s summarize the average performance across all treatment levels to identify the best-performing methods:
[13]:
# Best performing learners for each metric
print("\nBest performing learners (averaged across treatment levels):")
print("-" * 60)
# Calculate average metrics across treatment levels for each learner
summary_stats = df_performance.groupby('Learner')[['RMSE_g0', 'RMSE_g1', 'LogLoss_m']].mean().round(4)
print(summary_stats)
Best performing learners (averaged across treatment levels):
------------------------------------------------------------
RMSE_g0 RMSE_g1 LogLoss_m
Learner
LightGBM 14.3299 20.5135 0.5424
Linear 20.2055 20.6882 0.4509
RandomForest 13.8450 19.0307 0.4849
TabPFN 8.7339 7.2615 0.4510
Key Takeaways#
This example demonstrates several important findings about using TabPFN for causal inference.
Outcome modeling: TabPFN significantly outperforms traditional methods for both g0 and g1 functions, with much lower RMSE values
Causal estimates: The superior nuisance function performance translates to more accurate APO and ATE estimates
No hyperparameter tuning: TabPFN achieves these results without any model-specific tuning