Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GENERAL SUPPORT]: Can't load from a checkpoint. #3440

Closed
1 task done
Expertium opened this issue Feb 28, 2025 · 1 comment
Closed
1 task done

[GENERAL SUPPORT]: Can't load from a checkpoint. #3440

Expertium opened this issue Feb 28, 2025 · 1 comment
Labels
question Further information is requested

Comments

@Expertium
Copy link

Expertium commented Feb 28, 2025

Question

I get Experiment not set on Ax client. Must first call load_experiment or create_experiment to use handler functions when trying to load from a checkpoint. See the code snippet with a toy example (run it twice).
I'm using version 0.5.0.

Please provide any relevant code snippet if applicable.

import os
import numpy as np
from ax.service.ax_client import AxClient

def train_model(parameters_dict: dict):
    # Toy function
    return 0.5 + float(np.random.normal(0, 0.025, 1)) + parameters_dict.get('a') / 3000

parameters = [{'name': 'a', 'type': 'range', 'bounds': [1, 1000], 'log_scale': True, 'value_type': 'int'}]

ax_seed = 42
total_trials = 2
checkpoint_filename = f'{ax_seed}_test.json'
ax = AxClient(random_seed=ax_seed)

if os.path.isfile(checkpoint_filename):
    try:
        ax.load_from_json_file(checkpoint_filename)
        completed_trials = len(ax.experiment.trials)
        print(f'Successfully loaded experiment with {completed_trials} completed trials')
    except Exception as e:
        print(f'Error loading checkpoint: {e}')
        quit()
else:
    ax.create_experiment(name='Test', parameters=parameters)
    completed_trials = 0
    ax.save_to_json_file(checkpoint_filename)

for i in range(completed_trials, total_trials):
    print(f'Starting trial {i + 1}/{total_trials}')
    parameters, trial_index = ax.get_next_trial()
    ax.complete_trial(trial_index=trial_index, raw_data=train_model(parameters))
    # Backup after each trial
    ax.save_to_json_file(checkpoint_filename)

print('')
best_parameters, values = ax.get_best_parameters()
print(f'Best parameter={best_parameters}')
print(f'Best value={values}')

Code of Conduct

  • I agree to follow this Ax's Code of Conduct
@Expertium Expertium added the question Further information is requested label Feb 28, 2025
@Expertium
Copy link
Author

Nevermind, this was the problem:
ax.load_from_json_file(checkpoint_filename)

Fix:
ax = AxClient.load_from_json_file(checkpoint_filename)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant