Skip to content

Commit

Permalink
Fix bugs in Optuna integration with the prompt2model demo script (#374)
Browse files Browse the repository at this point in the history
* Fix bugs in Optuna integration

* Lint
  • Loading branch information
viswavi authored Oct 31, 2023
1 parent cc4995a commit 13a2049
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
4 changes: 2 additions & 2 deletions prompt2model/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
MAX_SUPPORTED_BATCH_SIZE = 4

DEFAULT_HYPERPARAMETERS_SPACE = {
"min_num_train_epochs": 10,
"max_num_train_epochs": 20,
"min_num_train_epochs": 5,
"max_num_train_epochs": 15,
"save_strategy": ["no"],
"evaluation_strategy": ["no"],
"per_device_train_batch_size": MAX_SUPPORTED_BATCH_SIZE,
Expand Down
32 changes: 23 additions & 9 deletions prompt2model_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def parse_model_size_limit(line: str, default_size=3e9) -> float:
return default_size
model_units = {"B": 1e0, "KB": 1e3, "MB": 1e6, "GB": 1e9, "TB": 1e12, "PB": 1e15}
unit_disambiguations = {
"B": ["b", "bytes"],
"KB": ["Kb", "kb", "kilobytes"],
"MB": ["Mb", "mb", "megabytes"],
"GB": ["Gb", "gb", "gigabytes"],
"TB": ["Tb", "tb", "terabytes"],
"PB": ["Pb", "pb", "petabytes"],
"B": ["b", "bytes"],
}
unit_matched = False
for unit, disambiguations in unit_disambiguations.items():
Expand All @@ -110,6 +110,7 @@ def parse_model_size_limit(line: str, default_size=3e9) -> float:
numerical_part = line.strip()[: -len(unit_name)].strip()
else:
numerical_part = line.strip()

if not str.isdecimal(numerical_part):
raise ValueError(
"Invalid input. Please enter a number (integer " + "or number with units)."
Expand Down Expand Up @@ -368,28 +369,40 @@ def main():

if line == "y":
line_print("Starting training with hyperparameter selection.")
default_min_num_epochs = DEFAULT_HYPERPARAMETERS_SPACE[
"min_num_train_epochs"
]
min_num_epochs = input(
f"Enter min number of epochs. Press enter to use default value ({default_min_num_epochs}): " # noqa E501
)
default_max_num_epochs = DEFAULT_HYPERPARAMETERS_SPACE[
"max_num_train_epochs"
]
max_num_epochs = input(
"Enter max number of epochs: Press enter to use default value: "
f"Enter max number of epochs. Press enter to use default value ({default_max_num_epochs}): " # noqa E501
)
default_num_trials = 10
num_trials = input(
"Enter the number of trials to conduct hypeparamter search. Press enter to use default value: " # noqa E501
f"Enter the number of trials (maximum number of hyperparameter configurations to consider) for hyperparameter search. Press enter to use default value ({default_num_trials}): " # noqa E501
)
default_batch_size = DEFAULT_HYPERPARAMETERS_SPACE[
"per_device_train_batch_size"
] # noqa E501
max_batch_size = input(
"Enter the max batch size. "
+ f"Press enter to use default: {default_batch_size}"
+ f"Press enter to use default ({default_batch_size}): "
)

min_num_epochs = (
default_min_num_epochs if min_num_epochs == "" else eval(min_num_epochs)
)
max_num_epochs = (
DEFAULT_HYPERPARAMETERS_SPACE["max_num_train_epochs"]
if max_num_epochs == ""
else eval(max_num_epochs)
default_max_num_epochs if max_num_epochs == "" else eval(max_num_epochs)
)
num_trials = 1 if num_trials == "" else eval(num_trials)

max_batch_size = (
DEFAULT_HYPERPARAMETERS_SPACE["per_device_train_batch_size"][0]
DEFAULT_HYPERPARAMETERS_SPACE["per_device_train_batch_size"]
if max_batch_size == ""
else eval(max_batch_size)
)
Expand All @@ -406,12 +419,13 @@ def main():
line_print("Starting training.")

trained_model, trained_tokenizer = OptunaParamSelector(
n_trial=num_trials,
n_trials=num_trials,
trainer=trainer,
).select_from_hyperparameters(
training_datasets=training_datasets,
validation=validation_datasets,
hyperparameters={
"min_num_train_epochs": min_num_epochs,
"max_num_train_epochs": max_num_epochs,
"per_device_train_batch_size": [max_batch_size],
},
Expand Down

0 comments on commit 13a2049

Please sign in to comment.