Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 93 additions & 16 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import warnings
import importlib.util

from typing import Literal

import numpy as np

from spikeinterface.core.base import BaseExtractor
Expand Down Expand Up @@ -148,12 +150,16 @@ def vector_to_list_of_spiketrain_numba(sample_indices, unit_indices, num_units):
return vector_to_list_of_spiketrain_numba


# TODO later : implement other method like "maximum_rate", "by_percent", ...
# stratified sampling (isi / amplitude / pca distance ? )
def random_spikes_selection(
sorting: BaseSorting,
num_samples: int | None = None,
method: str = "uniform",
num_samples: list[int] | None = None,
method: Literal["uniform", "all", "percentage", "maximum_rate", "temporal_bins"] = "uniform",
max_spikes_per_unit: int = 500,
percentage: float | None = None,
maximum_rate: float | None = None,
bin_size_s: float | None = None,
k_per_bin: int | None = None,
margin_size: int | None = None,
seed: int | None = None,
):
Expand All @@ -167,14 +173,22 @@ def random_spikes_selection(
----------
sorting: BaseSorting
The sorting object
num_samples: list of int
num_samples: list[int] | None, default: None
The number of samples per segment.
Can be retrieved from recording with
num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())]
method: "uniform" | "all", default: "uniform"
The method to use. Only "uniform" is implemented for now
method: "uniform" | "percentage" | "maximum_rate" | "all", default: "uniform"
The method to use.
max_spikes_per_unit: int, default: 500
The number of spikes per units
The maximum number of spikes per units
percentage: float | None, default: None
In case of `percentage` method. The proportion of spikes per units.
maximum_rate: float | None, default: None
In case of `maximum_rate` method. The cap rate per units.
bin_size_s: float | None, default: None
In case of `temporal_bins` method. The duration of a temporal bin.
k_per_bin: int | None, default: None
In case of `temporal_bins` method. Maximum number of spikes per bins.
margin_size: None | int, default: None
A margin on each border of segments to avoid border spikes
seed: None | int, default: None
Expand All @@ -185,10 +199,19 @@ def random_spikes_selection(
random_spikes_indices: np.array
Selected spike indices coresponding to the sorting spike vector.
"""
rng_methods = ("uniform", "percentage", "maximum_rate", "temporal_bins")

if method == "all":
spikes = sorting.to_spike_vector()
random_spikes_indices = np.arange(spikes.size)

elif method in rng_methods:
from spikeinterface.widgets.utils import get_segment_durations

if method == "uniform":
rng = np.random.default_rng(seed=seed)

# since un concatenated
# spikes = [ [ (sample_index, unit_index, segment_index), (), ... ], [ (), ... ]]
spikes = sorting.to_spike_vector(concatenated=False)
cum_sizes = np.cumsum([0] + [s.size for s in spikes])

Expand All @@ -199,9 +222,12 @@ def random_spikes_selection(
for unit_index, unit_id in enumerate(sorting.unit_ids):
all_unit_indices = []
for segment_index in range(sorting.get_num_segments()):
# this is local index
# this is local segment index
inds_in_seg = spike_indices[segment_index][unit_id]
if margin_size is not None:
if num_samples is None:
raise ValueError("num_samples must be provided when margin_size is used")

local_spikes = spikes[segment_index][inds_in_seg]
mask = (local_spikes["sample_index"] >= margin_size) & (
local_spikes["sample_index"] < (num_samples[segment_index] - margin_size)
Expand All @@ -211,19 +237,70 @@ def random_spikes_selection(
inds_in_seg_abs = inds_in_seg + cum_sizes[segment_index]
all_unit_indices.append(inds_in_seg_abs)
all_unit_indices = np.concatenate(all_unit_indices)
selected_unit_indices = rng.choice(
all_unit_indices, size=min(max_spikes_per_unit, all_unit_indices.size), replace=False, shuffle=False
)

if method == "uniform":
rng_size = min(max_spikes_per_unit, all_unit_indices.size)

elif method == "percentage":
if percentage is None or not (0 < percentage <= 1):
raise ValueError(f"percentage must be in the interval (0, 1]")

rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage))

elif method == "maximum_rate":
if maximum_rate is None:
raise ValueError(f"maximum_rate must be defined")

t_duration = np.sum(get_segment_durations(sorting))
rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size)

elif method == "temporal_bins":
# expressed bin sampling as a dual sub sorting problem to be fully vectorized

if None in (k_per_bin, bin_size_s):
missing = []
if k_per_bin is None:
missing.append("k_per_bin")
if bin_size_s is None:
missing.append("bin_size_s")
raise ValueError(
f"the following args need to be defined when using the 'temporal bins' method : {', '.join(missing)}"
)

sampling_frequency = sorting.get_sampling_frequency()
bin_size_freq = int(bin_size_s * sampling_frequency)

unit_spikes = np.concat(spikes)[all_unit_indices]

# local to segment so will loop and reset
bin_index = unit_spikes["sample_index"] // bin_size_freq
segment_index = unit_spikes["segment_index"]

group_values = np.stack((segment_index, bin_index), axis=1)
_, group_keys = np.unique(group_values, return_inverse=True, axis=0)

score = rng.random(all_unit_indices.size)
order = np.lexsort((score, group_keys))

ordered_unit_indices = all_unit_indices[order]

group_start = np.r_[0, np.flatnonzero(np.diff(group_keys)) + 1]
counts = np.diff(np.r_[group_start, ordered_unit_indices.size])

ranks = np.arange(ordered_unit_indices.size, step=1) - np.repeat(group_start, counts)
selection_mask = ranks <= k_per_bin
selected = ordered_unit_indices[selection_mask]
random_spikes_indices.append(selected)
continue

selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False)
random_spikes_indices.append(selected_unit_indices)

random_spikes_indices = np.concatenate(random_spikes_indices)
random_spikes_indices = np.sort(random_spikes_indices)

elif method == "all":
spikes = sorting.to_spike_vector()
random_spikes_indices = np.arange(spikes.size)
else:
raise ValueError(f"random_spikes_selection(): method must be 'all' or 'uniform'")
raise ValueError(f"random_spikes_selection(): method must be 'all' or any in {', '.join(rng_methods)}")

return random_spikes_indices

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def fit(
model_folder_path: str,
detect_peaks_params: dict,
peak_selection_params: dict,
job_kwargs: dict = None,
job_kwargs: dict | None = None,
ms_before: float = 1.0,
ms_after: float = 1.0,
whiten: bool = True,
radius_um: float = None,
radius_um: float | None = None,
) -> "IncrementalPCA":
"""
Train a pca model using the data in the recording object and the parameters provided.
Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSor
return segment_indices


def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]:
def get_segment_durations(sorting: BaseSorting, segment_indices: list[int] = None) -> list[float]:
"""
Calculate the duration of each segment in a sorting object.

Expand All @@ -410,11 +410,17 @@ def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> l
sorting : BaseSorting
The sorting object containing spike data

segment_indices : list[int] | None
List of the segment indices to process. Default to None.

Returns
-------
list[float]
List of segment durations in seconds
"""
if segment_indices is None:
segment_indices = range(sorting.get_num_segments())

spikes = sorting.to_spike_vector()

segment_boundaries = [
Expand Down