From e2f065b59454302c3b51952d951f8f83d98f9fcf Mon Sep 17 00:00:00 2001 From: Dan Birman Date: Tue, 22 Oct 2024 09:32:36 -0700 Subject: [PATCH] feat: support modality and stage filters --- src/aind_data_schema/core/quality_control.py | 20 +++++- tests/test_quality_control.py | 68 ++++++++++++++++++-- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/aind_data_schema/core/quality_control.py b/src/aind_data_schema/core/quality_control.py index 04661f84..9502555e 100644 --- a/src/aind_data_schema/core/quality_control.py +++ b/src/aind_data_schema/core/quality_control.py @@ -165,15 +165,29 @@ class QualityControl(AindCoreModel): evaluations: List[QCEvaluation] = Field(..., title="Evaluations") notes: Optional[str] = Field(default=None, title="Notes") - @property - def status(self) -> Status: + def status(self, modality: str = None, stage: Stage = None) -> Status: """Loop through all evaluations and return the overall status Any FAIL -> FAIL If no fails, then any PENDING -> PENDING All PASS -> PASS + + Parameters + ---------- + modality : str, optional + Modality.ONE_OF to filter by, by default None + stage : Stage, optional + Stage to filter by, by default None + + Returns + ------- + Status """ - eval_statuses = [evaluation.status for evaluation in self.evaluations] + eval_statuses = [ + evaluation.status + for evaluation in self.evaluations + if (not modality or evaluation.modality == modality) and (not stage or evaluation.stage == stage) + ] if any(status == Status.FAIL for status in eval_statuses): return Status.FAIL diff --git a/tests/test_quality_control.py b/tests/test_quality_control.py index 4da9186a..15e948c2 100644 --- a/tests/test_quality_control.py +++ b/tests/test_quality_control.py @@ -83,7 +83,7 @@ def test_overall_status(self): ) # check that overall status gets auto-set if it has never been set before - self.assertEqual(q.status, Status.PASS) + self.assertEqual(q.status(), Status.PASS) # Add a pending metric to the first evaluation q.evaluations[0].metrics.append( @@ -100,7 +100,7 @@ def test_overall_status(self): ) ) - self.assertEqual(q.status, Status.PENDING) + self.assertEqual(q.status(), Status.PENDING) # Add a failing metric to the first evaluation q.evaluations[0].metrics.append( @@ -115,7 +115,7 @@ def test_overall_status(self): ) ) - self.assertEqual(q.status, Status.FAIL) + self.assertEqual(q.status(), Status.FAIL) def test_evaluation_status(self): """test that evaluation status goes to pass/pending/fail correctly""" @@ -323,7 +323,6 @@ def test_multi_session(self): ], ) - print(context.exception) self.assertTrue( "is in a single-asset QCEvaluation and should not have evaluated_assets" in repr(context.exception) ) @@ -367,6 +366,67 @@ def test_multi_session(self): self.assertTrue("is in a multi-asset QCEvaluation and must have evaluated_assets" in repr(context.exception)) + def test_status_filters(self): + """Test that QualityControl.status(modality, stage) filters correctly""" + + test_eval = QCEvaluation( + name="Drift map", + modality=Modality.ECEPHYS, + stage=Stage.PROCESSING, + metrics=[ + QCMetric( + name="Multiple values example", + value={"stuff": "in_a_dict"}, + status_history=[ + QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.PASS) + ], + ), + QCMetric( + name="Drift map pass/fail", + value=False, + description="Manual evaluation of whether the drift map looks good", + reference="s3://some-data-somewhere", + status_history=[ + QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.PASS) + ], + ), + ], + ) + test_eval2 = QCEvaluation( + name="Drift map", + modality=Modality.BEHAVIOR, + stage=Stage.RAW, + metrics=[ + QCMetric( + name="Multiple values example", + value={"stuff": "in_a_dict"}, + status_history=[ + QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.FAIL) + ], + ), + QCMetric( + name="Drift map pass/fail", + value=False, + description="Manual evaluation of whether the drift map looks good", + reference="s3://some-data-somewhere", + status_history=[ + QCStatus(evaluator="Bob", timestamp=datetime.fromisoformat("2020-10-10"), status=Status.PASS) + ], + ), + ], + ) + + # Confirm that the status filters work + q = QualityControl( + evaluations=[test_eval, test_eval2], + ) + + self.assertEqual(q.status(), Status.FAIL) + self.assertEqual(q.status(modality=Modality.BEHAVIOR), Status.FAIL) + self.assertEqual(q.status(modality=Modality.ECEPHYS), Status.PASS) + self.assertEqual(q.status(stage=Stage.RAW), Status.FAIL) + self.assertEqual(q.status(stage=Stage.PROCESSING), Status.PASS) + if __name__ == "__main__": unittest.main()