doubleml.utils.DoubleMLPolicyTree#
- class doubleml.utils.DoubleMLPolicyTree(orth_signal, features, depth=2, **tree_params)#
Policy Tree fitting for DoubleML. Currently avaivable for IRM models.
- Parameters:
orth_signal (
numpy.array
) – The orthogonal signal to be predicted. Has to be of shape(n_obs,)
, wheren_obs
is the number of observations.features (
pandas.DataFrame
) – The covariates for estimating the policy tree. Has to have the shape(n_obs, d)
, wheren_obs
is the number of observations andd
is the number of predictors.depth (int) – The depth of the policy tree that will be built. Default is
2
.**tree_params (dict) – Parameters that are forwarded to the
sklearn.tree.DecisionTreeClassifier
. Note that by default we perform minimal pruning by setting theccp_alpha = 0.01
andmin_samples_leaf = 8
. This can be adjusted.
Methods
fit
()Estimate DoubleMLPolicyTree models.
Plots the DoubleMLPolicyTree.
predict
(features)Predicts policy based on the DoubleMLPolicyTree.
Attributes
features
Covariates.
orth_signal
Orthogonal signal.
policy_tree
Policy tree model.
summary
A summary for the policy tree.
- DoubleMLPolicyTree.predict(features)#
Predicts policy based on the DoubleMLPolicyTree.
- Parameters:
features (
pandas.DataFrame
) – The covariates for predicting based on the policy tree. Has to have the shape(n_obs, d)
, wheren_obs
is the number of observations andd
is the number of predictors. Has to have the identical keys as the original covariates.- Returns:
self
- Return type: