Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/ConfigSpace/hyperparameters/hyperparameter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Hashable, Mapping, Sequence
Expand Down Expand Up @@ -132,11 +133,28 @@ def __init__(
if not self.legal_value(self.default_value):
raise ValueError(
f"Illegal default value {self.default_value} for"
f" hyperparamter '{self.name}'.",
f" hyperparameter '{self.name}'.",
)

self._normalized_default_value = self.to_vector(self.default_value)

def __setattr__(self, name: str, value: Any):
"""Check if attribute can be set on HP, and reinitialises the class if so."""
# NOTE: The following check is 'ugly', but it works...
if inspect.stack()[1][3] != '__init__': # This should be only executed on update, not init
# Extract all editable attributes
init_params: tuple[str] = self.__init__.__code__.co_varnames[:self.__init__.__code__.co_argcount]

if name not in init_params or not hasattr(self, name):
raise ValueError("Can't set attribute {name}, must be one passed to init.") # Something better error message than this

init_params = {key: self.__dict__[key] for key in init_params if hasattr(self, key)} # This will break if the parameter is not saved under its passed name
init_params[name] = value # Place the update value

self.__init__(**init_params) # Reinitialise
else:
super().__setattr__(name, value)

@property
def lower_vectorized(self) -> f64:
"""Lower bound of the hyperparameter in vector space."""
Expand Down
91 changes: 91 additions & 0 deletions test/test_hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3080,3 +3080,94 @@ def test_arbitrary_object_allowed_in_categorical_ordinal(
list(get_one_exchange_neighbourhood(sample, seed=1)) # no raise
for n in ns:
n.check_valid_configuration() # no raise


def test_update_hyperparameters():
space = ConfigurationSpace()
space.add(
[
UniformIntegerHyperparameter("a", 0, 100),
UniformFloatHyperparameter("b", -1.0, 1.0),
CategoricalHyperparameter("c", [1, 2, 3]),
OrdinalHyperparameter("d", [1, 2, 3]),
],
)
# Test updating numerical HP min/max values
space["a"].upper = 51
assert space["a"].upper == 51
space["a"].lower = 49
assert space["a"].lower == 49

# Sample the space to verify it does not sample OOD
sample = space.sample_configuration(size=25)
for value in sample:
assert 49 <= value["a"] <= 51

# Test updating default values
space["a"].lower = 1 # Update first to avoid error
space["a"].default_value = 5
assert space["a"].default_value == 5

# Test that it cannot change to an illegal value
with pytest.raises(ValueError):
space["a"].upper = 0 # lower than lower
with pytest.raises(ValueError):
space["a"].lower = 100 # higher than upper
with pytest.raises(ValueError):
space["a"].default_value = 1000 # Out of bounds

# Test Float
space["b"].upper = 0.1
assert space["b"].upper == 0.1
space["b"].lower = -0.1
assert space["b"].lower == -0.1

# Check sampling
sample = space.sample_configuration(size=100)
for value in sample:
assert -0.1 <= value["b"] <= 0.1

# Test illegal changes
with pytest.raises(ValueError):
space["b"].upper = -0.11 # lower than lower
with pytest.raises(ValueError):
space["b"].lower = 0.11 # higher than upper
with pytest.raises(ValueError):
space["b"].default_value = -10.0 # Out of bounds

# Test categorical HP
space["c"].choices = [1, 2, 3, 4] # Change range
assert space["c"].choices == (1, 2, 3, 4)

space["c"].default_value = 4 # Change default value
assert space["c"].default_value == 4

space["c"].weights = [0.1, 0.4, 0.1, 0.4] # Change weights
assert space["c"].weights == (0.1, 0.4, 0.1, 0.4)

# Test sampling
sample_count = {1: 0, 2: 0, 3: 0, 4: 0}
sample = space.sample_configuration(size=100)
for value in sample:
sample_count[value["c"]] += 1
assert sample_count[2] > sample_count[1]
assert sample_count[2] > sample_count[3]
assert sample_count[4] > sample_count[1]
assert sample_count[4] > sample_count[3]

# Test ordinal HP
space["d"].sequence = [1, 2, 3, 4] # Change range
assert space["d"].sequence == (1, 2, 3, 4)

space["d"].default_value = 4 # Change default value
assert space["d"].default_value == 4

# Test sampling
sample = space.sample_configuration(size=100)
assert 4 in [s["d"] for s in sample]

# TODO: Test changing HP type int -> float
# TODO: Test changing HP type float -> int
# TODO: Test changing HP type categorical -> ordinal
# TODO: Test changing HP type ordinal -> categorical
# TODO: Check that HP type cannot change between float/int and categorical/ordinal