Skip to content
Draft
583 changes: 583 additions & 0 deletions notebooks/structural_components_dataclass.ipynb

Large diffs are not rendered by default.

241 changes: 241 additions & 0 deletions pymc_extras/statespace/core/properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from __future__ import annotations

import warnings

from collections.abc import Iterator
from copy import deepcopy
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, Generic, Self, TypeVar

from pymc_extras.statespace.core import PyMCStateSpace
from pymc_extras.statespace.utils.constants import (
ALL_STATE_AUX_DIM,
ALL_STATE_DIM,
OBS_STATE_AUX_DIM,
OBS_STATE_DIM,
SHOCK_AUX_DIM,
SHOCK_DIM,
)

if TYPE_CHECKING:
from pymc_extras.statespace.models.structural.core import Component


@dataclass(frozen=True)
class Property:
def __str__(self) -> str:
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))


T = TypeVar("T", bound=Property)


@dataclass(frozen=True)
class Info(Generic[T]):
items: tuple[T, ...]
key_field: str = "name"
_index: dict[str, T] | None = None

def __post_init__(self):
index = {}
missing_attr = []
for item in self.items:
if not hasattr(item, self.key_field):
missing_attr.append(item)
continue
key = getattr(item, self.key_field)
# if key in index:
# raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states
Comment on lines +47 to +48
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That shouldn't happen here though, it should come up in merge or add right? And we handle it there with the allow_duplicates flag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what happens is because our data classes are immutable the __post_init__ runs right after our merge/add because we always return new objects of the same dataclass and it see that there are duplicate keys even though the merge/add method had allowed them via allow_duplicates.

index[key] = item
if missing_attr:
raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}")
object.__setattr__(self, "_index", index)

def _key(self, item: T) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used?

return getattr(item, self.key_field)

def get(self, key: str, default=None) -> T | None:
return self._index.get(key, default)

def __getitem__(self, key: str) -> T:
try:
return self._index[key]
except KeyError as e:
available = ", ".join(self._index.keys())
raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e

def __contains__(self, key: object) -> bool:
return key in self._index

def __iter__(self) -> Iterator[T]:
return iter(self.items)

def __len__(self) -> int:
return len(self.items)

def __str__(self) -> str:
return f"{self.key_field}s: {list(self._index.keys())}"

def add(self, new_item: T):
return type(self)([*self.items, new_item])

def merge(self, other: Self, allow_duplicates: bool = False) -> Self:
if not isinstance(other, type(self)):
raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}")

overlapping = set(self.names) & set(other.names)
if overlapping and not allow_duplicates:
raise ValueError(f"Duplicate names found: {overlapping}")

return type(self)(list(self.items) + list(other.items))

@property
def names(self) -> tuple[str, ...]:
return tuple(self._index.keys())

def copy(self) -> Info[T]:
return deepcopy(self)


@dataclass(frozen=True)
class Parameter(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
constraints: str | None = None


@dataclass(frozen=True)
class ParameterInfo(Info[Parameter]):
def __init__(self, parameters: list[Parameter]):
super().__init__(items=tuple(parameters), key_field="name")


@dataclass(frozen=True)
class Data(Property):
name: str
shape: tuple[int, ...]
dims: tuple[str, ...]
is_exogenous: bool


@dataclass(frozen=True)
class DataInfo(Info[Data]):
def __init__(self, data: list[Data]):
super().__init__(items=tuple(data), key_field="name")

@property
def needs_exogenous_data(self) -> bool:
return any(d.is_exogenous for d in self.items)

@property
def exogenous_names(self) -> tuple[str, ...]:
return tuple(d.name for d in self.items if d.is_exogenous)

def __str__(self) -> str:
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"


@dataclass(frozen=True)
class Coord(Property):
dimension: str
labels: tuple[str, ...]


@dataclass(frozen=True)
class CoordInfo(Info[Coord]):
def __init__(self, coords: list[Coord]):
super().__init__(items=tuple(coords), key_field="dimension")

def __str__(self) -> str:
base = "coordinates:"
for coord in self.items:
coord_str = str(coord)
indented = "\n".join(" " + line for line in coord_str.splitlines())
base += "\n" + indented + "\n"
return base

@classmethod
def default_coords_from_model(
cls, model: PyMCStateSpace | Component
) -> (
Self
): # TODO: Need to figure out how to include Component type was causing circular import issues
states = tuple(model.state_names)
obs_states = tuple(model.observed_states)
shocks = tuple(model.shock_names)

dim_to_labels = (
(ALL_STATE_DIM, states),
(ALL_STATE_AUX_DIM, states),
(OBS_STATE_DIM, obs_states),
(OBS_STATE_AUX_DIM, obs_states),
(SHOCK_DIM, shocks),
(SHOCK_AUX_DIM, shocks),
)

coords = [Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels]
return cls(coords)

def to_dict(self):
return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0}


@dataclass(frozen=True)
class State(Property):
name: str
observed: bool
shared: bool


@dataclass(frozen=True)
class StateInfo(Info[State]):
def __init__(self, states: list[State]):
super().__init__(items=tuple(states), key_field="name")

def __str__(self) -> str:
return (
f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
)

@property
def observed_states(self) -> tuple[State, ...]: # Is this needed??
return tuple(s for s in self.items if s.observed)
Comment on lines +201 to +203
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to keep it (as an alias for observed_state_names), then pick one to be the "canonical" name and just return that one from the other ones, rather than re-writing the loop in several places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaning more towards just using observed_state_names instead. If you are okay with that I am going to remove observed_states


@property
def observed_state_names(self) -> tuple[State, ...]:
return tuple(s.name for s in self.items if s.observed)

@property
def unobserved_state_names(self) -> tuple[State, ...]:
return tuple(s.name for s in self.items if not s.observed)

def merge(self, other: StateInfo, allow_duplicates: bool = False) -> StateInfo:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't the base class version work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here what is happening for shared states _get_combined_shapes counts the unique names between components. It needs to combine the state names without raising but it needs the unique names and not duplicates. It is subtle but if you look closely the StateInfo merge is an override that allows unique merge even when allow_duplicates is False compared to the base class that will raise an error. Obviously, this may not be the best way to handle this and maybe we need to look at _get_combined_shapes and see if it makes more sense to alter how that works.

"""Combine states from two StateInfo objects."""
if not isinstance(other, StateInfo):
raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo")

overlapping = set(self.names) & set(other.names)
if overlapping and not allow_duplicates:
# This is necessary for shared states
warnings.warn(
f"Duplicate state names found: {overlapping}. Merge will ONLY retain unique states",
UserWarning,
)
return StateInfo(
states=list(self.items)
+ [item for item in other.items if item.name not in overlapping]
)

return StateInfo(states=list(self.items) + list(other.items))


@dataclass(frozen=True)
class Shock(Property):
name: str


@dataclass(frozen=True)
class ShockInfo(Info[Shock]):
def __init__(self, shocks: list[Shock]):
super().__init__(items=tuple(shocks), key_field="name")
Loading
Loading