-
Notifications
You must be signed in to change notification settings - Fork 73
Represent statespace metadata with dataclasses #607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Represent statespace metadata with dataclasses #607
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
jessegrabowski
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great first pass, much cleaner than what we have now.
pymc_extras/statespace/models/structural/components/regression_dataclass.py
Outdated
Show resolved
Hide resolved
pymc_extras/statespace/models/structural/components/regression_dataclass.py
Outdated
Show resolved
Hide resolved
pymc_extras/statespace/models/structural/components/regression_dataclass.py
Outdated
Show resolved
Hide resolved
pymc_extras/statespace/models/structural/components/regression_dataclass.py
Outdated
Show resolved
Hide resolved
pymc_extras/statespace/models/structural/components/regression_dataclass.py
Outdated
Show resolved
Hide resolved
pymc_extras/statespace/models/structural/components/regression_dataclass.py
Outdated
Show resolved
Hide resolved
|
We can also keep all of the existing properties like |
|
Reflecting on it, I am convinced this is the way to go. It's 1000x more ergonomic. I made some changes to your initial code to make the API more "dictionary like", and to reduce code duplication. I moved everything to |
|
@jessegrabowski, this is looking really cool! What can I do to help push this forward? |
|
Delete the new We should keep your notebook with the plan to add it as a new example for the docs. Or it can be merged into the custom statespace notebook. So that should also be updated to import from the new |
|
Perfect! I'll work on that today!! It is really looking cool! |
pymc_extras/statespace/models/structural/components/regression.py
Outdated
Show resolved
Hide resolved
pymc_extras/statespace/models/structural/components/regression.py
Outdated
Show resolved
Hide resolved
pymc_extras/statespace/models/structural/components/regression_dataclass.py
Outdated
Show resolved
Hide resolved
tests/statespace/models/structural/components/test_regression.py
Outdated
Show resolved
Hide resolved
|
@jessegrabowski, I agree with all of your comments above. I am going to start making those changes. |
…uplicate with warning 2. removed unnecessary imports from __init__ after deleting regression_dataclass 3. updated components and structural classes to only utilize dataclasses and pull other objects from <foo>_info dataclasses 4. updated tests to conform to dataclass api
2. created tests for add and merge methods 3. added utility to convert from snake to pascal and integrated it in error messaging
… and placed default shoch and state setters
jessegrabowski
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incomplete review, I'll continue tomorrow AM
| # if key in index: | ||
| # raise ValueError(f"Duplicate {self.key_field} '{key}' detected.") # This needs to be possible for shared states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That shouldn't happen here though, it should come up in merge or add right? And we handle it there with the allow_duplicates flag
| raise AttributeError(f"Items missing attribute '{self.key_field}': {missing_attr}") | ||
| object.__setattr__(self, "_index", index) | ||
|
|
||
| def _key(self, item: T) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used?
| @property | ||
| def observed_states(self) -> tuple[State, ...]: # Is this needed?? | ||
| return tuple(s for s in self.items if s.observed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to keep it (as an alias for observed_state_names), then pick one to be the "canonical" name and just return that one from the other ones, rather than re-writing the loop in several places
| def unobserved_state_names(self) -> tuple[State, ...]: | ||
| return tuple(s.name for s in self.items if not s.observed) | ||
|
|
||
| def merge(self, other: StateInfo, allow_duplicates: bool = False) -> StateInfo: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why doesn't the base class version work?
| ) | ||
|
|
||
| self.param_info = ParameterInfo(parameters=[beta_parameter, sigma_parameter]) | ||
| self.param_names = self.param_info.names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.param_names = self.param_info.names |
| self.param_names = self.param_info.names | ||
| else: | ||
| self.param_info = ParameterInfo(parameters=[beta_parameter]) | ||
| self.param_names = self.param_info.names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.param_names = self.param_info.names | |
| self.param_names = self.param_info.names |
| self.param_info = ParameterInfo(parameters=[beta_parameter]) | ||
| self.param_names = self.param_info.names | ||
|
|
||
| def _set_data(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the name _set_data is too close to pm.set_data, which changes the actual model data. _set_data_info here?
| # TODO: discuss if copying is still needed since these are now immutable | ||
| self._coord_info = coords_info.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, it shouldn't be necessary
| if is_dataclass(param_info): | ||
| param_info = param_info.add(initial_state_cov_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If param_info isn't a dataclass at this point what else can it be? None? Can we just check for that instead?
| self.state_names = list(state_names) if state_names is not None else [] | ||
| self.observed_state_names = ( | ||
| list(observed_state_names) if observed_state_names is not None else [] | ||
| self.param_info = ParameterInfo( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we change the component signature to just take the Info objects directly?
| self.state_names = self.state_info.unobserved_state_names | ||
| self.observed_state_names = self.state_info.observed_state_names | ||
| self.param_names = self.param_info.names | ||
| self.data_names = [d.name for d in self.data_info if not d.is_exogenous] | ||
| self.exog_names = self.data_info.exogenous_names | ||
| self.shock_names = self.shock_info.names | ||
|
|
||
| self.param_info = {} | ||
| self.data_info = {} | ||
| self.coords = self.coord_info.to_dict() | ||
| self.param_dims = [p.dims for p in self.param_info] | ||
|
|
||
| self.param_counts = {} | ||
| self.needs_exog_data = self.data_info.needs_exogenous_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should all be properties
| def _set_parameters(self) -> None: | ||
| raise NotImplementedError | ||
|
|
||
| def _set_shocks(self) -> None: | ||
| return ShockInfo(shocks=[Shock(name=f"shock_{n}") for n in range(self.k_posdef or 0)]) | ||
|
|
||
| def _set_states(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These don't return None
|
|
||
| if not isinstance(self_prop, list | dict): | ||
| if not is_dataclass(self_prop): | ||
| # TODO: This works right now because we are only passing <foo>_info info names into _combine_property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need _combine_property? Can we just call left.merge(right) since everything is an Info?
| @@ -0,0 +1,5 @@ | |||
| import re | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a whole new file for this?
This is a draft proposal for #598
The idea is to handle each component separately using
_set_{component}methods and all information are stored using data classes for easy mapping.I believe this will simplify our tests of these components and will reduce redundancies where we have the same information spread across multiple sub-components like
data_namesanddata_info.@jessegrabowski let me know what you think I put a little notebook together to showcase the changes.