diff --git a/src/spikeinterface/preprocessing/generic.py b/src/spikeinterface/preprocessing/generic.py new file mode 100644 index 0000000000..f8bf0284be --- /dev/null +++ b/src/spikeinterface/preprocessing/generic.py @@ -0,0 +1,36 @@ +from functools import partial + +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.core.core_tools import define_function_from_class + + +class GenericPreprocessor(BasePreprocessor): + def __init__(self, recording, function, **function_kwargs): + super().__init__(recording) + self._serializability["json"] = False + + # Heavy computation can be done at the __init__ if needed + self.function_to_apply = partial(function, **function_kwargs) + + # Initialize segments + for segment in recording._recording_segments: + processed_segment = GenericPreprocessorSegment(segment, self.function_to_apply) + self.add_recording_segment(processed_segment) + + self._kwargs = {"recording": recording, "func": function} + self._kwargs.update(**function_kwargs) + + +class GenericPreprocessorSegment(BasePreprocessorSegment): + def __init__(self, parent_segment, function_to_apply): + super().__init__(parent_segment) + self.function_to_apply = function_to_apply # Function to apply to the traces + + def get_traces(self, start_frame, end_frame, channel_indices): + # Fetch the traces from the parent segment + traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) + # Apply the function to the traces + return self.function_to_apply(traces) + + +generic_preprocessor = define_function_from_class(GenericPreprocessor, name="generic_preprocessor") diff --git a/src/spikeinterface/preprocessing/preprocessinglist.py b/src/spikeinterface/preprocessing/preprocessinglist.py index bdf5f2219c..0708948d42 100644 --- a/src/spikeinterface/preprocessing/preprocessinglist.py +++ b/src/spikeinterface/preprocessing/preprocessinglist.py @@ -1,6 +1,8 @@ from __future__ import annotations ### PREPROCESSORS ### + + from .resample import ResampleRecording, resample from .decimate import DecimateRecording, decimate from .filter import ( @@ -43,7 +45,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed - +from .generic import GenericPreprocessor, generic_preprocessor preprocessers_full_list = [ # filter stuff diff --git a/src/spikeinterface/preprocessing/tests/test_generic_preprocessor.py b/src/spikeinterface/preprocessing/tests/test_generic_preprocessor.py new file mode 100644 index 0000000000..7f6bc5f319 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_generic_preprocessor.py @@ -0,0 +1,20 @@ +import numpy as np +import pytest + +from spikeinterface.core.generate import generate_recording +from spikeinterface.preprocessing import GenericPreprocessor + + +def test_basic_use(): + + recording = generate_recording(num_channels=4, durations=[1.0]) + recording = recording.rename_channels(["a", "b", "c", "d"]) + function = np.mean # function to apply to the traces + + # Initialize the preprocessor + preprocessor = GenericPreprocessor(recording, function) + + traces = preprocessor.get_traces(channel_ids=["a", "d"]) + expected_traces = np.mean(recording.get_traces(channel_ids=["a", "d"])) + + np.testing.assert_allclose(traces, expected_traces)