diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 90c7e18a99..6e84d31013 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -3,6 +3,8 @@ import warnings import importlib.util +from typing import Literal + import numpy as np from spikeinterface.core.base import BaseExtractor @@ -148,12 +150,16 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units): return vector_to_list_of_spiketrain_numba -# TODO later : implement other method like "maximum_rate", "by_percent", ... +# stratified sampling (isi / amplitude / pca distance ? ) def random_spikes_selection( sorting: BaseSorting, - num_samples: int | None = None, - method: str = "uniform", + num_samples: list[int] | None = None, + method: Literal["uniform", "all", "percentage", "maximum_rate", "temporal_bins"] = "uniform", max_spikes_per_unit: int = 500, + percentage: float | None = None, + maximum_rate: float | None = None, + bin_size_s: float | None = None, + k_per_bin: int | None = None, margin_size: int | None = None, seed: int | None = None, ): @@ -167,14 +173,22 @@ def random_spikes_selection( ---------- sorting: BaseSorting The sorting object - num_samples: list of int + num_samples: list[int] | None, default: None The number of samples per segment. Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - method: "uniform" | "all", default: "uniform" - The method to use. Only "uniform" is implemented for now + method: "uniform" | "percentage" | "maximum_rate" | "all", default: "uniform" + The method to use. max_spikes_per_unit: int, default: 500 - The number of spikes per units + The maximum number of spikes per units + percentage: float | None, default: None + In case of `percentage` method. The proportion of spikes per units. + maximum_rate: float | None, default: None + In case of `maximum_rate` method. The cap rate per units. + bin_size_s: float | None, default: None + In case of `temporal_bins` method. The duration of a temporal bin. + k_per_bin: int | None, default: None + In case of `temporal_bins` method. Maximum number of spikes per bins. margin_size: None | int, default: None A margin on each border of segments to avoid border spikes seed: None | int, default: None @@ -185,10 +199,19 @@ def random_spikes_selection( random_spikes_indices: np.array Selected spike indices coresponding to the sorting spike vector. """ + rng_methods = ("uniform", "percentage", "maximum_rate", "temporal_bins") + + if method == "all": + spikes = sorting.to_spike_vector() + random_spikes_indices = np.arange(spikes.size) + + elif method in rng_methods: + from spikeinterface.widgets.utils import get_segment_durations - if method == "uniform": rng = np.random.default_rng(seed=seed) + # since un concatenated + # spikes = [ [ (sample_index, unit_index, segment_index), (), ... ], [ (), ... ]] spikes = sorting.to_spike_vector(concatenated=False) cum_sizes = np.cumsum([0] + [s.size for s in spikes]) @@ -199,9 +222,12 @@ def random_spikes_selection( for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] for segment_index in range(sorting.get_num_segments()): - # this is local index + # this is local segment index inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: + if num_samples is None: + raise ValueError("num_samples must be provided when margin_size is used") + local_spikes = spikes[segment_index][inds_in_seg] mask = (local_spikes["sample_index"] >= margin_size) & ( local_spikes["sample_index"] < (num_samples[segment_index] - margin_size) @@ -211,19 +237,70 @@ def random_spikes_selection( inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index] all_unit_indices.append(inds_in_seg_abs) all_unit_indices = np.concatenate(all_unit_indices) - selected_unit_indices = rng.choice( - all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False - ) + + if method == "uniform": + rng_size = min(max_spikes_per_unit, all_unit_indices.size) + + elif method == "percentage": + if percentage is None or not (0 < percentage <= 1): + raise ValueError(f"percentage must be in the interval (0, 1]") + + rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) + + elif method == "maximum_rate": + if maximum_rate is None: + raise ValueError(f"maximum_rate must be defined") + + t_duration = np.sum(get_segment_durations(sorting)) + rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) + + elif method == "temporal_bins": + # expressed bin sampling as a dual sub sorting problem to be fully vectorized + + if None in (k_per_bin, bin_size_s): + missing = [] + if k_per_bin is None: + missing.append("k_per_bin") + if bin_size_s is None: + missing.append("bin_size_s") + raise ValueError( + f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}" + ) + + sampling_frequency = sorting.get_sampling_frequency() + bin_size_freq = int(bin_size_s * sampling_frequency) + + unit_spikes = np.concat(spikes)[all_unit_indices] + + # local to segment so will loop and reset + bin_index = unit_spikes["sample_index"] // bin_size_freq + segment_index = unit_spikes["segment_index"] + + group_values = np.stack((segment_index, bin_index), axis=1) + _, group_keys = np.unique(group_values, return_inverse=True, axis=0) + + score = rng.random(all_unit_indices.size) + order = np.lexsort((score, group_keys)) + + ordered_unit_indices = all_unit_indices[order] + + group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1] + counts = np.diff(np.r_[group_start, ordered_unit_indices.size]) + + ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts) + selection_mask = ranks <= k_per_bin + selected = ordered_unit_indices[selection_mask] + random_spikes_indices.append(selected) + continue + + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) random_spikes_indices = np.sort(random_spikes_indices) - elif method == "all": - spikes = sorting.to_spike_vector() - random_spikes_indices = np.arange(spikes.size) else: - raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") + raise ValueError(f"random_spikes_selection(): method must be 'all' or any in {', '.join(rng_methods)}") return random_spikes_indices diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index b1d3d5deaf..4720ff9098 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -96,11 +96,11 @@ def fit( model_folder_path: str, detect_peaks_params: dict, peak_selection_params: dict, - job_kwargs: dict = None, + job_kwargs: dict | None = None, ms_before: float = 1.0, ms_after: float = 1.0, whiten: bool = True, - radius_um: float = None, + radius_um: float | None = None, ) -> "IncrementalPCA": """ Train a pca model using the data in the recording object and the parameters provided. diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 50406b109e..923a950979 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -401,7 +401,7 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor return segment_indices -def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]: +def get_segment_durations(sorting: BaseSorting, segment_indices: list[int] = None) -> list[float]: """ Calculate the duration of each segment in a sorting object. @@ -410,11 +410,17 @@ def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> l sorting : BaseSorting The sorting object containing spike data + segment_indices : list[int] | None + List of the segment indices to process. Default to None. + Returns ------- list[float] List of segment durations in seconds """ + if segment_indices is None: + segment_indices = range(sorting.get_num_segments()) + spikes = sorting.to_spike_vector() segment_boundaries = [