From 23cef31ac371d56149e82d14620da1198c58992a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 16 Dec 2025 13:19:20 +0100 Subject: [PATCH] wip: load_analyzer_from_nwb function --- .../extractors/extractor_classes.py | 2 + .../extractors/nwbextractors.py | 287 +++++++++++++++++- 2 files changed, 288 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/extractor_classes.py b/src/spikeinterface/extractors/extractor_classes.py index 4f0f586f18..7a91927ec9 100644 --- a/src/spikeinterface/extractors/extractor_classes.py +++ b/src/spikeinterface/extractors/extractor_classes.py @@ -35,6 +35,7 @@ read_nwb_recording, read_nwb_sorting, read_nwb_timeseries, + load_analyzer_from_nwb, ) from .cbin_ibl import CompressedBinaryIblExtractor, read_cbin_ibl @@ -194,6 +195,7 @@ __all__.extend( [ "read_nwb", # convenience function for multiple nwb formats + "load_analyzer_from_nwb", "recording_extractor_full_dict", "sorting_extractor_full_dict", "event_extractor_full_dict", diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 976e752a62..069f014970 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -7,7 +7,14 @@ import numpy as np from spikeinterface import get_global_tmp_folder -from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, BaseSortingSegment +from spikeinterface.core import ( + BaseRecording, + BaseRecordingSegment, + BaseSorting, + BaseSortingSegment, + SortingAnalyzer, + get_default_analyzer_extension_params, +) from spikeinterface.core.core_tools import define_function_from_class @@ -1259,6 +1266,7 @@ def _fetch_sorting_segment_info_backend( # need this for later self.units_table = units_table + self._file = open_file return unit_ids, spike_times_data, spike_times_index_data @@ -1789,3 +1797,280 @@ def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_seri outputs = outputs[0] return outputs + + +def load_analyzer_from_nwb( + file_path: str | Path, + t_start: float | None = None, + sampling_frequency: float | None = None, + electrical_series_path: str | None = None, + unit_table_path: str | None = None, + stream_mode: Literal["fsspec", "remfile", "zarr"] | None = None, + stream_cache_path: str | Path | None = None, + cache: bool = False, + storage_options: dict | None = None, + use_pynwb: bool = False, + group_name: str | None = None, + compute_extra: List[str] | None = ["unit_locations", "correlograms"], + compute_extra_params: dict | None = None, + verbose: bool = False, +) -> SortingAnalyzer: + import pandas as pd + from spikeinterface.metrics.template import ComputeTemplateMetrics + from spikeinterface.metrics.quality import ComputeQualityMetrics + + # try to read recording object to get the analyzer + try: + recording = NwbRecordingExtractor( + file_path=file_path, + electrical_series_path=electrical_series_path, + stream_mode=stream_mode, + stream_cache_path=stream_cache_path, + cache=cache, + storage_options=storage_options, + use_pynwb=use_pynwb, + ) + except Exception: + if verbose: + print("Could not load recording, proceeding without it") + recording = None + + t_start_tmp = 0 if t_start is None else t_start + + sorting_tmp = NwbSortingExtractor( + file_path=file_path, + electrical_series_path=electrical_series_path, + unit_table_path=unit_table_path, + stream_mode=stream_mode, + stream_cache_path=stream_cache_path, + cache=cache, + storage_options=storage_options, + use_pynwb=use_pynwb, + t_start=t_start_tmp, + sampling_frequency=sampling_frequency, + ) + + if recording is None and t_start is None: + # re-estimate t_start from spike times + if verbose: + print("Re-estimating t_start from spike_times") + t_start_new = np.min(sorting_tmp._sorting_segments[0].spike_times_data) - 0.001 + if verbose: + print(f"Found new t_start: {t_start_new} s") + sorting = NwbSortingExtractor( + file_path=file_path, + electrical_series_path=electrical_series_path, + unit_table_path=unit_table_path, + stream_mode=stream_mode, + stream_cache_path=stream_cache_path, + cache=cache, + storage_options=storage_options, + use_pynwb=use_pynwb, + t_start=t_start_new, + sampling_frequency=sampling_frequency, + ) + else: + sorting = sorting_tmp + + if use_pynwb: + units = sorting.units_table + colnames = units.colnames + units = units.to_dataframe(index=True) + else: + units_dset = sorting._file["units"] + units = make_df(units_dset) + colnames = units.columns + + electrodes_indices = None + if use_pynwb: + electrodes_table = sorting._nwbfile.electrodes.to_dataframe(index=True) + if "electrodes" in colnames: + electrodes_indices = units["electrodes"] + else: + electrodes_table = make_df(sorting._file["/general/extracellular_ephys/electrodes"]) + if "electrodes" in colnames: + electrodes_indices = electrodes_indices = units["electrodes"][:] + + if electrodes_indices is not None: + # here we assume all groups are the same for each unit, so we just check one. + if "group_name" in electrodes_table.columns: + group_names = np.array([electrodes_table.iloc[int(ei[0])]["group_name"] for ei in electrodes_indices]) + if len(np.unique(group_names)) > 0: + if group_name is None: + raise Exception( + f"More than one group, use group_name option to select units. Available groups: {np.unique(group_names)}" + ) + else: + unit_mask = group_names == group_name + if verbose: + print(f"Selecting {sum(unit_mask)} / {len(units)} units from {group_name}") + sorting = sorting.select_units(unit_ids=sorting.unit_ids[unit_mask]) + units = units.loc[units.index[unit_mask]] + electrodes_indices = units["electrodes"] + + # TODO: figure out sparsity + + # handle recording if available + if recording is not None: + # check groups + group_names = np.unique(recording.get_channel_groups()) + if group_name is not None and len(group_names) > 1: + recording = recording.split_by("group")[group_name] + rec_attributes = None + else: + recording = None + rec_attributes = {} + + # get sliced electrodes table from electrode_indices union + electrode_indices_all = [] + for ei in electrodes_indices: + electrode_indices_all.extend(ei) + electrode_indices_all = np.sort(np.unique(electrode_indices_all)) + if verbose: + print(f"Found {len(electrode_indices_all)} electrodes") + electrodes_table_sliced = electrodes_table.iloc[electrode_indices_all] + if "channel_name" in electrodes_table_sliced: + channel_ids = electrodes_table_sliced["channel_name"][:] + else: + channel_ids = electrodes_table_sliced["id"][:] + num_samples = [sorting.to_spike_vector()[-1]["sample_index"]] + rec_attributes = dict( + channel_ids=channel_ids, + sampling_frequency=sorting.sampling_frequency, + num_channels=len(channel_ids), + num_samples=num_samples, + is_filtered=True, + dtype="float32", + ) + # make a probegroup + electrode_colnames = electrodes_table_sliced.columns + assert ( + "rel_x" in electrode_colnames and "rel_y" in electrode_colnames + ), "'rel_x' and 'rel_y' should be columns in the electrode name" + locations = np.array([electrodes_table_sliced["rel_x"][:], electrodes_table_sliced["rel_y"][:]]).T + probegroup = create_dummy_probegroup_from_locations(locations) + rec_attributes["probegroup"] = probegroup + + # instantiate analyzer + analyzer = SortingAnalyzer.create_memory( + sorting=sorting, recording=recording, sparsity=None, rec_attributes=rec_attributes, return_in_uV=True + ) + + # templates + if "waveform_mean" in units: + from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeRandomSpikes + + # instantiate templates + analyzer.compute("random_spikes", method="all") + + templates_ext = ComputeTemplates(sorting_analyzer=analyzer) + templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float") + total_ms = templates_avg_data.shape[1] / analyzer.sampling_frequency * 1000 + template_params = get_default_analyzer_extension_params("templates") + if total_ms != template_params["ms_before"] + template_params["ms_after"]: + if verbose: + print("Guessing correct template cutouts") + template_params["ms_before"] = int(1 / 3 * total_ms) + template_params["ms_after"] = total_ms - template_params["ms_before"] + template_params["operators"] = ["average", "std"] + templates_ext.set_params(**template_params) + templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float") + templates_ext.data["average"] = templates_avg_data + if "waveforms_sd" in units: + templates_std_data = np.array([t for t in units["waveform_sd"].values]).astype("float") + else: + templates_std_data = np.zeros_like(templates_avg_data) + templates_ext.data["std"] = templates_std_data + templates_ext.run_info["run_completed"] = True + + analyzer.extensions["templates"] = templates_ext + + template_metric_columns = ComputeTemplateMetrics.get_metric_columns() + quality_metric_columns = ComputeQualityMetrics.get_metric_columns() + + tm = pd.DataFrame(index=sorting.unit_ids) + qm = pd.DataFrame(index=sorting.unit_ids) + + for col in units.columns: + if col in template_metric_columns: + tm.loc[:, col] = units[col].values + if col in quality_metric_columns: + qm.loc[:, col] = units[col].values + + if len(tm.columns) > 0: + if verbose: + print("Adding template metrics") + tm_ext = ComputeTemplateMetrics(analyzer) + tm_ext.data["metrics"] = tm + tm_ext.run_info["run_completed"] = True + analyzer.extensions["template_metrics"] = tm_ext + if len(qm.columns) > 0: + if verbose: + print("Adding quality metrics") + qm_ext = ComputeQualityMetrics(analyzer) + qm_ext.data["metrics"] = qm + qm_ext.run_info["run_completed"] = True + analyzer.extensions["quality_metrics"] = qm_ext + + # compute extra required + if compute_extra is not None: + if verbose: + print(f"Computing extra extensions: {compute_extra}") + compute_extra_params = {} if compute_extra_params is None else compute_extra_params + analyzer.compute(compute_extra, **compute_extra_params) + + return analyzer + + +def create_dummy_probegroup_from_locations(locations, shape="circle", shape_params={"radius": 1}): + """ + Creates a "dummy" probe based on locations. + + Parameters + ---------- + locations : np.array + Array with channel locations (num_channels, ndim) [ndim can be 2 or 3] + shape : str, default: "circle" + Electrode shapes + shape_params : dict, default: {"radius": 1} + Shape parameters + + Returns + ------- + probe : Probe + The created probe + """ + from probeinterface import Probe, ProbeGroup + + ndim = locations.shape[1] + assert ndim == 2 + probe = Probe(ndim=2) + probe.set_contacts(locations, shapes=shape, shape_params=shape_params) + probe.set_device_channel_indices(np.arange(len(probe.contact_positions))) + probe.create_auto_shape() + probegroup = ProbeGroup() + probegroup.add_probe(probe) + + return probegroup + + +def make_df(group): + """Makes pandas DataFrame from hdf5/zarr NWB group""" + import pandas as pd + + colnames = list(group.keys()) + data = {} + for col in colnames: + if "_index" in col: + continue + item = group[col][:] + if f"{col}_index" in colnames: + item = np.split(item, group[f"{col}_index"][:])[:-1] + data[col] = item + elif item.ndim > 1: + data[col] = [item_flat for item_flat in item] + else: + data[col] = item + df = pd.DataFrame(data=data) + df.set_index("id", inplace=True) + return df