Skip to content

Commit

Permalink
Update tests and validation
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhardcastle committed Feb 2, 2024
1 parent 67df936 commit 165a174
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 90 deletions.
216 changes: 127 additions & 89 deletions src/npc_sync/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
import datetime
import io
import logging
import warnings
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, Literal, Union
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Union

import h5py
import npc_io
Expand All @@ -31,6 +30,9 @@

SyncPathOrDataset: TypeAlias = Union[npc_io.PathLike, h5py.File, "SyncDataset"]

FIRST_SOUND_ON_SYNC_DATE = datetime.date(2023, 8, 31)
"""Prior to this date, there's no sync line with "sound running" signal, on any rig: need to
use NI-DAQ analog recording on OpenEphys PXI to get sound onset times."""

def get_sync_data(sync_path_or_data: SyncPathOrDataset) -> SyncDataset:
"""Open a path or file-like object and return a SyncDataset object."""
Expand All @@ -54,6 +56,22 @@ def get_bit(uint_array: npt.NDArray, bit: int) -> npt.NDArray[np.uint8]:
return np.bitwise_and(uint_array, 2**bit).astype(bool).astype(np.uint8)


def get_sync_line_for_stim_onset(
waveform_type: str | Literal["sound", "audio", "opto"],
date: datetime.date | None = None,
) -> int:
if any(label in waveform_type for label in ("aud", "sound")):
if date and date < FIRST_SOUND_ON_SYNC_DATE:
raise ValueError(
f"Sound only recorded on sync since {FIRST_SOUND_ON_SYNC_DATE.isoformat()}: {date = }"
)
return 1
elif "opto" in waveform_type:
return 11
else:
raise ValueError(f"Unexpected value: {waveform_type = }")


class SyncDataset:
"""
A sync dataset. Contains methods for loading
Expand All @@ -66,14 +84,11 @@ class SyncDataset:
Examples
--------
>>> dset = SyncDataset('my_h5_file.h5') # doctest: +SKIP
>>> logger.info(dset.meta_data) # doctest: +SKIP
>>> dset.stats() # doctest: +SKIP
>>> dset.close() # doctest: +SKIP
>>> dset = SyncDataset('s3://aind-ephys-data/ecephys_676909_2023-12-14_12-43-11/behavior_videos/20231214T124311.h5')
>>> dset.validate(opto=True, audio=True)
>>> with SyncDataset('my_h5_file.h5') as d: # doctest: +SKIP
... logger.info(dset.meta_data)
... dset.stats()
... dset.validate()
The sync file documentation from MPE can be found at
sharepoint > Instrumentation > Shared Documents > Sync_line_labels_discussion_2020-01-27-.xlsx # NOQA E501
Expand All @@ -82,35 +97,66 @@ class SyncDataset:
"""
required_lines: ClassVar[list[str | int]] = [
"barcode_ephys",
"vsync_stim",
"stim_photodiode",
"stim_running",
*[f"{cam}_cam_{suffix}" for cam in ("beh", "eye", "face") for suffix in ("frame_readout", "exposing")],
"lick_sensor",
]

def __init__(self, path) -> None:
if isinstance(path, self.__class__):
self = path
else:
self.dfile = self.load(path)
self._check_line_labels()

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.dfile.filename})"

def _check_line_labels(self) -> None:
if hasattr(self, "line_labels"):
deprecated_keys = set(self.line_labels)
if deprecated_keys:
warnings.warn(
(
f"The loaded sync file contains the "
f"following deprecated line label keys: "
f"{deprecated_keys}. Consider updating the "
f"sync file line labels."
),
stacklevel=2,
)
else:
warnings.warn(
("The loaded sync file has no line labels and may " "not be valid."),
stacklevel=2,
)
def validate(self, opto: bool = False, audio: bool = False) -> None:
"""
Check all members of `self.required_lines` are present and have events.
Check vsync and photodiode events can be interpreted to deduce stim blocks.
- if opto or audio are True, work out which line indices correspond and
check those too
"""
self._check_line_labels(opto=opto, audio=audio)
self._check_stim_photodiode()
self._check_vsyncs()

def _check_line_labels(self, opto: bool = False, audio: bool = False) -> None:
if not hasattr(self, "line_labels"):
raise AssertionError("Sync file has no line labels.")
lines = self.required_lines
if opto:
lines.append(self.get_line_for_stim_onset("opto"))
if audio and self.start_time.date() >= FIRST_SOUND_ON_SYNC_DATE:
lines.append(self.get_line_for_stim_onset("audio"))
for line in lines:
self._check_line(line)

def _check_line(self, label_or_index: str | int) -> None:
try:
stats = self.line_stats(label_or_index)
except IndexError:
raise AssertionError(f"Sync file has no line {label_or_index}")
if stats is None:
raise AssertionError(f"Sync file has no events on line {label_or_index}")

def _check_stim_photodiode(self) -> None:
try:
_ = self.expected_diode_flip_rate
except ValueError as exc:
raise AssertionError("Frame rate estimated from diode flips is abnormal.") from exc

def _check_vsyncs(self) -> None:
try:
_ = self.vsync_times_in_blocks
except ValueError as exc:
raise AssertionError("vsyncs should be divisible into blocks corresponding to individual stims presented, but they appear abnormal.") from exc

def _process_times(self) -> npt.NDArray[np.int64]:
"""
Expand Down Expand Up @@ -151,7 +197,10 @@ def load(self, path) -> h5py.File:
self.line_labels: Sequence[str] = self.meta_data["line_labels"]
self.times = self._process_times()
return self.dfile


def get_line_for_stim_onset(self, waveform_type: Literal["sound", "audio", "opto"]) -> int:
return get_sync_line_for_stim_onset(waveform_type=waveform_type, date=self.start_time.date())

@property
def sample_freq(self) -> float:
try:
Expand Down Expand Up @@ -517,78 +566,63 @@ def line_stats(self, line, print_results=True) -> dict[str, Any] | None:
total_falling = len(falling)

# get labels
label = self.line_labels[line]
label = self.line_labels[bit]

if total_events <= 0:
if print_results:
logger.info("*" * 70)
logger.info("No events on line: %s" % line)
logger.info("*" * 70)
return None
elif total_events <= 10:
if print_results:
logger.info("*" * 70)
logger.info("Sparse events on line: %s" % line)
logger.info("Rising: %s" % total_rising)
logger.info("Falling: %s" % total_falling)
logger.info("*" * 70)
return {
"line": line,
"bit": bit,
"total_rising": total_rising,
"total_falling": total_falling,
"avg_freq": None,
"duty_cycle": None,
}
else:
# period
period = self.period(line)
# period
period = self.period(line)

avg_period = period["avg"]
max_period = period["max"]
min_period = period["min"]
period_sd = period["sd"]
avg_period = period["avg"]
max_period = period["max"]
min_period = period["min"]
period_sd = period["sd"]

# freq
avg_freq = self.frequency(line)

# duty cycle
duty_cycle = self.duty_cycle(line)

if print_results:
logger.info("*" * 70)
# freq
avg_freq = self.frequency(line)

# duty cycle
duty_cycle = self.duty_cycle(line)

if print_results:
logger.info("*" * 70)
if total_events <= 10:
logger.warning("Sparse events on line: %s" % line)
else:
logger.info("Quick stats for line: %s" % line)
logger.info("Label: %s" % label)
logger.info("Bit: %i" % bit)
logger.info("Data points: %i" % total_data_points)
logger.info("Total transitions: %i" % total_events)
logger.info("Rising edges: %i" % total_rising)
logger.info("Falling edges: %i" % total_falling)
logger.info("Average period: %s" % avg_period)
logger.info("Minimum period: %s" % min_period)
logger.info("Max period: %s" % max_period)
logger.info("Period SD: %s" % period_sd)
logger.info("Average freq: %s" % avg_freq)
logger.info("Duty cycle: %s" % duty_cycle)

logger.info("*" * 70)
logger.info("Label: %s" % label)
logger.info("Bit: %i" % bit)
logger.info("Data points: %i" % total_data_points)
logger.info("Total transitions: %i" % total_events)
logger.info("Rising edges: %i" % total_rising)
logger.info("Falling edges: %i" % total_falling)
logger.info("Average period: %s" % avg_period)
logger.info("Minimum period: %s" % min_period)
logger.info("Max period: %s" % max_period)
logger.info("Period SD: %s" % period_sd)
logger.info("Average freq: %s" % avg_freq)
logger.info("Duty cycle: %s" % duty_cycle)
logger.info("*" * 70)

return {
"line": line,
"label": label,
"bit": bit,
"total_data_points": total_data_points,
"total_events": total_events,
"total_rising": total_rising,
"total_falling": total_falling,
"avg_period": avg_period,
"min_period": min_period,
"max_period": max_period,
"period_sd": period_sd,
"avg_freq": avg_freq,
"duty_cycle": duty_cycle,
}
return {
"line": line,
"label": label,
"bit": bit,
"total_data_points": total_data_points,
"total_events": total_events,
"total_rising": total_rising,
"total_falling": total_falling,
"avg_period": avg_period,
"min_period": min_period,
"max_period": max_period,
"period_sd": period_sd,
"avg_freq": avg_freq,
"duty_cycle": duty_cycle,
}

def period(
self, line: str | int, edge: Literal["rising", "falling"] = "rising"
Expand Down Expand Up @@ -1302,6 +1336,8 @@ def plot_stim_onsets(self) -> matplotlib.figure.Figure:
legend = axes[ind].get_legend()
if ind > 0 and legend is not None:
legend.remove()
fig.set_size_inches(10, 5 * len(fig.axes))
fig.subplots_adjust(hspace=0.3)
return fig

def plot_stim_offsets(self) -> matplotlib.figure.Figure:
Expand Down Expand Up @@ -1346,6 +1382,8 @@ def plot_stim_offsets(self) -> matplotlib.figure.Figure:
legend = axes[ind].get_legend()
if ind > 0 and legend is not None:
legend.remove()
fig.set_size_inches(10, 5 * len(fig.axes))
fig.subplots_adjust(hspace=0.3)
return fig

def plot_diode_measured_sync_square_flips(
Expand Down
Binary file added tests/output/test_stim_ends.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/output/test_stim_starts.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


def test_import_package():
pass
pass
12 changes: 12 additions & 0 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import npc_sync


def test_plots():
dset = npc_sync.SyncDataset('s3://aind-ephys-data/ecephys_662892_2023-08-21_12-43-45/behavior/20230821T124345.h5')
dset.plot_stim_onsets().savefig('tests/output/test_stim_starts.png')
dset.plot_stim_offsets().savefig('tests/output/test_stim_ends.png')

if __name__ == '__main__':
import pytest

pytest.main(['-s', __file__])

0 comments on commit 165a174

Please sign in to comment.