Skip to content

Commit

Permalink
Merge pull request #1609 from h-mayorquin/fix_time_metadata
Browse files Browse the repository at this point in the history
Add timing metadata for `SpikeGLXRawIO`
  • Loading branch information
zm711 authored Jan 17, 2025
2 parents 6f571f0 + a6fa96d commit d722b0b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
20 changes: 16 additions & 4 deletions neo/rawio/spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,25 @@ def _parse_header(self):
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)

# deal with nb_segment and t_start/t_stop per segment
self._t_starts = {seg_index: 0.0 for seg_index in range(nb_segment)}

self._t_starts = {stream_name: {} for stream_name in stream_names}
self._t_stops = {seg_index: 0.0 for seg_index in range(nb_segment)}
for seg_index in range(nb_segment):
for stream_name in stream_names:

for stream_name in stream_names:
for seg_index in range(nb_segment):
info = self.signals_info_dict[seg_index, stream_name]

frame_start = float(info["meta"]["firstSample"])
sampling_frequency = info["sampling_rate"]
t_start = frame_start / sampling_frequency

self._t_starts[stream_name][seg_index] = t_start
t_stop = info["sample_length"] / info["sampling_rate"]
self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop)




# fille into header dict
self.header = {}
self.header["nb_block"] = 1
Expand Down Expand Up @@ -276,7 +287,8 @@ def _segment_t_stop(self, block_index, seg_index):
return self._t_stops[seg_index]

def _get_signal_t_start(self, block_index, seg_index, stream_index):
return 0.0
stream_name = self.header["signal_streams"][stream_index]["name"]
return self._t_starts[stream_name][seg_index]

def _event_count(self, event_channel_idx, block_index=None, seg_index=None):
timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_idx, None, None)
Expand Down
59 changes: 59 additions & 0 deletions neo/test/rawiotest/test_spikeglxrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,65 @@ def test_nidq_digital_channel(self):
atol = 0.001
assert np.allclose(on_diff, 1, atol=atol)

def test_t_start_reading(self):
"""Test that t_start values are correctly read for all streams and segments."""

# Expected t_start values for each stream and segment
expected_t_starts = {
'imec0.ap': {
0: 15.319535472007237,
1: 15.339535431281986,
2: 21.284723325294053,
3: 21.3047232845688
},
'imec1.ap': {
0: 15.319554693264516,
1: 15.339521518106308,
2: 21.284735282142822,
3: 21.304702106984614
},
'imec0.lf': {
0: 15.3191688060872,
1: 15.339168765361949,
2: 21.284356659374016,
3: 21.304356618648765
},
'imec1.lf': {
0: 15.319321358082725,
1: 15.339321516521915,
2: 21.284568614155827,
3: 21.30456877259502
}
}

# Initialize the RawIO
rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4"))
rawio.parse_header()

# Get list of stream names
stream_names = rawio.header["signal_streams"]["name"]

# Test t_start for each stream and segment
for stream_name, expected_values in expected_t_starts.items():
# Get stream index
stream_index = list(stream_names).index(stream_name)

# Check each segment
for seg_index, expected_t_start in expected_values.items():
actual_t_start = rawio.get_signal_t_start(
block_index=0,
seg_index=seg_index,
stream_index=stream_index
)

# Use numpy.testing for proper float comparison
np.testing.assert_allclose(
actual_t_start,
expected_t_start,
rtol=1e-9,
atol=1e-9,
err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}"
)

if __name__ == "__main__":
unittest.main()

0 comments on commit d722b0b

Please sign in to comment.