Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Micromed segments #1589

Merged
merged 11 commits into from
Jan 17, 2025
115 changes: 78 additions & 37 deletions neo/rawio/micromedrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(self, filename=""):

def _parse_header(self):

self._buffer_descriptions = {0: {0: {}}}

with open(self.filename, "rb") as fid:
f = StructFile(fid)
Expand All @@ -67,6 +66,7 @@ def _parse_header(self):
rec_datetime = datetime.datetime(year + 1900, month, day, hour, minute, sec)

Data_Start_Offset, Num_Chan, Multiplexer, Rate_Min, Bytes = f.read_f("IHHHH", offset=138)
sig_dtype = "u" + str(Bytes)

# header version
(header_version,) = f.read_f("b", offset=175)
Expand Down Expand Up @@ -99,25 +99,35 @@ def _parse_header(self):
if zname != zname2.decode("ascii").strip(" "):
raise NeoReadWriteError("expected the zone name to match")

# raw signals memmap
sig_dtype = "u" + str(Bytes)
signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset)
buffer_id = "0"
stream_id = "0"
self._buffer_descriptions[0][0][buffer_id] = {
"type": "raw",
"file_path": str(self.filename),
"dtype": sig_dtype,
"order": "C",
"file_offset": 0,
"shape": signal_shape,
}

# "TRONCA" zone define segments
zname2, pos, length = zones["TRONCA"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Light recommendation to have better name for pos :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outch it was already everywhere

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not crucial here :) But if any micromed issues come up I will send them your way.

f.seek(pos)
max_segments = 100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this 100 come from. We probably need a comment for any magic number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this magic number come from other implemtation in matlab. Maybe this should be exposed, I guess this is to avoid infinite while on some dataset.

self.info_segments = []
for i in range(max_segments):
seg_start = int(np.frombuffer(f.read(4), dtype="u4")[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Could we get a comment for why we are reading 4 at a time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 is 4bytes for "u4"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant why are you doing 'u4'? why not some other size for reading. How do you know reading u4 gets us what we need instead of some other size? If we do a comment then we know that either 1) this size is documented or 2) you are doing this empirically.

trace_offset = int(np.frombuffer(f.read(4), dtype="u4")[0])
if seg_start == 0 and trace_offset == 0:
break
else:
self.info_segments.append((seg_start, trace_offset))

if len(self.info_segments) == 0:
# one unique segment = general case
self.info_segments.append((0, 0))

nb_segment = len(self.info_segments)

# Reading Code Info
zname2, pos, length = zones["ORDER"]
f.seek(pos)
code = np.frombuffer(f.read(Num_Chan * 2), dtype="u2")

# unique stream and buffer
buffer_id = "0"
stream_id = "0"

units_code = {-1: "nV", 0: "uV", 1: "mV", 2: 1, 100: "percent", 101: "dimensionless", 102: "dimensionless"}
signal_channels = []
sig_grounds = []
Expand All @@ -140,10 +150,8 @@ def _parse_header(self):
(sampling_rate,) = f.read_f("H")
sampling_rate *= Rate_Min
chan_id = str(c)
signal_channels.append((chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id))

signal_channels.append(
(chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)
)

signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)

Expand All @@ -155,6 +163,32 @@ def _parse_header(self):
raise NeoReadWriteError("The sampling rates must be the same across signal channels")
self._sampling_rate = float(np.unique(signal_channels["sampling_rate"])[0])

# memmap traces buffer
full_signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset)
seg_limits = [trace_offset for seg_start, trace_offset in self.info_segments] + [full_signal_shape[0]]
self._t_starts = []
self._buffer_descriptions = {0 :{}}
for seg_index in range(nb_segment):
seg_start, trace_offset = self.info_segments[seg_index]
self._t_starts.append(seg_start / self._sampling_rate)

start = seg_limits[seg_index]
stop = seg_limits[seg_index + 1]

shape = (stop - start, Num_Chan)
file_offset = Data_Start_Offset + start * np.dtype(sig_dtype).itemsize * Num_Chan
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
self._buffer_descriptions[0][seg_index] = {}
self._buffer_descriptions[0][seg_index][buffer_id] = {
"type" : "raw",
"file_path" : str(self.filename),
"dtype" : sig_dtype,
"order": "C",
"file_offset" : file_offset,
"shape" : shape,
}



# Event channels
event_channels = []
event_channels.append(("Trigger", "", "event"))
Expand All @@ -176,13 +210,18 @@ def _parse_header(self):
dtype = np.dtype(ev_dtype)
rawevent = np.memmap(self.filename, dtype=dtype, mode="r", offset=pos, shape=length // dtype.itemsize)

keep = (
(rawevent["start"] >= rawevent["start"][0])
& (rawevent["start"] < signal_shape[0])
& (rawevent["start"] != 0)
)
rawevent = rawevent[keep]
self._raw_events.append(rawevent)
# important : all events timing are related to the first segment t_start
self._raw_events.append([])
for seg_index in range(nb_segment):
left_lim = seg_limits[seg_index]
right_lim = seg_limits[seg_index + 1]
keep = (
(rawevent["start"] >= left_lim)
& (rawevent["start"] < right_lim)
& (rawevent["start"] != 0)
)
self._raw_events[-1].append(rawevent[keep])


# No spikes
spike_channels = []
Expand All @@ -191,7 +230,7 @@ def _parse_header(self):
# fille into header dict
self.header = {}
self.header["nb_block"] = 1
self.header["nb_segment"] = [1]
self.header["nb_segment"] = [nb_segment]
self.header["signal_buffers"] = signal_buffers
self.header["signal_streams"] = signal_streams
self.header["signal_channels"] = signal_channels
Expand All @@ -216,38 +255,40 @@ def _source_name(self):
return self.filename

def _segment_t_start(self, block_index, seg_index):
return 0.0
return self._t_starts[seg_index]

def _segment_t_stop(self, block_index, seg_index):
sig_size = self.get_signal_size(block_index, seg_index, 0)
t_stop = sig_size / self._sampling_rate
return t_stop
duration = self.get_signal_size(block_index, seg_index, stream_index=0) / self._sampling_rate
return duration + self.segment_t_start(block_index, seg_index)

def _get_signal_t_start(self, block_index, seg_index, stream_index):
if stream_index != 0:
raise ValueError("`stream_index` must be 0")
return 0.0
assert stream_index == 0
return self._t_starts[seg_index]

def _spike_count(self, block_index, seg_index, unit_index):
return 0

def _event_count(self, block_index, seg_index, event_channel_index):
n = self._raw_events[event_channel_index].size
n = self._raw_events[event_channel_index][seg_index].size
return n

def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):

raw_event = self._raw_events[event_channel_index]
raw_event = self._raw_events[event_channel_index][seg_index]

# important : all events timing are related to the first segment t_start
seg_start0, _ = self.info_segments[0]

if t_start is not None:
keep = raw_event["start"] >= int(t_start * self._sampling_rate)
keep = raw_event["start"] + seg_start0 >= int(t_start * self._sampling_rate)
raw_event = raw_event[keep]

if t_stop is not None:
keep = raw_event["start"] <= int(t_stop * self._sampling_rate)
keep = raw_event["start"] + seg_start0 <= int(t_stop * self._sampling_rate)
raw_event = raw_event[keep]

timestamp = raw_event["start"]
timestamp = raw_event["start"] + seg_start0

if event_channel_index < 2:
durations = None
else:
Expand Down
38 changes: 37 additions & 1 deletion neo/test/rawiotest/test_micromedrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,50 @@

from neo.test.rawiotest.common_rawio_test import BaseTestRawIO

import numpy as np

class TestMicromedRawIO(
BaseTestRawIO,
unittest.TestCase,
):
rawioclass = MicromedRawIO
entities_to_download = ["micromed"]
entities_to_test = ["micromed/File_micromed_1.TRC"]
entities_to_test = [
"micromed/File_micromed_1.TRC",
"micromed/File_mircomed2.TRC",
"micromed/File_mircomed2_2segments.TRC",
]

def test_micromed_multi_segments(self):
file_full = self.get_local_path("micromed/File_mircomed2.TRC")
file_splitted = self.get_local_path("micromed/File_mircomed2_2segments.TRC")

# the second file contains 2 pieces of the first file
# so it is 2 segments with the same traces but reduced
# note that traces in the splited can differ at the very end of the cut

reader1 = MicromedRawIO(file_full)
reader1.parse_header()
assert reader1.segment_count(block_index=0) == 1
assert reader1.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) == 0.
traces1 = reader1.get_analogsignal_chunk(stream_index=0)

reader2 = MicromedRawIO(file_splitted)
reader2.parse_header()
print(reader2)
assert reader2.segment_count(block_index=0) == 2

# check that pieces of the second file is equal to the first file (except a truncation at the end)
for seg_index in range(2):
t_start = reader2.get_signal_t_start(block_index=0, seg_index=seg_index, stream_index=0)
assert t_start > 0
sr = reader2.get_signal_sampling_rate(stream_index=0)
ind_start = int(t_start * sr)
traces1_chunk = traces1[ind_start: ind_start+traces2.shape[0]]
traces2 = reader2.get_analogsignal_chunk(block_index=0, seg_index=seg_index, stream_index=0)
# we remove the last 100 sample because tools that cut traces is truncating the last buffer
assert np.array_equal(traces2[:-100], traces1_chunk[:-100])



if __name__ == "__main__":
Expand Down
Loading