Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,11 @@ def get_shank_count(self) -> int:
"""
Return the number of shanks for this probe.
"""
assert self.shank_ids is not None
n = len(np.unique(self.shank_ids))
# assert self.shank_ids is not None
if self.shank_ids is None:
n = 1
else:
n = len(np.unique(self.shank_ids))
return n

def set_contacts(
Expand Down Expand Up @@ -380,7 +383,8 @@ def set_contacts(
self.set_contact_ids(contact_ids)

if shank_ids is None:
self._shank_ids = np.zeros(n, dtype=str)
# self._shank_ids = np.zeros(n, dtype=str)
self._shank_ids = None
else:
self._shank_ids = np.asarray(shank_ids).astype(str)
if self.shank_ids.size != n:
Expand Down Expand Up @@ -601,11 +605,15 @@ def get_shanks(self):
"""
Return the list of Shank objects for this Probe
"""
assert self.shank_ids is not None, "Can only get shanks if `shank_ids` exist"
shanks = []
for shank_id in np.unique(self.shank_ids):
shank = Shank(probe=self, shank_id=shank_id)
shanks.append(shank)
# assert self.shank_ids is not None, "Can only get shanks if `shank_ids` exist"
if self.shank_ids is None:
# has a unique shank
shanks = [Shank(probe=self, shank_id=None)]
else:
shanks = []
for shank_id in np.unique(self.shank_ids):
shank = Shank(probe=self, shank_id=shank_id)
shanks.append(shank)
return shanks

def __eq__(self, other):
Expand Down Expand Up @@ -1032,7 +1040,11 @@ def to_numpy(self, complete: bool = False) -> np.array:
param_shape.append(k)
for k in param_shape:
dtype += [(k, "float64")]
dtype += [("shank_ids", "U64"), ("contact_ids", "U64")]

if self._shank_ids is not None:
dtype += [("shank_ids", "U64")]

dtype += [("contact_ids", "U64")]

if self._contact_sides is not None:
dtype += [
Expand Down Expand Up @@ -1060,7 +1072,8 @@ def to_numpy(self, complete: bool = False) -> np.array:
for k, v in p.items():
arr[k][i] = v

arr["shank_ids"] = self.shank_ids
if self._shank_ids is not None:
arr["shank_ids"] = self.shank_ids

if self._contact_sides is not None:
arr["contact_sides"] = self.contact_sides
Expand Down
5 changes: 4 additions & 1 deletion src/probeinterface/shank.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def __init__(self, probe, shank_id):
self.shank_id = shank_id

def get_indices(self):
(inds,) = np.nonzero(self.probe.shank_ids == self.shank_id)
if self.probe.shank_ids is None:
inds = np.arange(self.probe.get_contact_count(), dtype=int)
else:
inds = np.flatnonzero(self.probe.shank_ids == self.shank_id)
return inds

def get_contact_count(self):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_io/test_3brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ def test_3brain():
assert np.all(np.isclose(np.diff(unique_rows), contact_pitch)), file
unique_cols = np.unique(probe.contact_positions[:, 0])
assert np.all(np.isclose(np.diff(unique_cols), contact_pitch))


if __name__ == "__main__":
test_3brain()
13 changes: 9 additions & 4 deletions tests/test_io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def test_BIDS_format(tmp_path):
probe.set_contact_ids(probe_el_ids)

# switch to more generic dtype for shank_ids
probe.set_shank_ids(probe.shank_ids.astype(str))
if probe.shank_ids is not None:
probe.set_shank_ids(probe.shank_ids.astype(str))

write_BIDS_probe(folder_path, probegroup)

Expand All @@ -103,7 +104,8 @@ def test_BIDS_format(tmp_path):
t = np.array([list(probe_read.contact_ids).index(elid) for elid in probe_orig.contact_ids])

assert all(probe_orig.contact_ids == probe_read.contact_ids[t])
assert all(probe_orig.shank_ids == probe_read.shank_ids[t])
if probe_orig.shank_ids is not None:
assert all(probe_orig.shank_ids == probe_read.shank_ids[t])
assert all(probe_orig.contact_shapes == probe_read.contact_shapes[t])
assert probe_orig.ndim == probe_read.ndim
assert probe_orig.si_units == probe_read.si_units
Expand Down Expand Up @@ -206,8 +208,11 @@ def test_prb(tmp_path):


if __name__ == "__main__":
# test_probeinterface_format()
# test_BIDS_format()
import tempfile
tmp_path = Path(tempfile.mkdtemp())

# test_probeinterface_format(tmp_path)
test_BIDS_format(tmp_path)
# test_BIDS_format_empty()
# test_BIDS_format_minimal()
pass
8 changes: 4 additions & 4 deletions tests/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,10 @@ def test_double_side_probe():


if __name__ == "__main__":
test_probe()
import tempfile
tmp_path = Path(tempfile.mkdtemp())

tmp_path = Path("tmp")
tmp_path.mkdir(exist_ok=True)
test_save_to_zarr(tmp_path)

test_probe()
test_save_to_zarr(tmp_path)
test_double_side_probe()
Loading