Skip to content

Commit

Permalink
76 add expected metadata files split by modality (#83)
Browse files Browse the repository at this point in the history
* feat: add expected modality model file

1 = required, 0 = optional, -1 = never present

* updated expected files

* refactor: allow passing the key to field_handler methods

* feat: developing ExpectedFiles class

* feat: adding required name column, fixing enum issues

* chore: lint

* doc: docstrings

* test: ExpectedFiles tests

* chore: lint

* refactor: ignored -> excluded

* test: fix tests

---------

Co-authored-by: Saskia de Vries <sejdevries@gmail.com>
  • Loading branch information
dbirman and saskiad authored Sep 25, 2024
1 parent 7ed845a commit 419163c
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 4 deletions.
67 changes: 66 additions & 1 deletion src/aind_data_schema_models/modalities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Module for Modality definitions"""
"""Module for Modality and ExpectedFile definitions"""

from importlib_resources import files
from pydantic import ConfigDict, Field
from enum import IntEnum
from typing_extensions import Annotated

from aind_data_schema_models.pid_names import BaseName
from aind_data_schema_models.utils import create_literal_class, read_csv
Expand All @@ -25,3 +27,66 @@ class ModalityModel(BaseName):

Modality.abbreviation_map = {m().abbreviation: m() for m in Modality.ALL}
Modality.from_abbreviation = lambda x: Modality.abbreviation_map.get(x)


class FileRequirement(IntEnum):
"""Whether a file is required for a specific modality"""

REQUIRED = 1
OPTIONAL = 0
EXCLUDED = -1


class ExpectedFilesModel(BaseName):
"""Model config"""

model_config = ConfigDict(frozen=True)
name: str = Field(..., title="Modality name")
modality_abbreviation: str = Field(..., title="Modality abbreviation")
subject: FileRequirement = Field(..., title="Subject file requirement")
data_description: FileRequirement = Field(..., title="Data description file requirement")
procedures: FileRequirement = Field(..., title="Procedures file requirement")
session: FileRequirement = Field(..., title="Session file requirement")
rig: FileRequirement = Field(..., title="Processing file requirement")
processing: FileRequirement = Field(..., title="Processing file requirement")
acquisition: FileRequirement = Field(..., title="Acquisition file requirement")
instrument: FileRequirement = Field(..., title="Instrument file requirement")
quality_control: FileRequirement = Field(..., title="Quality control file requirement")


def map_file_requirement(value: int, record: dict, field: str):
"""Map integers to Annotated[FileRequirement, value]
Parameters
----------
value : int
File required value
record : dict
Full class dictionary
field : str
Field name that the FileRequirement value will be assigned to
"""
record[field] = Annotated[
FileRequirement,
Field(default=FileRequirement(int(value))),
]


ExpectedFiles = create_literal_class(
objects=read_csv(str(files("aind_data_schema_models.models").joinpath("modality_expected_files.csv"))),
class_name="ExpectedFiles",
base_model=ExpectedFilesModel,
discriminator="modality_abbreviation",
field_handlers={
"subject": map_file_requirement,
"data_description": map_file_requirement,
"procedures": map_file_requirement,
"session": map_file_requirement,
"rig": map_file_requirement,
"processing": map_file_requirement,
"acquisition": map_file_requirement,
"instrument": map_file_requirement,
"quality_control": map_file_requirement,
},
class_module=__name__,
)
15 changes: 15 additions & 0 deletions src/aind_data_schema_models/models/modality_expected_files.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name,modality_abbreviation,subject,data_description,procedures,session,rig,processing,acquisition,instrument,quality_control
behavior,behavior,1,1,1,1,1,0,-1,-1,0
behavior-videos,behavior-videos,1,1,1,1,1,0,-1,-1,0
confocal,confocal,1,1,1,-1,-1,1,1,1,0
EMG,EMG,1,1,1,1,1,0,-1,-1,0
ecephys,ecephys,1,1,1,1,1,0,-1,-1,0
fib,fib,1,1,1,1,1,0,-1,-1,0
fMOST,fMOST,1,1,1,-1,-1,1,1,1,0
icephys,icephys,1,1,1,1,1,0,-1,-1,0
ISI,ISI,1,1,1,1,1,0,-1,-1,0
MRI,MRI,1,1,1,1,1,0,-1,-1,0
merfish,merfish,1,1,1,-1,-1,1,1,1,0
pophys,pophys,1,1,1,1,1,0,-1,-1,0
slap,slap,1,1,1,1,1,0,-1,-1,0
SPIM,SPIM,1,1,1,-1,-1,1,1,1,0
2 changes: 1 addition & 1 deletion src/aind_data_schema_models/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class RegistryModel(BaseName):
Registry.from_abbreviation = lambda x: Registry.abbreviation_map.get(x)


def map_registry(abbreviation: str, record: dict):
def map_registry(abbreviation: str, record: dict, *args):
"""replace the "registry" key of a dictionary with a RegistryModel object"""
registry = Registry.from_abbreviation(abbreviation)
if registry:
Expand Down
2 changes: 1 addition & 1 deletion src/aind_data_schema_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create_literal_model(
fields = {}
for k, v in obj.items():
if k in field_handlers:
field_handlers[k](v, fields)
field_handlers[k](v, fields, k)
elif k in base_model.__annotations__.keys():
field_type = base_model.__annotations__[k]
if v is not None:
Expand Down
19 changes: 18 additions & 1 deletion tests/test_modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import unittest

from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.modalities import Modality, ExpectedFiles, FileRequirement


class TestModality(unittest.TestCase):
Expand All @@ -14,5 +14,22 @@ def test_from_abbreviation(self):
self.assertEqual(Modality.ECEPHYS, Modality.from_abbreviation("ecephys"))


class TestExpectedFiles(unittest.TestCase):
"""Test methods in ExpectedFiles class"""

def test_expected_file_state(self):
"""Test that expected file states were set correctly"""

self.assertEqual(ExpectedFiles.ECEPHYS.subject, FileRequirement.REQUIRED)
self.assertEqual(ExpectedFiles.ECEPHYS.data_description, FileRequirement.REQUIRED)
self.assertEqual(ExpectedFiles.ECEPHYS.procedures, FileRequirement.REQUIRED)
self.assertEqual(ExpectedFiles.ECEPHYS.session, FileRequirement.REQUIRED)
self.assertEqual(ExpectedFiles.ECEPHYS.rig, FileRequirement.REQUIRED)
self.assertEqual(ExpectedFiles.ECEPHYS.processing, FileRequirement.OPTIONAL)
self.assertEqual(ExpectedFiles.ECEPHYS.acquisition, FileRequirement.EXCLUDED)
self.assertEqual(ExpectedFiles.ECEPHYS.instrument, FileRequirement.EXCLUDED)
self.assertEqual(ExpectedFiles.ECEPHYS.quality_control, FileRequirement.OPTIONAL)


if __name__ == "__main__":
unittest.main()

0 comments on commit 419163c

Please sign in to comment.