MLFlow + Optuna: Parallel hyper-parameter optimization and logging

Optuna is a Python library that allows to easily optimize hyper-parameters of machine learning models. MLFlow is a tool which can be used to keep track of experiments. In this post I want to show how to use them together: Use Optuna to find optimal hyper-parameters and MLFlow to keep track of each hyper-parameter candidate (Optuna trial).

I will create one MLFlow run for the overall Optuna study and one nested run for each trial. Trials will run in parallel. Using the default MLFlow fluent interface does not work properly when using multiple threads in parallel because you will see errors like this:

mlflow.exceptions.MlflowException: Changing param values is not allowed. Param with key=’x’ was already logged with value=’4.826018001260979′ for run ID=’664a3b7001b04fcdb132c351238a8cf4′. Attempted logging new value ‘4.799057323848487’.

This error is shown if you use the “standard mlflow approach”:

import optuna
import mlflow

def objective(trial):
    with mlflow.start_run(nested=True):                # Race condition possible
        x = trial.suggest_float("x", -10.0, 10.0)
        mlflow.log_param("x", x)                       # Race condition possible
        val = x**2
        mlflow.log_metric("xsq", val)                  # Race condition possible
        return val

with mlflow.start_run():
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=50, n_jobs=2)

This is because the fluent interface (mlflow.*) is internally using a stack of active runs that is not thread-safe. When using multiple threads and each thread logs parameters or metrics, these calls can interfere and you see the error above.

Side note: For simplicity the objective function here is just minimizing x^2 and not really training a model. In order to do real hyper-parameter minimization, you need to train and evaluate your model given the trial hyper-parameters in the objective function and return the error metric.

To avoid this error we need to stop using the mlflow.* interface directly and instead start using the tracking API using MlFlowClient. Here is how it works:

import optuna
from mlflow.tracking import MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID

def get_objective(parent_run_id):
    # get an objective function for optuna that creates nested MLFlow runs

    def objective(trial):
        trial_run = client.create_run(
            experiment_id=experiment,
            tags={
                MLFLOW_PARENT_RUN_ID: parent_run_id
            }
        )

        x = trial.suggest_float("x", -10.0, 10.0)
        client.log_param(trial_run.info.run_id, "x", x)
        val =  x**2
        client.log_metric(trial_run.info.run_id, "xsq", val)
        return val
    
    return objective

client = MlflowClient()
experiment_name = "min_x_sq"
try:
    experiment = client.create_experiment(experiment_name)
except:
    experiment = client.get_experiment_by_name(experiment_name).experiment_id

study_run = client.create_run(experiment_id=experiment)
study_run_id = study_run.info.run_id

study = optuna.create_study(direction="minimize")
study.optimize(get_objective(study_run_id), n_trials=50, n_jobs=2)

client.log_param(study_run_id, "best_x", study.best_trial.params["x"])
client.log_metric(study_run_id, "best_xsq", study.best_value)

Creating nested runs using the MlFlowClient is a bit more work, but if you look at what mlflow.start_run(nested=True) is doing internally, it is just the same: Setting the MLFLOW_PARENT_RUN_ID tag to the parent run (study run).

Since we are not using mlflow.start_run anymore but instead create runs manually, no race conditions occur anymore. Unfortunately there is a lot more code that we need to write (e.g. each log_* call needs to know which run we refer). In my opinion it would be much easier if MLFlow provided a way to call log_* directly on the run objects and create child runs using something like run.create_nested_run(). Maybe that is something they add in the future.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.