diff --git a/pyproject.toml b/pyproject.toml index e4351f2e..ff585cc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,6 +173,7 @@ select = [ ignore = [ "T201", # TODO: Remove + "COM812", # Causes issues with ruff formatter "D100", "D104", # Missing docstring in public package "D105", # Missing docstring in magic mthod diff --git a/src/ConfigSpace/_condition_tree.py b/src/ConfigSpace/_condition_tree.py index 30fbac9e..ec33fc5d 100644 --- a/src/ConfigSpace/_condition_tree.py +++ b/src/ConfigSpace/_condition_tree.py @@ -41,7 +41,7 @@ import numpy as np from more_itertools import unique_everseen -from ConfigSpace.conditions import Condition, Conjunction +from ConfigSpace.conditions import Condition, ConditionLike, Conjunction from ConfigSpace.exceptions import ( AmbiguousConditionError, ChildNotFoundError, @@ -62,7 +62,6 @@ from ConfigSpace.types import f64 if TYPE_CHECKING: - from ConfigSpace.conditions import ConditionLike from ConfigSpace.hyperparameters import Hyperparameter from ConfigSpace.types import Array @@ -782,8 +781,7 @@ def _minimum_conditions(self) -> list[ConditionNode]: # i.e. two hyperparameters both rely on algorithm == "A" base_conditions: dict[int, ConditionNode] = {} for node in self.nodes.values(): - # This node has no parent as is a root - if node.parent_condition is None: + if node.parent_condition is None: # This node has no parent as it is a root node assert node.name in self.roots continue diff --git a/src/ConfigSpace/api/__init__.py b/src/ConfigSpace/api/__init__.py index e1ea41b9..473cafe5 100644 --- a/src/ConfigSpace/api/__init__.py +++ b/src/ConfigSpace/api/__init__.py @@ -3,13 +3,13 @@ from ConfigSpace.api.types import Categorical, Float, Integer __all__ = [ - "types", - "distributions", "Beta", - "Distribution", - "Normal", - "Uniform", "Categorical", + "Distribution", "Float", "Integer", + "Normal", + "Uniform", + "distributions", + "types", ] diff --git a/src/ConfigSpace/configuration.py b/src/ConfigSpace/configuration.py index 1b556f7f..ed080b8b 100644 --- a/src/ConfigSpace/configuration.py +++ b/src/ConfigSpace/configuration.py @@ -73,10 +73,10 @@ def __init__( ConfigSpace package. """ if ( - values is not None - and vector is not None - or values is None - and vector is None + (values is not None + and vector is not None) + or (values is None + and vector is None) ): raise ValueError( "Specify Configuration as either a dictionary or a vector.", diff --git a/src/ConfigSpace/configuration_space.py b/src/ConfigSpace/configuration_space.py index c0c0ecc6..1d2720b3 100644 --- a/src/ConfigSpace/configuration_space.py +++ b/src/ConfigSpace/configuration_space.py @@ -49,6 +49,7 @@ from ConfigSpace.configuration import Configuration, NotSet from ConfigSpace.exceptions import ( ActiveHyperparameterNotSetError, + HyperparameterNotFoundError, ForbiddenValueError, IllegalVectorizedValueError, InactiveHyperparameterSetError, @@ -350,6 +351,101 @@ def _put_to_list( self._len = len(self._dag.nodes) self._check_default_configuration() + def remove( + self, + *args: Hyperparameter, + ) -> None: + """Remove a hyperparameter from the configuration space. + + If the hyperparameter has children, the children are also removed. + This includes defined conditions and conjunctions! + + !!! note + + If removing multiple hyperparameters, it is better to remove them all + at once with one call to `remove()`, as we rebuilt a cache after each + call to `remove()`. + + Args: + args: Hyperparameter(s) to remove + """ + remove_hps = [] + for arg in args: + if isinstance(arg, Hyperparameter): + if arg.name not in self._dag.nodes: + raise HyperparameterNotFoundError( + f"Hyperparameter '{arg.name}' does not exist in space.", + ) + remove_hps.append(arg) + else: + raise TypeError(f"Unknown type {type(arg)}") + remove_hps_names = [hp.name for hp in remove_hps] + + # Filter HPs from the DAG + hps: list[Hyperparameter] = [node.hp for node in self._dag.nodes.values() if node.hp.name not in remove_hps_names] + + def remove_hyperparameter_from_conjunction( + target: Conjunction | Condition | ForbiddenRelation | ForbiddenClause, + ) -> ( + Conjunction + | Condition + | ForbiddenClause + | ForbiddenRelation + | ForbiddenConjunction + | None + ): + if isinstance(target, ForbiddenRelation) and ( + target.left.name in remove_hps_names or target.right.name in remove_hps_names + ): + return None + if isinstance(target, ForbiddenClause) and target.hyperparameter.name in remove_hps_names: + return None + if isinstance(target, Condition) and ( + target.parent.name in remove_hps_names or target.child.name in remove_hps_names + ): + return None + if isinstance(target, (Conjunction, ForbiddenConjunction)): + new_components = [] + for component in target.components: + new_component = remove_hyperparameter_from_conjunction(component) + if new_component is not None: + new_components.append(new_component) + if len(new_components) >= 2: # Can create a conjunction + return type(target)(*new_components) + if len(new_components) == 1: # Only one component remains + return new_components[0] + return None # No components remain + return target # Nothing to change + + # Remove HPs from conditions + conditions = [] + for condition in self._dag.conditions: + condition = remove_hyperparameter_from_conjunction(condition) + if condition is not None: # If None, the conditional clause is empty and thus not added + conditions.append(condition) + + # Remove HPs from Forbiddens + forbiddens = [] + for forbidden in self._dag.forbiddens: + forbidden = remove_hyperparameter_from_conjunction(forbidden) + if forbidden is not None: # If None, the forbidden clause is empty and is not added + forbiddens.append( + remove_hyperparameter_from_conjunction(forbidden) + ) + + # Rebuild the DAG + self._dag = DAG() + with self._dag.update(): + for hp in hps: + self._dag.add(hp) + for condition in conditions: + self._dag.add_condition(condition) + for forbidden in forbiddens: + self._dag.add_forbidden(forbidden) + + self._len = len(self._dag.nodes) + self._check_default_configuration() + def add_configuration_space( self, prefix: str, @@ -878,7 +974,7 @@ def __iter__(self) -> Iterator[str]: return iter(self._dag.nodes.keys()) def items(self) -> ItemsView[str, Hyperparameter]: - """Return an items view of the hyperparameters, same as `dict.items()`.""" # noqa: D402 + """Return an items view of the hyperparameters, same as `dict.items()`.""" return {name: node.hp for name, node in self._dag.nodes.items()}.items() def __len__(self) -> int: diff --git a/src/ConfigSpace/hyperparameters/__init__.py b/src/ConfigSpace/hyperparameters/__init__.py index 75a67a14..545703fa 100644 --- a/src/ConfigSpace/hyperparameters/__init__.py +++ b/src/ConfigSpace/hyperparameters/__init__.py @@ -24,7 +24,7 @@ "NormalIntegerHyperparameter", "NumericalHyperparameter", "OrdinalHyperparameter", + "UnParametrizedHyperparameter", "UniformFloatHyperparameter", "UniformIntegerHyperparameter", - "UnParametrizedHyperparameter", ] diff --git a/test/test_configuration_space.py b/test/test_configuration_space.py index 381e3f9c..c4fb2d8f 100644 --- a/test/test_configuration_space.py +++ b/test/test_configuration_space.py @@ -82,6 +82,110 @@ def test_add(): cs.add(hp) +def test_remove(): + cs = ConfigurationSpace() + hp = UniformIntegerHyperparameter("name", 0, 10) + hp2 = UniformFloatHyperparameter("name2", 0, 10) + hp3 = CategoricalHyperparameter( + "weather", ["dry", "rainy", "snowy"], default_value="dry" + ) + cs.add(hp, hp2, hp3) + cs.remove(hp) + assert len(cs) == 2 + + # Test multi removal + cs.add(hp) + cs.remove(hp, hp2) + assert len(cs) == 1 + + # Test faulty input + with pytest.raises(TypeError): + cs.remove(object()) + + # Non existant HP + with pytest.raises(HyperparameterNotFoundError): + cs.remove(hp) + + cs.add(hp, hp2) + # Test one correct one faulty, nothing should happen + with pytest.raises(TypeError): + cs.remove(hp, object()) + assert len(cs) == 3 + + # Make hp2 a conditional parameter, the condition should also be removed when hp is removed + cond = EqualsCondition(hp, hp2, 1) + cs.add(cond) + cs.remove(hp) + assert len(cs) == 2 + assert cs.conditional_hyperparameters == [] + assert cs.conditions == [] + + # Set up forbidden relation, the relation should also be removed + forb = ForbiddenEqualsClause(hp3, "snowy") + cs.add(forb) + cs.remove(hp3) + assert len(cs) == 1 + assert cs.forbidden_clauses == [] + + # And now for more complicated conditions + cs = ConfigurationSpace() + hp1 = CategoricalHyperparameter("input1", [0, 1]) + cs.add(hp1) + hp2 = CategoricalHyperparameter("input2", [0, 1]) + cs.add(hp2) + hp3 = CategoricalHyperparameter("input3", [0, 1]) + cs.add(hp3) + hp4 = CategoricalHyperparameter("input4", [0, 1]) + cs.add(hp4) + hp5 = CategoricalHyperparameter("input5", [0, 1]) + cs.add(hp5) + hp6 = Constant("constant1", "True") + cs.add(hp6) + + cond1 = EqualsCondition(hp6, hp1, 1) + cond2 = NotEqualsCondition(hp6, hp2, 1) + cond3 = InCondition(hp6, hp3, [1]) + cond4 = EqualsCondition(hp6, hp4, 1) + cond5 = EqualsCondition(hp6, hp5, 1) + + conj1 = AndConjunction(cond1, cond2) + conj2 = OrConjunction(conj1, cond3) + conj3 = AndConjunction(conj2, cond4, cond5) + cs.add(conj3) + + cs.remove(hp3) + assert len(cs) == 5 + # Only one part of the condition should be removed, not the entire condition + assert len(cs.conditional_hyperparameters) == 1 + assert len(cs.conditions) == 1 + # Test the exact value + assert ( + str(cs.conditions[0]) + == "((constant1 | input1 == 1 && constant1 | input2 != 1) && constant1 | input4 == 1 && constant1 | input5 == 1)" + ) + + # Now more complicated forbiddens + cs = ConfigurationSpace() + cs.add([hp1, hp2, hp3, hp4, hp5, hp6]) + cs.add(conj3) + + forb1 = ForbiddenEqualsClause(hp1, 1) + forb2 = ForbiddenAndConjunction(forb1, ForbiddenEqualsClause(hp2, 1)) + forb3 = ForbiddenAndConjunction(forb2, ForbiddenEqualsClause(hp3, 1)) + forb4 = ForbiddenEqualsClause(hp3, 1) + forb5 = ForbiddenEqualsClause(hp4, 1) + cs.add(forb3, forb4, forb5) + + cs.remove(hp3) + assert len(cs) == 5 + assert len(cs.forbidden_clauses) == 2 + assert ( + str(cs.forbidden_clauses[0]) + == "(Forbidden: input1 == 1 && Forbidden: input2 == 1)" + ) + assert str(cs.forbidden_clauses[1]) == "Forbidden: input4 == 1" + + def test_add_non_hyperparameter(): cs = ConfigurationSpace() with pytest.raises(TypeError):