9. Sample-splitting, cross-fitting and repeated cross-fitting

Sample-splitting and the application of cross-fitting is a central part of double/debiased machine learning (DML). For all DML models DoubleMLPLR, DoubleMLPLIV, DoubleMLIRM, and DoubleMLIIVM, the specification is done via the parameters n_folds and n_rep. Advanced resampling techniques can be obtained via the boolean parameters draw_sample_splitting and apply_cross_fitting as well as the methods draw_sample_splitting() and set_sample_splitting().

As an example we consider a partially linear regression model (PLR) implemented in DoubleMLPLR.

In [1]: import doubleml as dml

In [2]: import numpy as np

In [3]: from doubleml.datasets import make_plr_CCDDHNR2018

In [4]: from sklearn.ensemble import RandomForestRegressor

In [5]: from sklearn.base import clone

In [6]: learner = RandomForestRegressor(n_estimators=100, max_features=20, max_depth=5, min_samples_leaf=2)

In [7]: ml_g = clone(learner)

In [8]: ml_m = clone(learner)

In [9]: np.random.seed(1234)

In [10]: obj_dml_data = make_plr_CCDDHNR2018(alpha=0.5, n_obs=100)
 1library(DoubleML)
 2library(mlr3)
 3lgr::get_logger("mlr3")$set_threshold("warn")
 4library(mlr3learners)
 5library(data.table)
 6
 7learner = lrn("regr.ranger", num.trees = 100, mtry = 20, min.node.size = 2, max.depth = 5)
 8ml_g = learner
 9ml_m = learner
10data = make_plr_CCDDHNR2018(alpha=0.5, n_obs=100, return_type = "data.table")
11obj_dml_data = DoubleMLData$new(data,
12                                y_col = "y",
13                                d_cols = "d")

9.1. Cross-fitting with \(K\) folds

The default setting is n_folds = 5 and n_rep = 1, i.e., \(K=5\) folds and no repeated cross-fitting.

In [11]: dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m, n_folds = 5, n_rep = 1)

In [12]: print(dml_plr_obj.n_folds)
5

In [13]: print(dml_plr_obj.n_rep)
1
1dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m, n_folds = 5, n_rep = 1)
2print(dml_plr_obj$n_folds)
3print(dml_plr_obj$n_rep)
[1] 5
[1] 1

During the initialization of a DML model like DoubleMLPLR a \(K\)-fold random partition \((I_k)_{k=1}^{K}\) of observation indices is generated. The \(K\)-fold random partition is stored in the smpls attribute of the DML model object.

In [14]: print(dml_plr_obj.smpls)
[[(array([ 0,  2,  3,  4,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 19,
       21, 22, 23, 24, 25, 27, 28, 29, 31, 32, 34, 35, 36, 37, 38, 40, 44,
       46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 58, 59, 60, 61, 62, 63, 64,
       65, 67, 68, 69, 70, 71, 72, 74, 75, 76, 77, 79, 80, 81, 82, 84, 85,
       86, 88, 89, 90, 91, 92, 93, 94, 95, 97, 98, 99]), array([ 1,  5, 18, 20, 26, 30, 33, 39, 41, 42, 43, 45, 56, 57, 66, 73, 78,
       83, 87, 96])), (array([ 0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 26, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40,
       41, 42, 43, 44, 45, 46, 48, 53, 54, 55, 56, 57, 58, 60, 61, 62, 63,
       65, 66, 69, 70, 71, 72, 73, 77, 78, 79, 80, 81, 82, 83, 84, 85, 87,
       88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]), array([ 9, 23, 24, 25, 27, 28, 35, 47, 49, 50, 51, 52, 59, 64, 67, 68, 74,
       75, 76, 86])), (array([ 0,  1,  2,  3,  5,  6,  7,  9, 10, 11, 12, 14, 16, 17, 18, 20, 21,
       22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 33, 35, 38, 39, 40, 41, 42,
       43, 44, 45, 47, 48, 49, 50, 51, 52, 53, 56, 57, 58, 59, 60, 61, 62,
       63, 64, 65, 66, 67, 68, 69, 71, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       83, 84, 86, 87, 88, 89, 90, 91, 92, 96, 98, 99]), array([ 4,  8, 13, 15, 19, 32, 34, 36, 37, 46, 54, 55, 70, 72, 82, 85, 93,
       94, 95, 97])), (array([ 0,  1,  3,  4,  5,  6,  8,  9, 13, 14, 15, 17, 18, 19, 20, 21, 22,
       23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 41,
       42, 43, 44, 45, 46, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 59, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 82,
       83, 85, 86, 87, 88, 89, 93, 94, 95, 96, 97, 98]), array([ 2,  7, 10, 11, 12, 16, 38, 40, 48, 58, 60, 61, 62, 63, 81, 84, 90,
       91, 92, 99])), (array([ 1,  2,  4,  5,  7,  8,  9, 10, 11, 12, 13, 15, 16, 18, 19, 20, 23,
       24, 25, 26, 27, 28, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 45, 46, 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61,
       62, 63, 64, 66, 67, 68, 70, 72, 73, 74, 75, 76, 78, 81, 82, 83, 84,
       85, 86, 87, 90, 91, 92, 93, 94, 95, 96, 97, 99]), array([ 0,  3,  6, 14, 17, 21, 22, 29, 31, 44, 53, 65, 69, 71, 77, 79, 80,
       88, 89, 98]))]]
1dml_plr_obj$smpls
  1. $train_ids
      1. 9
      2. 13
      3. 16
      4. 20
      5. 24
      6. 34
      7. 42
      8. 51
      9. 52
      10. 54
      11. 62
      12. 63
      13. 67
      14. 72
      15. 73
      16. 80
      17. 85
      18. 89
      19. 90
      20. 96
      21. 4
      22. 5
      23. 19
      24. 30
      25. 32
      26. 33
      27. 35
      28. 38
      29. 40
      30. 41
      31. 48
      32. 53
      33. 58
      34. 64
      35. 68
      36. 82
      37. 83
      38. 87
      39. 88
      40. 97
      41. 3
      42. 11
      43. 12
      44. 15
      45. 17
      46. 25
      47. 39
      48. 44
      49. 46
      50. 47
      51. 61
      52. 66
      53. 71
      54. 75
      55. 77
      56. 78
      57. 86
      58. 91
      59. 92
      60. 94
      61. 1
      62. 7
      63. 8
      64. 10
      65. 14
      66. 21
      67. 22
      68. 23
      69. 29
      70. 37
      71. 50
      72. 55
      73. 56
      74. 57
      75. 59
      76. 60
      77. 69
      78. 74
      79. 81
      80. 100
      1. 2
      2. 6
      3. 18
      4. 26
      5. 27
      6. 28
      7. 31
      8. 36
      9. 43
      10. 45
      11. 49
      12. 65
      13. 70
      14. 76
      15. 79
      16. 84
      17. 93
      18. 95
      19. 98
      20. 99
      21. 4
      22. 5
      23. 19
      24. 30
      25. 32
      26. 33
      27. 35
      28. 38
      29. 40
      30. 41
      31. 48
      32. 53
      33. 58
      34. 64
      35. 68
      36. 82
      37. 83
      38. 87
      39. 88
      40. 97
      41. 3
      42. 11
      43. 12
      44. 15
      45. 17
      46. 25
      47. 39
      48. 44
      49. 46
      50. 47
      51. 61
      52. 66
      53. 71
      54. 75
      55. 77
      56. 78
      57. 86
      58. 91
      59. 92
      60. 94
      61. 1
      62. 7
      63. 8
      64. 10
      65. 14
      66. 21
      67. 22
      68. 23
      69. 29
      70. 37
      71. 50
      72. 55
      73. 56
      74. 57
      75. 59
      76. 60
      77. 69
      78. 74
      79. 81
      80. 100
      1. 2
      2. 6
      3. 18
      4. 26
      5. 27
      6. 28
      7. 31
      8. 36
      9. 43
      10. 45
      11. 49
      12. 65
      13. 70
      14. 76
      15. 79
      16. 84
      17. 93
      18. 95
      19. 98
      20. 99
      21. 9
      22. 13
      23. 16
      24. 20
      25. 24
      26. 34
      27. 42
      28. 51
      29. 52
      30. 54
      31. 62
      32. 63
      33. 67
      34. 72
      35. 73
      36. 80
      37. 85
      38. 89
      39. 90
      40. 96
      41. 3
      42. 11
      43. 12
      44. 15
      45. 17
      46. 25
      47. 39
      48. 44
      49. 46
      50. 47
      51. 61
      52. 66
      53. 71
      54. 75
      55. 77
      56. 78
      57. 86
      58. 91
      59. 92
      60. 94
      61. 1
      62. 7
      63. 8
      64. 10
      65. 14
      66. 21
      67. 22
      68. 23
      69. 29
      70. 37
      71. 50
      72. 55
      73. 56
      74. 57
      75. 59
      76. 60
      77. 69
      78. 74
      79. 81
      80. 100
      1. 2
      2. 6
      3. 18
      4. 26
      5. 27
      6. 28
      7. 31
      8. 36
      9. 43
      10. 45
      11. 49
      12. 65
      13. 70
      14. 76
      15. 79
      16. 84
      17. 93
      18. 95
      19. 98
      20. 99
      21. 9
      22. 13
      23. 16
      24. 20
      25. 24
      26. 34
      27. 42
      28. 51
      29. 52
      30. 54
      31. 62
      32. 63
      33. 67
      34. 72
      35. 73
      36. 80
      37. 85
      38. 89
      39. 90
      40. 96
      41. 4
      42. 5
      43. 19
      44. 30
      45. 32
      46. 33
      47. 35
      48. 38
      49. 40
      50. 41
      51. 48
      52. 53
      53. 58
      54. 64
      55. 68
      56. 82
      57. 83
      58. 87
      59. 88
      60. 97
      61. 1
      62. 7
      63. 8
      64. 10
      65. 14
      66. 21
      67. 22
      68. 23
      69. 29
      70. 37
      71. 50
      72. 55
      73. 56
      74. 57
      75. 59
      76. 60
      77. 69
      78. 74
      79. 81
      80. 100
      1. 2
      2. 6
      3. 18
      4. 26
      5. 27
      6. 28
      7. 31
      8. 36
      9. 43
      10. 45
      11. 49
      12. 65
      13. 70
      14. 76
      15. 79
      16. 84
      17. 93
      18. 95
      19. 98
      20. 99
      21. 9
      22. 13
      23. 16
      24. 20
      25. 24
      26. 34
      27. 42
      28. 51
      29. 52
      30. 54
      31. 62
      32. 63
      33. 67
      34. 72
      35. 73
      36. 80
      37. 85
      38. 89
      39. 90
      40. 96
      41. 4
      42. 5
      43. 19
      44. 30
      45. 32
      46. 33
      47. 35
      48. 38
      49. 40
      50. 41
      51. 48
      52. 53
      53. 58
      54. 64
      55. 68
      56. 82
      57. 83
      58. 87
      59. 88
      60. 97
      61. 3
      62. 11
      63. 12
      64. 15
      65. 17
      66. 25
      67. 39
      68. 44
      69. 46
      70. 47
      71. 61
      72. 66
      73. 71
      74. 75
      75. 77
      76. 78
      77. 86
      78. 91
      79. 92
      80. 94
    $test_ids
      1. 2
      2. 6
      3. 18
      4. 26
      5. 27
      6. 28
      7. 31
      8. 36
      9. 43
      10. 45
      11. 49
      12. 65
      13. 70
      14. 76
      15. 79
      16. 84
      17. 93
      18. 95
      19. 98
      20. 99
      1. 9
      2. 13
      3. 16
      4. 20
      5. 24
      6. 34
      7. 42
      8. 51
      9. 52
      10. 54
      11. 62
      12. 63
      13. 67
      14. 72
      15. 73
      16. 80
      17. 85
      18. 89
      19. 90
      20. 96
      1. 4
      2. 5
      3. 19
      4. 30
      5. 32
      6. 33
      7. 35
      8. 38
      9. 40
      10. 41
      11. 48
      12. 53
      13. 58
      14. 64
      15. 68
      16. 82
      17. 83
      18. 87
      19. 88
      20. 97
      1. 3
      2. 11
      3. 12
      4. 15
      5. 17
      6. 25
      7. 39
      8. 44
      9. 46
      10. 47
      11. 61
      12. 66
      13. 71
      14. 75
      15. 77
      16. 78
      17. 86
      18. 91
      19. 92
      20. 94
      1. 1
      2. 7
      3. 8
      4. 10
      5. 14
      6. 21
      7. 22
      8. 23
      9. 29
      10. 37
      11. 50
      12. 55
      13. 56
      14. 57
      15. 59
      16. 60
      17. 69
      18. 74
      19. 81
      20. 100

For each \(k \in [K] = \lbrace 1, \ldots, K]\) the nuisance ML estimator

\[\hat{\eta}_{0,k} = \hat{\eta}_{0,k}\big((W_i)_{i\not\in I_k}\big)\]

is based on the observations of all other \(k-1\) folds. The values of the two score function components \(\psi_a(W_i; \hat{\eta}_0)\) and \(\psi_b(W_i; \hat{\eta}_0))\) for each observation index \(i \in I_k\) are computed and stored in the attributes psi_a and psi_b.

In [15]: dml_plr_obj.fit();

In [16]: print(dml_plr_obj.psi_a[:5])
[[[-2.88531   ]]

 [[-2.90210293]]

 [[-0.07859105]]

 [[-0.32047806]]

 [[-2.63196468]]]

In [17]: print(dml_plr_obj.psi_b[:5])
[[[ 0.14584976]]

 [[-1.07752306]]

 [[-0.11679041]]

 [[ 0.39065473]]

 [[ 1.14544635]]]
1dml_plr_obj$fit()
2print(dml_plr_obj$psi_a[1:5, ,1])
3print(dml_plr_obj$psi_b[1:5, ,1])
[1] -1.2496054 -0.2520070 -2.7121088 -0.1073410 -0.3584161
[1] -0.303834292  0.472306305  2.117100636  0.114220212  0.006785538

9.2. Repeated cross-fitting with \(K\) folds and \(M\) repetition

Repeated cross-fitting is obtained by choosing a value \(M>1\) for the number of repetition n_rep. It results in \(M\) random \(K\)-fold partitions being drawn.

In [18]: dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m, n_folds = 5, n_rep = 10)

In [19]: print(dml_plr_obj.n_folds)
5

In [20]: print(dml_plr_obj.n_rep)
10
1dml_plr_obj = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m, n_folds = 5, n_rep = 10)
2print(dml_plr_obj$n_folds)
3print(dml_plr_obj$n_rep)
[1] 5
[1] 10

For each of the \(M\) partitions, the nuisance ML models are estimated and score functions computed as described in Cross-fitting with K folds. The resulting values of the score functions are stored in 3-dimensional arrays psi_a and psi_b, where the row index corresponds the observation index \(i \in [N] = \lbrace 1, \ldots, N\rbrace\) and the column index to the partition \(m \in [M] = \lbrace 1, \ldots, M\rbrace\). The third dimension refers to the treatment variable and becomes non-singleton in case of multiple treatment variables.

In [21]: dml_plr_obj.fit();

In [22]: print(dml_plr_obj.psi_a[:5, :, 0])
[[-2.68599367 -3.23380465 -3.87925464 -2.89041468 -3.08674439 -3.17372511
  -4.06162781 -2.71757056 -3.93769627 -4.80729735]
 [-2.88112325 -2.42456107 -2.02910892 -2.04403523 -3.37767737 -2.35320534
  -2.45193279 -2.22980749 -1.9562462  -2.09392654]
 [-0.181359   -0.16799817 -0.07332458 -0.21872073 -0.06599282 -0.11102169
  -0.16396342 -0.50714384 -0.29398956 -0.56135776]
 [-0.04570088 -0.12277897 -0.10397745 -0.13543039 -0.36304868 -0.11301572
  -0.19153314 -0.62534417 -0.02197674 -0.01718614]
 [-2.99590389 -4.12622929 -1.77374567 -2.84585724 -1.47892889 -2.80198843
  -4.7074332  -1.77410084 -3.5422875  -3.20771252]]

In [23]: print(dml_plr_obj.psi_b[:5, :, 0])
[[ 0.04605718  0.40626329  0.74793091  0.18205525 -0.13083083  0.24800699
  -0.15153994  0.02067865  0.17040069  0.40464152]
 [-0.44276955 -1.3415938  -0.50685506 -0.8130893  -0.40504965 -0.5433195
  -0.77235098 -0.97564857 -0.68378606 -0.8890806 ]
 [-0.2085827  -0.15144663 -0.10209256 -0.24270094 -0.20685474 -0.15512408
  -0.11551561 -0.35281149 -0.07560974 -0.20729528]
 [ 0.15462366  0.17402885  0.22050906  0.33823906  0.43836142  0.16201454
   0.20697463  0.53725044  0.08476671  0.10902499]
 [ 0.64042081  1.36752436  0.21184894  0.45545298  0.39017148  0.71515068
   1.00864349  0.91863437  0.93603507  0.55055093]]
1dml_plr_obj$fit()
2print(dml_plr_obj$psi_a[1:5, ,1])
3print(dml_plr_obj$psi_b[1:5, ,1])
           [,1]        [,2]       [,3]        [,4]       [,5]        [,6]
[1,] -0.5125802 -0.56863342 -1.0271681 -0.82759223 -0.7440757 -0.81512989
[2,] -0.1669142 -0.49455845 -0.2881985 -0.07168046 -0.1384273 -0.16363067
[3,] -3.0034116 -2.66582128 -1.8441882 -2.01811503 -2.4410936 -2.28062322
[4,] -0.1004134 -0.01233796 -0.1057360 -0.01057142 -0.1199144 -0.09529150
[5,] -0.3831080 -0.21900322 -0.7224490 -0.23193657 -0.2683936 -0.04342685
           [,7]        [,8]        [,9]      [,10]
[1,] -0.6115190 -0.63107257 -0.44237420 -0.5456893
[2,] -0.1382186 -0.55949465 -0.22134226 -0.1420030
[3,] -2.7942224 -2.06292029 -2.60088576 -2.6564128
[4,] -0.2236185 -0.03518702 -0.07328316 -0.3999448
[5,] -0.3823124 -0.16134755 -0.37728720 -0.1410774
            [,1]        [,2]        [,3]        [,4]         [,5]        [,6]
[1,] -0.44558214  0.06050378 -0.12463530 -0.70458470 -0.165399308 -0.25697123
[2,]  0.11269632  0.62689683  0.40862205  0.20370271  0.092105067  0.26696185
[3,]  2.37977991  2.52331739  1.60097977  1.13752533  1.994174479  1.78051710
[4,]  0.01917427  0.01565279  0.06061683  0.03540183 -0.001243385  0.21251709
[5,] -0.10022706 -0.05061832 -0.03553749 -0.06292934  0.178355224 -0.04545068
            [,7]        [,8]        [,9]       [,10]
[1,]  0.01983182 -0.51317227 -0.43675148 -0.27547781
[2,]  0.28210033  0.38989384  0.35157830  0.35810012
[3,]  2.12462806  2.14148875  2.17242752  2.94644679
[4,] -0.07002637 -0.10186844  0.15986034  0.46970580
[5,] -0.02650921  0.02349533 -0.03803074 -0.04718882

We estimate the causal parameter \(\tilde{\theta}_{0,m}\) for each of the \(M\) partitions with a DML algorithm as described in Double machine learning algorithms. Standard errors are obtained as described in Variance estimation and confidence intervals for a causal parameter of interest. The aggregation of the estimates of the causal parameter and its standard errors is done using the median

\[ \begin{align}\begin{aligned}\tilde{\theta}_{0} &= \text{Median}\big((\tilde{\theta}_{0,m})_{m \in [M]}\big),\\\hat{\sigma} &= \sqrt{\text{Median}\big((\hat{\sigma}_m^2 + (\tilde{\theta}_{0,m} - \tilde{\theta}_{0})^2)_{m \in [M]}\big)}.\end{aligned}\end{align} \]

The estimate of the causal parameter \(\tilde{\theta}_{0}\) is stored in the coef attribute and the asymptotic standard error \(\hat{\sigma}/\sqrt{N}\) in se.

In [24]: print(dml_plr_obj.coef)
[0.49855526]

In [25]: print(dml_plr_obj.se)
[0.07599746]
1print(dml_plr_obj$coef)
2print(dml_plr_obj$se)
        d 
0.3601161 
        d 
0.1294869 

The parameter estimates \((\tilde{\theta}_{0,m})_{m \in [M]}\) and asymptotic standard errors \((\hat{\sigma}_m/\sqrt{N})_{m \in [M]}\) for each of the \(M\) partitions are stored in the attributes _all_coef and _all_se, respectively.

In [26]: print(dml_plr_obj._all_coef)
[[0.4912799  0.47383624 0.47971449 0.51116967 0.50583062 0.54217549
  0.52744838 0.47070297 0.52106584 0.4732103 ]]

In [27]: print(dml_plr_obj._all_se)
[[0.07574142 0.0764744  0.07045978 0.08555445 0.08032833 0.0760157
  0.07579885 0.0761749  0.07356857 0.07122361]]
1print(dml_plr_obj$all_coef)
2print(dml_plr_obj$all_se)
          [,1]     [,2]      [,3]      [,4]      [,5]      [,6]      [,7]
[1,] 0.3491986 0.364952 0.2905704 0.3991302 0.4071478 0.3529034 0.3691178
          [,8]      [,9]     [,10]
[1,] 0.3552802 0.3496076 0.4121474
          [,1]      [,2]      [,3]      [,4]      [,5]      [,6]      [,7]
[1,] 0.1297423 0.1277344 0.1310055 0.1291675 0.1254188 0.1289139 0.1387186
          [,8]      [,9]     [,10]
[1,] 0.1200295 0.1349347 0.1338285

9.3. Externally provide a sample splitting / partition

All DML models allow a partition to be provided externally via the method set_sample_splitting(). In Python we can for example use the K-Folds cross-validator of sklearn KFold in order to generate a sample splitting and provide it to the DML model object. Note that by setting draw_sample_splitting = False one can prevent that a partition is drawn during initialization of the DML model object. The following calls are equivalent. In the first sample code, we use the standard interface and draw the sample-splitting with \(K=4\) folds during initialization of the DoubleMLPLR object.

In [28]: np.random.seed(314)

In [29]: dml_plr_obj_internal = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m, n_folds = 4)

In [30]: print(dml_plr_obj_internal.fit().summary)
     coef   std err         t         P>|t|     2.5 %    97.5 %
d  0.5005  0.086388  5.793628  6.888187e-09  0.331182  0.669817
1set.seed(314)
2dml_plr_obj_internal = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m, n_folds = 4)
3dml_plr_obj_internal$fit()
4dml_plr_obj_internal$summary()
[1] "Estimates and significance testing of the effect of target variables"
  Estimate. Std. Error t value Pr(>|t|)    
d    0.3698     0.1108   3.339 0.000842 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


In the second sample code, we use the K-Folds cross-validator of sklearn KFold and set the partition via the set_sample_splitting() method.

In [31]: dml_plr_obj_external = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m, draw_sample_splitting = False)

In [32]: from sklearn.model_selection import KFold

In [33]: np.random.seed(314)

In [34]: kf = KFold(n_splits=4, shuffle=True)

In [35]: smpls = [(train, test) for train, test in kf.split(obj_dml_data.x)]

In [36]: dml_plr_obj_external.set_sample_splitting(smpls);

In [37]: print(dml_plr_obj_external.fit().summary)
     coef   std err         t         P>|t|     2.5 %    97.5 %
d  0.5005  0.086388  5.793628  6.888187e-09  0.331182  0.669817
 1dml_plr_obj_external = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m, draw_sample_splitting = FALSE)
 2
 3set.seed(314)
 4# set up a task and cross-validation resampling scheme in mlr3
 5my_task = Task$new("help task", "regr", data)
 6my_sampling = rsmp("cv", folds = 4)$instantiate(my_task)
 7
 8train_ids = lapply(1:4, function(x) my_sampling$train_set(x))
 9test_ids = lapply(1:4, function(x) my_sampling$test_set(x))
10smpls = list(list(train_ids = train_ids, test_ids = test_ids))
11
12dml_plr_obj_external$set_sample_splitting(smpls)
13dml_plr_obj_external$fit()
14dml_plr_obj_external$summary()
[1] "Estimates and significance testing of the effect of target variables"
  Estimate. Std. Error t value Pr(>|t|)    
d    0.3698     0.1108   3.339 0.000842 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


9.4. Sample-splitting without cross-fitting

The boolean flag apply_cross_fitting allows to estimate DML models without applying cross-fitting. It results in randomly splitting the sample into two parts. The first half of the data is used for the estimation of the nuisance ML models and the second half for estimating the causal parameter. Note that cross-fitting performs well empirically and is recommended to remove bias induced by overfitting, see also Sample splitting to remove bias induced by overfitting.

In [38]: np.random.seed(314)

In [39]: dml_plr_obj_external = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m,
   ....:                                        n_folds = 2, apply_cross_fitting = False)
   ....: 

In [40]: print(dml_plr_obj_external.fit().summary)
       coef   std err         t     P>|t|     2.5 %   97.5 %
d  0.559242  0.118328  4.726196  0.000002  0.327323  0.79116
1dml_plr_obj_external = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m,
2                                       n_folds = 2, apply_cross_fitting = FALSE)
3dml_plr_obj_external$fit()
4dml_plr_obj_external$summary()
[1] "Estimates and significance testing of the effect of target variables"
  Estimate. Std. Error t value Pr(>|t|)  
d    0.3130     0.1619   1.934   0.0532 .
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


Note, that in order to split data unevenly into train and test sets the interface to externally set the sample splitting via set_sample_splitting() needs to be applied, like for example:

In [41]: np.random.seed(314)

In [42]: dml_plr_obj_external = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m,
   ....:                                        n_folds = 2, apply_cross_fitting = False, draw_sample_splitting = False)
   ....: 

In [43]: from sklearn.model_selection import train_test_split

In [44]: smpls = train_test_split(np.arange(obj_dml_data.n_obs), train_size=0.8)

In [45]: dml_plr_obj_external.set_sample_splitting(tuple(smpls));

In [46]: print(dml_plr_obj_external.fit().summary)
      coef   std err         t     P>|t|     2.5 %    97.5 %
d  0.47396  0.134194  3.531894  0.000413  0.210944  0.736976
 1dml_plr_obj_external = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m,
 2                                        n_folds = 2, apply_cross_fitting = FALSE,
 3                                        draw_sample_splitting = FALSE)
 4
 5set.seed(314)
 6# set up a task and cross-validation resampling scheme in mlr3
 7my_task = Task$new("help task", "regr", data)
 8my_sampling = rsmp("holdout", ratio = 0.8)$instantiate(my_task)
 9
10train_ids = list(my_sampling$train_set(1))
11test_ids = list(my_sampling$test_set(1))
12smpls = list(list(train_ids = train_ids, test_ids = test_ids))
13
14dml_plr_obj_external$set_sample_splitting(smpls)
15dml_plr_obj_external$fit()
16dml_plr_obj_external$summary()
[1] "Estimates and significance testing of the effect of target variables"
  Estimate. Std. Error t value Pr(>|t|)    
d    0.6383     0.1784   3.578 0.000347 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


9.5. Estimate DML models without sample-splitting

The implementation of the DML models allows the estimation without sample splitting, i.e., all observations are used for learning the nuisance models as well as for the estimation of the causal parameter. Note that this approach usually results in a bias and is therefore not recommended without appropriate theoretical justification, see also Sample splitting to remove bias induced by overfitting.

In [47]: np.random.seed(314)

In [48]: dml_plr_no_split = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m,
   ....:                                    n_folds = 1, apply_cross_fitting = False)
   ....: 

In [49]: print(dml_plr_obj_external.fit().summary)
       coef   std err        t    P>|t|     2.5 %    97.5 %
d  0.508076  0.125342  4.05353  0.00005  0.262411  0.753741
1dml_plr_no_split = DoubleMLPLR$new(obj_dml_data, ml_g, ml_m,
2                                   n_folds = 1, apply_cross_fitting = FALSE)
3
4set.seed(314)
5dml_plr_no_split$fit()
6dml_plr_no_split$summary()
[1] "Estimates and significance testing of the effect of target variables"
  Estimate. Std. Error t value Pr(>|t|)   
d    0.3777     0.1344    2.81  0.00495 **
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1