Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure backwards compatiblity between ophys and pohys modality #61

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/aind_data_schema_models/modalities.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
"""Module for Modality definitions"""

from typing import Any

from importlib_resources import files
from pydantic import ConfigDict, Field
from pydantic import BeforeValidator, ConfigDict, Field

from aind_data_schema_models.pid_names import BaseName
from aind_data_schema_models.utils import create_literal_class, read_csv


# This is a hotfix to allow users to use the old ophys modality abbreviation
# It should be removed in a future release
def _coerce_ophys_to_pophys(v: Any):
if isinstance(v, dict):
if v.get("abbreviation") == "ophys":
return Modality.POPHYS
return v


class ModalityModel(BaseName):
"""Base model config"""

Expand All @@ -21,7 +32,18 @@ class ModalityModel(BaseName):
base_model=ModalityModel,
discriminator="abbreviation",
class_module=__name__,
validators=[BeforeValidator(_coerce_ophys_to_pophys)],
)

Modality.OPHYS = Modality.POPHYS

Modality.abbreviation_map = {m().abbreviation: m() for m in Modality.ALL}
Modality.from_abbreviation = lambda x: Modality.abbreviation_map.get(x)


def _from_abbreviation(cls, v: Any):
if v == "ophys":
v = "pophys"
return cls.abbreviation_map.get(v)


Modality.from_abbreviation = lambda x: _from_abbreviation(Modality, x)
13 changes: 11 additions & 2 deletions src/aind_data_schema_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union

from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic import AfterValidator, BaseModel, BeforeValidator, ConfigDict, Field, WrapValidator, create_model
from typing_extensions import Annotated


Expand Down Expand Up @@ -88,6 +88,7 @@ def create_literal_class(
base_model: Type[BaseModel] = BaseModel,
discriminator: str = "name",
field_handlers: Optional[dict] = None,
validators: Optional[List[Union[WrapValidator, AfterValidator, BeforeValidator]]] = None,
):
"""
Make a dynamic pydantic literal class
Expand Down Expand Up @@ -125,7 +126,15 @@ def create_literal_class(
setattr(cls, "ALL", tuple(all_models))

# Older versions of flake8 raise errors about 'ALL' being undefined
setattr(cls, "ONE_OF", Annotated[Union[getattr(cls, "ALL")], Field(discriminator=discriminator)]) # noqa: F821
setattr(
cls,
"ONE_OF",
Annotated[
Union[getattr(cls, "ALL")], # noqa: F821
Field(discriminator=discriminator),
*(validators if validators else []),
],
)

# add the model instances as class variables
for m in all_models:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

import unittest

from pydantic import BaseModel

from aind_data_schema_models.modalities import Modality


class MockModel(BaseModel):
mod1: Modality.ONE_OF
mod2: Modality.ONE_OF
mod3: Modality.ONE_OF


class TestModality(unittest.TestCase):
"""Tests methods in Modality class"""

Expand All @@ -13,6 +21,29 @@ def test_from_abbreviation(self):

self.assertEqual(Modality.ECEPHYS, Modality.from_abbreviation("ecephys"))

def test_ophys_to_pophys_coercion(self):
"""Tests that ophys is coerced to pophys"""

_test_literal = """
{
"mod1":{"name":"Extracellular electrophysiology","abbreviation":"ecephys"},
"mod2":{"name":"Planar optical physiology","abbreviation":"pophys"},
"mod3":{"name":"foo bar","abbreviation":"ophys"}
}"""
t = MockModel(mod1=Modality.ECEPHYS, mod2=Modality.POPHYS, mod3=Modality.POPHYS)

self.assertEqual(t, MockModel.model_validate_json(_test_literal))

def test_ophys_to_pophys_from_abbreviation(self):
"""Tests that ophys is coerced to pophys from abbreviation"""

self.assertEqual(Modality.POPHYS, Modality.from_abbreviation("ophys"))

def test_ophys_attribute(self):
"""Tests that ophys attribute is available"""

self.assertEqual(Modality.OPHYS, Modality.POPHYS)


if __name__ == "__main__":
unittest.main()
Loading