Skip to content

Commit 67c711b

Browse files
authored
Merge pull request #269 from DoubleML/dev
Update workflow with tuning
2 parents e9b2def + 5a1dc36 commit 67c711b

File tree

1 file changed

+87
-6
lines changed

1 file changed

+87
-6
lines changed

doc/workflow/workflow.rst

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ We can directly pass the parameters during initialization of the learner objects
150150
Because we have a binary treatment variable, we can use a classification learner for the corresponding nuisance part.
151151
We use a regression learner for the continuous outcome variable net financial assets.
152152

153+
Hyperparameter tuning of the machine learning models can be performed in Step 5. before estimation.
154+
153155
.. tab-set::
154156

155157
.. tab-item:: Python
@@ -249,10 +251,89 @@ the dml algorithm (:ref:`DML1 vs. DML2 <algorithms>`) and the score function (:r
249251
score = 'partialling out',
250252
dml_procedure = 'dml2')
251253

252-
5. Estimation
254+
255+
5. Hyperparameter Tuning
256+
------------------------
257+
258+
As (optional) step before estimation, we can perform hyperparameter tuning of the machine learning models.
259+
:ref:`DoubleML <doubleml_package>` for Python supports hyperparameter tuning via `Optuna <https://optuna.org/>`_ and
260+
the R version relies on the `mlr3tuning <https://mlr3tuning.mlr-org.com/>`_ package.
261+
For more details, please refer to the :ref:`hyperparameter tuning (Python) <py_tune_params>` and :ref:`hyperparameter tuning (R) <r_tune_params>`
262+
sections in the documentation.
263+
264+
.. tab-set::
265+
266+
.. tab-item:: Python
267+
:sync: py
268+
269+
.. ipython:: python
270+
271+
import optuna
272+
273+
# define search spaces for hyperparameters
274+
def ml_l_params(trial):
275+
return {
276+
'n_estimators': trial.suggest_int('n_estimators', 50, 200, step=50),
277+
'max_depth': trial.suggest_int('max_depth', 3, 10),
278+
'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
279+
}
280+
281+
def ml_m_params(trial):
282+
return {
283+
'n_estimators': trial.suggest_int('n_estimators', 50, 200, step=50),
284+
'max_depth': trial.suggest_int('max_depth', 3, 10),
285+
'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
286+
}
287+
288+
param_space = {
289+
'ml_l': ml_l_params,
290+
'ml_m': ml_m_params
291+
}
292+
293+
optuna_settings = {
294+
'n_trials': 10, # small number for illustration purposes
295+
'show_progress_bar': True,
296+
'verbosity': optuna.logging.WARNING, # Suppress Optuna logs
297+
}
298+
299+
# Hyperparameter tuning
300+
dml_plr_tree.tune_ml_models(ml_param_space=param_space,
301+
optuna_settings=optuna_settings,
302+
)
303+
304+
.. tab-item:: R
305+
:sync: r
306+
307+
.. jupyter-execute::
308+
309+
library(mlr3tuning)
310+
library(paradox)
311+
lgr::get_logger("mlr3")$set_threshold("warn")
312+
lgr::get_logger("bbotk")$set_threshold("warn")
313+
314+
# Define search spaces for hyperparameters
315+
param_grid = list(
316+
"ml_l" = ps(mtry = p_int(lower = 2, upper = 5),
317+
max.depth = p_int(lower = 3, upper = 7)),
318+
"ml_m" = ps(mtry = p_int(lower = 2, upper = 5),
319+
max.depth = p_int(lower = 3, upper = 7))
320+
)
321+
322+
tune_settings = list(
323+
terminator = trm("evals", n_evals = 10),
324+
algorithm = tnr("grid_search", resolution = 5),
325+
measure = list("ml_l" = msr("regr.mse"),
326+
"ml_m" = msr("classif.ce"))
327+
)
328+
329+
# Hyperparameter tuning
330+
dml_plr_forest$tune(param_set = param_grid, tune_settings = tune_settings, tune_on_folds = FALSE)
331+
332+
333+
6. Estimation
253334
-------------
254335

255-
We perform estimation in Step 5. In this step, the cross-fitting algorithm is executed such that the predictions
336+
We perform estimation in Step 6. In this step, the cross-fitting algorithm is executed such that the predictions
256337
in the score are computed. As an output, users can access the coefficient estimates and standard errors either via the
257338
corresponding fields or via a summary.
258339

@@ -292,10 +373,10 @@ corresponding fields or via a summary.
292373
# Summary
293374
dml_plr_forest$summary()
294375

295-
6. Inference
376+
7. Inference
296377
------------
297378

298-
In Step 6., we can perform further inference methods and finally interpret our findings. For example, we can set up confidence intervals
379+
In Step 7., we can perform further inference methods and finally interpret our findings. For example, we can set up confidence intervals
299380
or, in case multiple causal parameters are estimated, adjust the analysis for multiple testing. :ref:`DoubleML <doubleml_package>`
300381
supports various approaches to perform :ref:`valid simultaneous inference <sim_inf>`
301382
which are partly based on a multiplier bootstrap.
@@ -342,10 +423,10 @@ If we did not control for the confounding variables, the average treatment effec
342423
dml_plr_forest$confint(joint = TRUE)
343424

344425

345-
7. Sensitivity Analysis
426+
8. Sensitivity Analysis
346427
------------------------
347428

348-
In Step 7., we can analyze the sensitivity of the estimated parameters. In the :ref:`plr-model` the causal interpretation
429+
In Step 8., we can analyze the sensitivity of the estimated parameters. In the :ref:`plr-model` the causal interpretation
349430
relies on conditional exogeneity, which requires to control for confounding variables. The :ref:`DoubleML <doubleml_package>` python package
350431
implements :ref:`sensitivity` with respect to omitted confounders.
351432

0 commit comments

Comments
 (0)