Skip to content

Commit

Permalink
Update optuna to v3
Browse files Browse the repository at this point in the history
  • Loading branch information
keisuke-umezawa committed Jun 28, 2023
1 parent a53b320 commit b9d6f48
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@
BaseDistribution,
CategoricalChoiceType,
CategoricalDistribution,
DiscreteUniformDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
LogUniformDistribution,
UniformDistribution,
IntDistribution,
FloatDistribution,
)
from optuna.trial import Trial

Expand All @@ -62,17 +59,17 @@ def create_optuna_distribution_from_config(
assert param.low is not None
assert param.high is not None
if param.log:
return IntLogUniformDistribution(int(param.low), int(param.high))
return IntDistribution(int(param.low), int(param.high), log=True)
step = int(param.step) if param.step is not None else 1
return IntUniformDistribution(int(param.low), int(param.high), step=step)
return IntDistribution(int(param.low), int(param.high), step=step)
if param.type == DistributionType.float:
assert param.low is not None
assert param.high is not None
if param.log:
return LogUniformDistribution(param.low, param.high)
return FloatDistribution(param.low, param.high, log=True)
if param.step is not None:
return DiscreteUniformDistribution(param.low, param.high, param.step)
return UniformDistribution(param.low, param.high)
return FloatDistribution(param.low, param.high, step=param.step)
return FloatDistribution(param.low, param.high)
raise NotImplementedError(f"{param.type} is not supported by Optuna sweeper.")


Expand Down Expand Up @@ -107,23 +104,21 @@ def create_optuna_distribution_from_override(override: Override) -> Any:
or isinstance(value.stop, float)
or isinstance(value.step, float)
):
return DiscreteUniformDistribution(value.start, value.stop, value.step)
return IntUniformDistribution(
int(value.start), int(value.stop), step=int(value.step)
)
return FloatDistribution(value.start, value.stop, step=value.step)
return IntDistribution(int(value.start), int(value.stop), step=int(value.step))

if override.is_interval_sweep():
assert isinstance(value, IntervalSweep)
assert value.start is not None
assert value.end is not None
if "log" in value.tags:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntLogUniformDistribution(int(value.start), int(value.end))
return LogUniformDistribution(value.start, value.end)
return IntDistribution(int(value.start), int(value.end))
return FloatDistribution(value.start, value.end, log=True)
else:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntUniformDistribution(value.start, value.end)
return UniformDistribution(value.start, value.end)
return IntDistribution(value.start, value.end)
return FloatDistribution(value.start, value.end)

raise NotImplementedError(f"{override} is not supported by Optuna sweeper.")

Expand Down Expand Up @@ -266,13 +261,13 @@ def _parse_sweeper_params_config(self) -> List[str]:
def _to_grid_sampler_choices(self, distribution: BaseDistribution) -> Any:
if isinstance(distribution, CategoricalDistribution):
return distribution.choices
elif isinstance(distribution, IntUniformDistribution):
elif isinstance(distribution, IntDistribution):
assert (
distribution.step is not None
), "`step` of IntUniformDistribution must be a positive integer."
), "`step` of IntDistribution must be a positive integer."
n_items = (distribution.high - distribution.low) // distribution.step
return [distribution.low + i * distribution.step for i in range(n_items)]
elif isinstance(distribution, DiscreteUniformDistribution):
elif isinstance(distribution, FloatDistribution):
n_items = int((distribution.high - distribution.low) // distribution.q)
return [distribution.low + i * distribution.q for i in range(n_items)]
else:
Expand Down
2 changes: 1 addition & 1 deletion plugins/hydra_optuna_sweeper/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
],
install_requires=[
"hydra-core>=1.1.0.dev7",
"optuna>=2.10.0,<3.0.0",
"optuna>=3.0.0",
],
include_package_data=True,
)
35 changes: 16 additions & 19 deletions plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
from optuna.distributions import (
BaseDistribution,
CategoricalDistribution,
DiscreteUniformDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
LogUniformDistribution,
UniformDistribution,
IntDistribution,
FloatDistribution,
)
from optuna.samplers import RandomSampler
from pytest import mark, warns
Expand Down Expand Up @@ -59,24 +56,24 @@ def check_distribution(expected: BaseDistribution, actual: BaseDistribution) ->
{"type": "categorical", "choices": [1, 2, 3]},
CategoricalDistribution([1, 2, 3]),
),
({"type": "int", "low": 0, "high": 10}, IntUniformDistribution(0, 10)),
({"type": "int", "low": 0, "high": 10}, IntDistribution(0, 10)),
(
{"type": "int", "low": 0, "high": 10, "step": 2},
IntUniformDistribution(0, 10, step=2),
IntDistribution(0, 10, step=2),
),
({"type": "int", "low": 0, "high": 5}, IntUniformDistribution(0, 5)),
({"type": "int", "low": 0, "high": 5}, IntDistribution(0, 5)),
(
{"type": "int", "low": 1, "high": 100, "log": True},
IntLogUniformDistribution(1, 100),
IntDistribution(1, 100, log=True),
),
({"type": "float", "low": 0, "high": 1}, UniformDistribution(0, 1)),
({"type": "float", "low": 0, "high": 1}, FloatDistribution(0, 1)),
(
{"type": "float", "low": 0, "high": 10, "step": 2},
DiscreteUniformDistribution(0, 10, 2),
FloatDistribution(0, 10, step=2),
),
(
{"type": "float", "low": 1, "high": 100, "log": True},
LogUniformDistribution(1, 100),
FloatDistribution(1, 100, log=True),
),
],
)
Expand All @@ -92,12 +89,12 @@ def test_create_optuna_distribution_from_config(input: Any, expected: Any) -> No
("key=choice(true, false)", CategoricalDistribution([True, False])),
("key=choice('hello', 'world')", CategoricalDistribution(["hello", "world"])),
("key=shuffle(range(1,3))", CategoricalDistribution((1, 2))),
("key=range(1,3)", IntUniformDistribution(1, 3)),
("key=interval(1, 5)", UniformDistribution(1, 5)),
("key=int(interval(1, 5))", IntUniformDistribution(1, 5)),
("key=tag(log, interval(1, 5))", LogUniformDistribution(1, 5)),
("key=tag(log, int(interval(1, 5)))", IntLogUniformDistribution(1, 5)),
("key=range(0.5, 5.5, step=1)", DiscreteUniformDistribution(0.5, 5.5, 1)),
("key=range(1,3)", IntDistribution(1, 3)),
("key=interval(1, 5)", FloatDistribution(1, 5)),
("key=int(interval(1, 5))", IntDistribution(1, 5)),
("key=tag(log, interval(1, 5))", FloatDistribution(1, 5, log=True)),
("key=tag(log, int(interval(1, 5)))", IntDistribution(1, 5, log=True)),
("key=range(0.5, 5.5, step=1)", FloatDistribution(0.5, 5.5, step=1)),
],
)
def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> None:
Expand All @@ -121,7 +118,7 @@ def test_create_optuna_distribution_from_override(input: Any, expected: Any) ->
(
{
"key1": CategoricalDistribution([1, 2]),
"key3": IntUniformDistribution(1, 3),
"key3": IntDistribution(1, 3),
},
{"key2": "5"},
),
Expand Down
8 changes: 4 additions & 4 deletions website/docs/plugins/optuna_sweeper.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ Hydra provides a override parser that support rich syntax. Please refer to [Over

#### Interval override

By default, `interval` is converted to [`UniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.UniformDistribution.html). You can use [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html), [`LogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.LogUniformDistribution.html) or [`IntLogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntLogUniformDistribution.html) by casting the interval to `int` and tagging it with `log`.
By default, `interval` is converted to [`FloatDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.FloatDistribution.html). You can use [`IntDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntDistribution.html) by casting the interval to `int`.

<details><summary>Example for interval override</summary>

Expand Down Expand Up @@ -147,8 +147,8 @@ The output is as follows:

#### Range override

`range` is converted to [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead.
If any of `range`'s start, stop or step is of type float, it will be converted to [`DiscreteUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.DiscreteUniformDistribution.html)
`range` is converted to [`IntDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead.
If any of `range`'s start, stop or step is of type float, it will be converted to [`FloatDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.FloatDistribution.html)

<details><summary>Example for range override</summary>

Expand Down Expand Up @@ -321,4 +321,4 @@ Configuring a trial object is done in the following sequence:
- Command line overrides are set
- `custom_search_space` parameters are set

It is not allowed to set search space parameters in the `custom_search_space` method for parameters which have a fixed value from command line overrides. [Trial.user_attrs](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.user_attrs) can be inspected to find any of such fixed parameters.
It is not allowed to set search space parameters in the `custom_search_space` method for parameters which have a fixed value from command line overrides. [Trial.user_attrs](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.user_attrs) can be inspected to find any of such fixed parameters.

0 comments on commit b9d6f48

Please sign in to comment.