From ad8f2bc328535646d51483c9e553b1313d1d5a64 Mon Sep 17 00:00:00 2001 From: tayheau Date: Tue, 9 Dec 2025 10:32:22 +0100 Subject: [PATCH 1/6] first commit to create branch --- src/spikeinterface/core/sorting_tools.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 90c7e18a99..15337671f5 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -185,8 +185,13 @@ def random_spikes_selection( random_spikes_indices: np.array Selected spike indices coresponding to the sorting spike vector. """ + rng_methods = ("uniform", "percentage") - if method == "uniform": + if method == "all": + spikes = sorting.to_spike_vector() + random_spikes_indices = np.arange(spikes.size) + + elif method in rng_methods: rng = np.random.default_rng(seed=seed) spikes = sorting.to_spike_vector(concatenated=False) @@ -211,17 +216,20 @@ def random_spikes_selection( inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index] all_unit_indices.append(inds_in_seg_abs) all_unit_indices = np.concatenate(all_unit_indices) + + if method == "uniform": + rng_size = min(max_spikes_per_unit, all_unit_indices.size) + elif method == "percentage": + rng_size = min(max_spikes_per_unit, all_unit_indices.size * percentage) + selected_unit_indices = rng.choice( - all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False + all_unit_indices, size=rng_size, replace=False, shuffle=False ) random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) random_spikes_indices = np.sort(random_spikes_indices) - elif method == "all": - spikes = sorting.to_spike_vector() - random_spikes_indices = np.arange(spikes.size) else: raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") From e6c9c2b1d459b02d099bf08f2881812e7ed4a965 Mon Sep 17 00:00:00 2001 From: tayheau Date: Tue, 9 Dec 2025 17:54:54 +0100 Subject: [PATCH 2/6] temporal bin, rate cap and percentage sampling --- src/spikeinterface/core/sorting_tools.py | 93 ++++++++++++++++--- .../waveforms/temporal_pca.py | 4 +- src/spikeinterface/widgets/utils.py | 8 +- 3 files changed, 90 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 15337671f5..c41d0102d3 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -3,8 +3,11 @@ import warnings import importlib.util +from typing import Literal + import numpy as np +from spikeinterface.widgets.utils import get_segment_durations from spikeinterface.core.base import BaseExtractor from spikeinterface.core.basesorting import BaseSorting from spikeinterface.core.numpyextractors import NumpySorting @@ -148,12 +151,16 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units): return vector_to_list_of_spiketrain_numba -# TODO later : implement other method like "maximum_rate", "by_percent", ... +# stratified sampling (isi / amplitude / pca distance ? ) def random_spikes_selection( sorting: BaseSorting, - num_samples: int | None = None, - method: str = "uniform", + num_samples: list[int] | None = None, + method: Literal["uniform", "all", "percentage", "maximum_rate", "temporal_bins"] = "uniform", max_spikes_per_unit: int = 500, + percentage: float | None = None, + maximum_rate: float | None = None, + bin_size_s: float | None = None, + k_per_bin: int | None = None, margin_size: int | None = None, seed: int | None = None, ): @@ -167,14 +174,22 @@ def random_spikes_selection( ---------- sorting: BaseSorting The sorting object - num_samples: list of int + num_samples: list[int] | None, default: None The number of samples per segment. Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - method: "uniform" | "all", default: "uniform" - The method to use. Only "uniform" is implemented for now + method: "uniform" | "percentage" | "maximum_rate" | "all", default: "uniform" + The method to use. max_spikes_per_unit: int, default: 500 - The number of spikes per units + The maximum number of spikes per units + percentage: float | None, default: None + In case of `percentage` method. The proportion of spikes per units. + maximum_rate: float | None, default: None + In case of `maximum_rate` method. The cap rate per units. + bin_size_s: float | None, default: None + In case of `temporal_bins` method. The duration of a temporal bin. + k_per_bin: int | None, default: None + In case of `temporal_bins` method. Maximum number of spikes per bins. margin_size: None | int, default: None A margin on each border of segments to avoid border spikes seed: None | int, default: None @@ -185,7 +200,7 @@ def random_spikes_selection( random_spikes_indices: np.array Selected spike indices coresponding to the sorting spike vector. """ - rng_methods = ("uniform", "percentage") + rng_methods = ("uniform", "percentage", "maximum_rate", "temporal_bins") if method == "all": spikes = sorting.to_spike_vector() @@ -194,6 +209,8 @@ def random_spikes_selection( elif method in rng_methods: rng = np.random.default_rng(seed=seed) + # since un concatenated + # spikes = [ [ (sample_index, unit_index, segment_index), (), ... ], [ (), ... ]] spikes = sorting.to_spike_vector(concatenated=False) cum_sizes = np.cumsum([0] + [s.size for s in spikes]) @@ -203,10 +220,15 @@ def random_spikes_selection( random_spikes_indices = [] for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] + all_unit_trains = [] for segment_index in range(sorting.get_num_segments()): - # this is local index + # this is local segment index + trains_in_seg = spike_trains[segment_index][unit_id] inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: + if num_samples is None: + raise ValueError("num_samples must be provided when margin_size is used") + local_spikes = spikes[segment_index][inds_in_seg] mask = (local_spikes["sample_index"] >= margin_size) & ( local_spikes["sample_index"] < (num_samples[segment_index] - margin_size) @@ -219,8 +241,56 @@ def random_spikes_selection( if method == "uniform": rng_size = min(max_spikes_per_unit, all_unit_indices.size) + elif method == "percentage": - rng_size = min(max_spikes_per_unit, all_unit_indices.size * percentage) + if percentage is None or not (0 < percentage <= 1) : + raise ValueError(f"percentage must be in the interval (0, 1]") + + rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) + + elif method == "maximum_rate": + if maximum_rate is None: + raise ValueError(f"maximum_rate must be defined") + + t_duration = np.sum(get_segment_durations(sorting)) + rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) + + elif method == "temporal_bins": + # expressed bin sampling as a dual sub sorting problem to be fully vectorized + + if None in (k_per_bin, bin_size_s): + missing = [] + if k_per_bin is None: + missing.append("k_per_bin") + if bin_size_s is None: + missing.append("bin_size_s") + print(f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}") + + sampling_frequency = sorting.get_sampling_frequency() + bin_size_freq = int(bin_size_s * sampling_frequency) + + unit_spikes = np.concat(spikes)[all_unit_indices] + + # local to segment so will loop and reset + bin_index = unit_spikes["sample_index"] // bin_size_freq + segment_index = unit_spikes["segment_index"] + + group_values = np.stack((segment_index, bin_index), axis = 1) + _, group_keys = np.unique(group_values, return_inverse = True, axis= 0) + + score = rng.random(all_unit_indices.size) + order = np.lexsort((score, group_keys)) + + ordered_unit_indices = all_unit_indices[order] + + group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1 ] + counts = np.diff(np.r_[group_start, ordered_unit_indices.size]) + + ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts) + selection_mask = ranks <= k_per_bin + selected = ordered_unit_indices[selection_mask] + random_spikes_indices.append(selected) + continue selected_unit_indices = rng.choice( all_unit_indices, size=rng_size, replace=False, shuffle=False @@ -231,11 +301,10 @@ def random_spikes_selection( random_spikes_indices = np.sort(random_spikes_indices) else: - raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'") + raise ValueError(f"random_spikes_selection(): method must be 'all' or any in {', '.join(rng_methods)}") return random_spikes_indices - ### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index b1d3d5deaf..4720ff9098 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -96,11 +96,11 @@ def fit( model_folder_path: str, detect_peaks_params: dict, peak_selection_params: dict, - job_kwargs: dict = None, + job_kwargs: dict | None = None, ms_before: float = 1.0, ms_after: float = 1.0, whiten: bool = True, - radius_um: float = None, + radius_um: float | None = None, ) -> "IncrementalPCA": """ Train a pca model using the data in the recording object and the parameters provided. diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 50406b109e..923a950979 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -401,7 +401,7 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor return segment_indices -def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]: +def get_segment_durations(sorting: BaseSorting, segment_indices: list[int] = None) -> list[float]: """ Calculate the duration of each segment in a sorting object. @@ -410,11 +410,17 @@ def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> l sorting : BaseSorting The sorting object containing spike data + segment_indices : list[int] | None + List of the segment indices to process. Default to None. + Returns ------- list[float] List of segment durations in seconds """ + if segment_indices is None: + segment_indices = range(sorting.get_num_segments()) + spikes = sorting.to_spike_vector() segment_boundaries = [ From bef42f1deea6c9e43c3a36e7e26c86ce283fed71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Dec 2025 13:57:44 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index c41d0102d3..f17ae9a600 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -174,7 +174,7 @@ def random_spikes_selection( ---------- sorting: BaseSorting The sorting object - num_samples: list[int] | None, default: None + num_samples: list[int] | None, default: None The number of samples per segment. Can be retrieved from recording with num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] @@ -243,7 +243,7 @@ def random_spikes_selection( rng_size = min(max_spikes_per_unit, all_unit_indices.size) elif method == "percentage": - if percentage is None or not (0 < percentage <= 1) : + if percentage is None or not (0 < percentage <= 1): raise ValueError(f"percentage must be in the interval (0, 1]") rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) @@ -264,7 +264,9 @@ def random_spikes_selection( missing.append("k_per_bin") if bin_size_s is None: missing.append("bin_size_s") - print(f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}") + print( + f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}" + ) sampling_frequency = sorting.get_sampling_frequency() bin_size_freq = int(bin_size_s * sampling_frequency) @@ -275,15 +277,15 @@ def random_spikes_selection( bin_index = unit_spikes["sample_index"] // bin_size_freq segment_index = unit_spikes["segment_index"] - group_values = np.stack((segment_index, bin_index), axis = 1) - _, group_keys = np.unique(group_values, return_inverse = True, axis= 0) + group_values = np.stack((segment_index, bin_index), axis=1) + _, group_keys = np.unique(group_values, return_inverse=True, axis=0) score = rng.random(all_unit_indices.size) order = np.lexsort((score, group_keys)) ordered_unit_indices = all_unit_indices[order] - - group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1 ] + + group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1] counts = np.diff(np.r_[group_start, ordered_unit_indices.size]) ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts) @@ -292,9 +294,7 @@ def random_spikes_selection( random_spikes_indices.append(selected) continue - selected_unit_indices = rng.choice( - all_unit_indices, size=rng_size, replace=False, shuffle=False - ) + selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) random_spikes_indices.append(selected_unit_indices) random_spikes_indices = np.concatenate(random_spikes_indices) @@ -305,6 +305,7 @@ def random_spikes_selection( return random_spikes_indices + ### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, From b037441dc1819c93e6e57f853427c041d4a2fc54 Mon Sep 17 00:00:00 2001 From: tayheau Date: Thu, 18 Dec 2025 15:24:12 +0100 Subject: [PATCH 4/6] lazy loading get_segment_duration --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index f17ae9a600..4ab38bc2a9 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -7,7 +7,6 @@ import numpy as np -from spikeinterface.widgets.utils import get_segment_durations from spikeinterface.core.base import BaseExtractor from spikeinterface.core.basesorting import BaseSorting from spikeinterface.core.numpyextractors import NumpySorting @@ -207,6 +206,7 @@ def random_spikes_selection( random_spikes_indices = np.arange(spikes.size) elif method in rng_methods: + from spikeinterface.widgets.utils import get_segment_durations rng = np.random.default_rng(seed=seed) # since un concatenated From 625bbef6401bcb69e3124225564024757190143b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:24:41 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sorting_tools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 4ab38bc2a9..0412d00a3a 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -207,6 +207,7 @@ def random_spikes_selection( elif method in rng_methods: from spikeinterface.widgets.utils import get_segment_durations + rng = np.random.default_rng(seed=seed) # since un concatenated From acfcb660b21294d74af7859b8b5f0e0049ec574e Mon Sep 17 00:00:00 2001 From: tayheau Date: Thu, 18 Dec 2025 18:24:28 +0100 Subject: [PATCH 6/6] removed unused var --- src/spikeinterface/core/sorting_tools.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 0412d00a3a..6e84d31013 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -221,10 +221,8 @@ def random_spikes_selection( random_spikes_indices = [] for unit_index, unit_id in enumerate(sorting.unit_ids): all_unit_indices = [] - all_unit_trains = [] for segment_index in range(sorting.get_num_segments()): # this is local segment index - trains_in_seg = spike_trains[segment_index][unit_id] inds_in_seg = spike_indices[segment_index][unit_id] if margin_size is not None: if num_samples is None: @@ -265,7 +263,7 @@ def random_spikes_selection( missing.append("k_per_bin") if bin_size_s is None: missing.append("bin_size_s") - print( + raise ValueError( f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}" )