Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spikeinterface/extractors/extractor_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
read_nwb_recording,
read_nwb_sorting,
read_nwb_timeseries,
load_analyzer_from_nwb,
)

from .cbin_ibl import CompressedBinaryIblExtractor, read_cbin_ibl
Expand Down Expand Up @@ -194,6 +195,7 @@
__all__.extend(
[
"read_nwb", # convenience function for multiple nwb formats
"load_analyzer_from_nwb",
"recording_extractor_full_dict",
"sorting_extractor_full_dict",
"event_extractor_full_dict",
Expand Down
287 changes: 286 additions & 1 deletion src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
import numpy as np

from spikeinterface import get_global_tmp_folder
from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, BaseSortingSegment
from spikeinterface.core import (
BaseRecording,
BaseRecordingSegment,
BaseSorting,
BaseSortingSegment,
SortingAnalyzer,
get_default_analyzer_extension_params,
)
from spikeinterface.core.core_tools import define_function_from_class


Expand Down Expand Up @@ -1259,6 +1266,7 @@ def _fetch_sorting_segment_info_backend(

# need this for later
self.units_table = units_table
self._file = open_file

return unit_ids, spike_times_data, spike_times_index_data

Expand Down Expand Up @@ -1789,3 +1797,280 @@ def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_seri
outputs = outputs[0]

return outputs


def load_analyzer_from_nwb(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't like the name read_nwb_as_analyzer() ? to match the kilosort one.

file_path: str | Path,
t_start: float | None = None,
sampling_frequency: float | None = None,
electrical_series_path: str | None = None,
unit_table_path: str | None = None,
stream_mode: Literal["fsspec", "remfile", "zarr"] | None = None,
stream_cache_path: str | Path | None = None,
cache: bool = False,
storage_options: dict | None = None,
use_pynwb: bool = False,
group_name: str | None = None,
compute_extra: List[str] | None = ["unit_locations", "correlograms"],
compute_extra_params: dict | None = None,
verbose: bool = False,
) -> SortingAnalyzer:
import pandas as pd
from spikeinterface.metrics.template import ComputeTemplateMetrics
from spikeinterface.metrics.quality import ComputeQualityMetrics

# try to read recording object to get the analyzer
try:
recording = NwbRecordingExtractor(
file_path=file_path,
electrical_series_path=electrical_series_path,
stream_mode=stream_mode,
stream_cache_path=stream_cache_path,
cache=cache,
storage_options=storage_options,
use_pynwb=use_pynwb,
)
except Exception:
if verbose:
print("Could not load recording, proceeding without it")
recording = None

t_start_tmp = 0 if t_start is None else t_start

sorting_tmp = NwbSortingExtractor(
file_path=file_path,
electrical_series_path=electrical_series_path,
unit_table_path=unit_table_path,
stream_mode=stream_mode,
stream_cache_path=stream_cache_path,
cache=cache,
storage_options=storage_options,
use_pynwb=use_pynwb,
t_start=t_start_tmp,
sampling_frequency=sampling_frequency,
)
Comment on lines +1838 to +1851
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use session_start_time instead.


if recording is None and t_start is None:
# re-estimate t_start from spike times
if verbose:
print("Re-estimating t_start from spike_times")
t_start_new = np.min(sorting_tmp._sorting_segments[0].spike_times_data) - 0.001
if verbose:
print(f"Found new t_start: {t_start_new} s")
sorting = NwbSortingExtractor(
file_path=file_path,
electrical_series_path=electrical_series_path,
unit_table_path=unit_table_path,
stream_mode=stream_mode,
stream_cache_path=stream_cache_path,
cache=cache,
storage_options=storage_options,
use_pynwb=use_pynwb,
t_start=t_start_new,
sampling_frequency=sampling_frequency,
)
else:
sorting = sorting_tmp

if use_pynwb:
units = sorting.units_table
colnames = units.colnames
units = units.to_dataframe(index=True)
else:
units_dset = sorting._file["units"]
units = make_df(units_dset)
colnames = units.columns

electrodes_indices = None
if use_pynwb:
electrodes_table = sorting._nwbfile.electrodes.to_dataframe(index=True)
if "electrodes" in colnames:
electrodes_indices = units["electrodes"]
else:
electrodes_table = make_df(sorting._file["/general/extracellular_ephys/electrodes"])
if "electrodes" in colnames:
electrodes_indices = electrodes_indices = units["electrodes"][:]

if electrodes_indices is not None:
# here we assume all groups are the same for each unit, so we just check one.
if "group_name" in electrodes_table.columns:
group_names = np.array([electrodes_table.iloc[int(ei[0])]["group_name"] for ei in electrodes_indices])
if len(np.unique(group_names)) > 0:
if group_name is None:
raise Exception(
f"More than one group, use group_name option to select units. Available groups: {np.unique(group_names)}"
)
else:
unit_mask = group_names == group_name
if verbose:
print(f"Selecting {sum(unit_mask)} / {len(units)} units from {group_name}")
sorting = sorting.select_units(unit_ids=sorting.unit_ids[unit_mask])
units = units.loc[units.index[unit_mask]]
electrodes_indices = units["electrodes"]
Comment on lines +1894 to +1909
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use the same trick as the "aggregation_key" when instantiating a sorting analyzer from grouped recordings/sortings


# TODO: figure out sparsity

# handle recording if available
if recording is not None:
# check groups
group_names = np.unique(recording.get_channel_groups())
if group_name is not None and len(group_names) > 1:
recording = recording.split_by("group")[group_name]
rec_attributes = None
else:
recording = None
rec_attributes = {}

# get sliced electrodes table from electrode_indices union
electrode_indices_all = []
for ei in electrodes_indices:
electrode_indices_all.extend(ei)
electrode_indices_all = np.sort(np.unique(electrode_indices_all))
if verbose:
print(f"Found {len(electrode_indices_all)} electrodes")
electrodes_table_sliced = electrodes_table.iloc[electrode_indices_all]
if "channel_name" in electrodes_table_sliced:
channel_ids = electrodes_table_sliced["channel_name"][:]
else:
channel_ids = electrodes_table_sliced["id"][:]
num_samples = [sorting.to_spike_vector()[-1]["sample_index"]]
rec_attributes = dict(
channel_ids=channel_ids,
sampling_frequency=sorting.sampling_frequency,
num_channels=len(channel_ids),
num_samples=num_samples,
is_filtered=True,
dtype="float32",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the dtype and why is it fixed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this optional at the Analyzer level (same for is_filtered)

)
# make a probegroup
electrode_colnames = electrodes_table_sliced.columns
assert (
"rel_x" in electrode_colnames and "rel_y" in electrode_colnames
), "'rel_x' and 'rel_y' should be columns in the electrode name"
locations = np.array([electrodes_table_sliced["rel_x"][:], electrodes_table_sliced["rel_y"][:]]).T
probegroup = create_dummy_probegroup_from_locations(locations)
rec_attributes["probegroup"] = probegroup

# instantiate analyzer
analyzer = SortingAnalyzer.create_memory(
sorting=sorting, recording=recording, sparsity=None, rec_attributes=rec_attributes, return_in_uV=True
)

# templates
if "waveform_mean" in units:
from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeRandomSpikes

# instantiate templates
analyzer.compute("random_spikes", method="all")

templates_ext = ComputeTemplates(sorting_analyzer=analyzer)
templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float")
total_ms = templates_avg_data.shape[1] / analyzer.sampling_frequency * 1000
template_params = get_default_analyzer_extension_params("templates")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a strange guess.
Do we except nwd to have the same template params as spikeinterface actual version ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a proper way to do it ?
I think I would go directly to the 1/3 2/2 + warnings meachanism.

if total_ms != template_params["ms_before"] + template_params["ms_after"]:
if verbose:
print("Guessing correct template cutouts")
template_params["ms_before"] = int(1 / 3 * total_ms)
template_params["ms_after"] = total_ms - template_params["ms_before"]
template_params["operators"] = ["average", "std"]
templates_ext.set_params(**template_params)
templates_avg_data = np.array([t for t in units["waveform_mean"].values]).astype("float")
templates_ext.data["average"] = templates_avg_data
if "waveforms_sd" in units:
templates_std_data = np.array([t for t in units["waveform_sd"].values]).astype("float")
else:
templates_std_data = np.zeros_like(templates_avg_data)
templates_ext.data["std"] = templates_std_data
templates_ext.run_info["run_completed"] = True

analyzer.extensions["templates"] = templates_ext

template_metric_columns = ComputeTemplateMetrics.get_metric_columns()
quality_metric_columns = ComputeQualityMetrics.get_metric_columns()

tm = pd.DataFrame(index=sorting.unit_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is tm?

qm = pd.DataFrame(index=sorting.unit_ids)
Comment on lines +1991 to +1992
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set the correct dtype from the new extension system ?


for col in units.columns:
if col in template_metric_columns:
tm.loc[:, col] = units[col].values
if col in quality_metric_columns:
qm.loc[:, col] = units[col].values

if len(tm.columns) > 0:
if verbose:
print("Adding template metrics")
tm_ext = ComputeTemplateMetrics(analyzer)
tm_ext.data["metrics"] = tm
tm_ext.run_info["run_completed"] = True
analyzer.extensions["template_metrics"] = tm_ext
if len(qm.columns) > 0:
if verbose:
print("Adding quality metrics")
qm_ext = ComputeQualityMetrics(analyzer)
qm_ext.data["metrics"] = qm
qm_ext.run_info["run_completed"] = True
analyzer.extensions["quality_metrics"] = qm_ext

# compute extra required
if compute_extra is not None:
if verbose:
print(f"Computing extra extensions: {compute_extra}")
compute_extra_params = {} if compute_extra_params is None else compute_extra_params
analyzer.compute(compute_extra, **compute_extra_params)

return analyzer


def create_dummy_probegroup_from_locations(locations, shape="circle", shape_params={"radius": 1}):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make this private as we might want to change this.

"""
Creates a "dummy" probe based on locations.

Parameters
----------
locations : np.array
Array with channel locations (num_channels, ndim) [ndim can be 2 or 3]
shape : str, default: "circle"
Electrode shapes
shape_params : dict, default: {"radius": 1}
Shape parameters

Returns
-------
probe : Probe
The created probe
"""
from probeinterface import Probe, ProbeGroup

ndim = locations.shape[1]
assert ndim == 2
probe = Probe(ndim=2)
probe.set_contacts(locations, shapes=shape, shape_params=shape_params)
probe.set_device_channel_indices(np.arange(len(probe.contact_positions)))
probe.create_auto_shape()
probegroup = ProbeGroup()
probegroup.add_probe(probe)

return probegroup


def make_df(group):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make this private as we might want to change this. Plus, this is a super generic name that we don't want to contaminate any namespace with.

"""Makes pandas DataFrame from hdf5/zarr NWB group"""
import pandas as pd

colnames = list(group.keys())
data = {}
for col in colnames:
if "_index" in col:
continue
item = group[col][:]
if f"{col}_index" in colnames:
item = np.split(item, group[f"{col}_index"][:])[:-1]
data[col] = item
elif item.ndim > 1:
data[col] = [item_flat for item_flat in item]
else:
data[col] = item
df = pd.DataFrame(data=data)
df.set_index("id", inplace=True)
return df