diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index 2572708..1d26f8c 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -297,8 +297,11 @@ def get_shank_count(self) -> int: """ Return the number of shanks for this probe. """ - assert self.shank_ids is not None - n = len(np.unique(self.shank_ids)) + # assert self.shank_ids is not None + if self.shank_ids is None: + n = 1 + else: + n = len(np.unique(self.shank_ids)) return n def set_contacts( @@ -380,7 +383,8 @@ def set_contacts( self.set_contact_ids(contact_ids) if shank_ids is None: - self._shank_ids = np.zeros(n, dtype=str) + # self._shank_ids = np.zeros(n, dtype=str) + self._shank_ids = None else: self._shank_ids = np.asarray(shank_ids).astype(str) if self.shank_ids.size != n: @@ -601,11 +605,15 @@ def get_shanks(self): """ Return the list of Shank objects for this Probe """ - assert self.shank_ids is not None, "Can only get shanks if `shank_ids` exist" - shanks = [] - for shank_id in np.unique(self.shank_ids): - shank = Shank(probe=self, shank_id=shank_id) - shanks.append(shank) + # assert self.shank_ids is not None, "Can only get shanks if `shank_ids` exist" + if self.shank_ids is None: + # has a unique shank + shanks = [Shank(probe=self, shank_id=None)] + else: + shanks = [] + for shank_id in np.unique(self.shank_ids): + shank = Shank(probe=self, shank_id=shank_id) + shanks.append(shank) return shanks def __eq__(self, other): @@ -1032,7 +1040,11 @@ def to_numpy(self, complete: bool = False) -> np.array: param_shape.append(k) for k in param_shape: dtype += [(k, "float64")] - dtype += [("shank_ids", "U64"), ("contact_ids", "U64")] + + if self._shank_ids is not None: + dtype += [("shank_ids", "U64")] + + dtype += [("contact_ids", "U64")] if self._contact_sides is not None: dtype += [ @@ -1060,7 +1072,8 @@ def to_numpy(self, complete: bool = False) -> np.array: for k, v in p.items(): arr[k][i] = v - arr["shank_ids"] = self.shank_ids + if self._shank_ids is not None: + arr["shank_ids"] = self.shank_ids if self._contact_sides is not None: arr["contact_sides"] = self.contact_sides diff --git a/src/probeinterface/shank.py b/src/probeinterface/shank.py index bd2b448..4d872e0 100644 --- a/src/probeinterface/shank.py +++ b/src/probeinterface/shank.py @@ -15,7 +15,10 @@ def __init__(self, probe, shank_id): self.shank_id = shank_id def get_indices(self): - (inds,) = np.nonzero(self.probe.shank_ids == self.shank_id) + if self.probe.shank_ids is None: + inds = np.arange(self.probe.get_contact_count(), dtype=int) + else: + inds = np.flatnonzero(self.probe.shank_ids == self.shank_id) return inds def get_contact_count(self): diff --git a/tests/test_io/test_3brain.py b/tests/test_io/test_3brain.py index 1128a23..a6d3fe8 100644 --- a/tests/test_io/test_3brain.py +++ b/tests/test_io/test_3brain.py @@ -43,3 +43,7 @@ def test_3brain(): assert np.all(np.isclose(np.diff(unique_rows), contact_pitch)), file unique_cols = np.unique(probe.contact_positions[:, 0]) assert np.all(np.isclose(np.diff(unique_cols), contact_pitch)) + + +if __name__ == "__main__": + test_3brain() \ No newline at end of file diff --git a/tests/test_io/test_io.py b/tests/test_io/test_io.py index f038c52..1d7ebb6 100644 --- a/tests/test_io/test_io.py +++ b/tests/test_io/test_io.py @@ -83,7 +83,8 @@ def test_BIDS_format(tmp_path): probe.set_contact_ids(probe_el_ids) # switch to more generic dtype for shank_ids - probe.set_shank_ids(probe.shank_ids.astype(str)) + if probe.shank_ids is not None: + probe.set_shank_ids(probe.shank_ids.astype(str)) write_BIDS_probe(folder_path, probegroup) @@ -103,7 +104,8 @@ def test_BIDS_format(tmp_path): t = np.array([list(probe_read.contact_ids).index(elid) for elid in probe_orig.contact_ids]) assert all(probe_orig.contact_ids == probe_read.contact_ids[t]) - assert all(probe_orig.shank_ids == probe_read.shank_ids[t]) + if probe_orig.shank_ids is not None: + assert all(probe_orig.shank_ids == probe_read.shank_ids[t]) assert all(probe_orig.contact_shapes == probe_read.contact_shapes[t]) assert probe_orig.ndim == probe_read.ndim assert probe_orig.si_units == probe_read.si_units @@ -206,8 +208,11 @@ def test_prb(tmp_path): if __name__ == "__main__": - # test_probeinterface_format() - # test_BIDS_format() + import tempfile + tmp_path = Path(tempfile.mkdtemp()) + + # test_probeinterface_format(tmp_path) + test_BIDS_format(tmp_path) # test_BIDS_format_empty() # test_BIDS_format_minimal() pass diff --git a/tests/test_probe.py b/tests/test_probe.py index 48a3b82..f697427 100644 --- a/tests/test_probe.py +++ b/tests/test_probe.py @@ -229,10 +229,10 @@ def test_double_side_probe(): if __name__ == "__main__": - test_probe() + import tempfile + tmp_path = Path(tempfile.mkdtemp()) - tmp_path = Path("tmp") - tmp_path.mkdir(exist_ok=True) - test_save_to_zarr(tmp_path) + test_probe() + test_save_to_zarr(tmp_path) test_double_side_probe()