diff --git a/src/ConfigSpace/hyperparameters/hyperparameter.py b/src/ConfigSpace/hyperparameters/hyperparameter.py index 9fb12f94..f10de318 100644 --- a/src/ConfigSpace/hyperparameters/hyperparameter.py +++ b/src/ConfigSpace/hyperparameters/hyperparameter.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Hashable, Mapping, Sequence @@ -132,11 +133,28 @@ def __init__( if not self.legal_value(self.default_value): raise ValueError( f"Illegal default value {self.default_value} for" - f" hyperparamter '{self.name}'.", + f" hyperparameter '{self.name}'.", ) self._normalized_default_value = self.to_vector(self.default_value) + def __setattr__(self, name: str, value: Any): + """Check if attribute can be set on HP, and reinitialises the class if so.""" + # NOTE: The following check is 'ugly', but it works... + if inspect.stack()[1][3] != '__init__': # This should be only executed on update, not init + # Extract all editable attributes + init_params: tuple[str] = self.__init__.__code__.co_varnames[:self.__init__.__code__.co_argcount] + + if name not in init_params or not hasattr(self, name): + raise ValueError("Can't set attribute {name}, must be one passed to init.") # Something better error message than this + + init_params = {key: self.__dict__[key] for key in init_params if hasattr(self, key)} # This will break if the parameter is not saved under its passed name + init_params[name] = value # Place the update value + + self.__init__(**init_params) # Reinitialise + else: + super().__setattr__(name, value) + @property def lower_vectorized(self) -> f64: """Lower bound of the hyperparameter in vector space.""" diff --git a/test/test_hyperparameters.py b/test/test_hyperparameters.py index 87cd07ce..05484244 100644 --- a/test/test_hyperparameters.py +++ b/test/test_hyperparameters.py @@ -3080,3 +3080,94 @@ def test_arbitrary_object_allowed_in_categorical_ordinal( list(get_one_exchange_neighbourhood(sample, seed=1)) # no raise for n in ns: n.check_valid_configuration() # no raise + + +def test_update_hyperparameters(): + space = ConfigurationSpace() + space.add( + [ + UniformIntegerHyperparameter("a", 0, 100), + UniformFloatHyperparameter("b", -1.0, 1.0), + CategoricalHyperparameter("c", [1, 2, 3]), + OrdinalHyperparameter("d", [1, 2, 3]), + ], + ) + # Test updating numerical HP min/max values + space["a"].upper = 51 + assert space["a"].upper == 51 + space["a"].lower = 49 + assert space["a"].lower == 49 + + # Sample the space to verify it does not sample OOD + sample = space.sample_configuration(size=25) + for value in sample: + assert 49 <= value["a"] <= 51 + + # Test updating default values + space["a"].lower = 1 # Update first to avoid error + space["a"].default_value = 5 + assert space["a"].default_value == 5 + + # Test that it cannot change to an illegal value + with pytest.raises(ValueError): + space["a"].upper = 0 # lower than lower + with pytest.raises(ValueError): + space["a"].lower = 100 # higher than upper + with pytest.raises(ValueError): + space["a"].default_value = 1000 # Out of bounds + + # Test Float + space["b"].upper = 0.1 + assert space["b"].upper == 0.1 + space["b"].lower = -0.1 + assert space["b"].lower == -0.1 + + # Check sampling + sample = space.sample_configuration(size=100) + for value in sample: + assert -0.1 <= value["b"] <= 0.1 + + # Test illegal changes + with pytest.raises(ValueError): + space["b"].upper = -0.11 # lower than lower + with pytest.raises(ValueError): + space["b"].lower = 0.11 # higher than upper + with pytest.raises(ValueError): + space["b"].default_value = -10.0 # Out of bounds + + # Test categorical HP + space["c"].choices = [1, 2, 3, 4] # Change range + assert space["c"].choices == (1, 2, 3, 4) + + space["c"].default_value = 4 # Change default value + assert space["c"].default_value == 4 + + space["c"].weights = [0.1, 0.4, 0.1, 0.4] # Change weights + assert space["c"].weights == (0.1, 0.4, 0.1, 0.4) + + # Test sampling + sample_count = {1: 0, 2: 0, 3: 0, 4: 0} + sample = space.sample_configuration(size=100) + for value in sample: + sample_count[value["c"]] += 1 + assert sample_count[2] > sample_count[1] + assert sample_count[2] > sample_count[3] + assert sample_count[4] > sample_count[1] + assert sample_count[4] > sample_count[3] + + # Test ordinal HP + space["d"].sequence = [1, 2, 3, 4] # Change range + assert space["d"].sequence == (1, 2, 3, 4) + + space["d"].default_value = 4 # Change default value + assert space["d"].default_value == 4 + + # Test sampling + sample = space.sample_configuration(size=100) + assert 4 in [s["d"] for s in sample] + + # TODO: Test changing HP type int -> float + # TODO: Test changing HP type float -> int + # TODO: Test changing HP type categorical -> ordinal + # TODO: Test changing HP type ordinal -> categorical + # TODO: Check that HP type cannot change between float/int and categorical/ordinal