Skip to content

Commit

Permalink
Migrates python tensorflow code to operate on R4.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 324309871
  • Loading branch information
nickgeorge committed Aug 14, 2020
1 parent 184d5aa commit 93a9666
Show file tree
Hide file tree
Showing 21 changed files with 193 additions and 282 deletions.
36 changes: 20 additions & 16 deletions py/google/fhir/labels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ py_library(
],
srcs_version = "PY3",
deps = [
"//proto/stu3:codes_py_pb2",
"//proto/stu3:datatypes_py_pb2",
"//proto/stu3:resources_py_pb2",
"//proto/r4/core:codes_py_pb2",
"//proto/r4/core:datatypes_py_pb2",
"//proto/r4/core/resources:bundle_and_contained_resource_py_pb2",
"//proto/r4/core/resources:encounter_py_pb2",
"//proto/r4/core/resources:patient_py_pb2",
],
)

Expand All @@ -36,12 +38,13 @@ py_test(
"encounter_test.py",
],
data = [
"//testdata/stu3:labels",
"//testdata/r4:labels",
],
python_version = "PY3",
deps = [
":encounter",
"//proto/stu3:resources_py_pb2",
"//proto/r4/core/resources:bundle_and_contained_resource_py_pb2",
"//proto/r4/core/resources:encounter_py_pb2",
"@absl_py//absl/testing:absltest",
"@com_google_protobuf//:protobuf_python",
],
Expand All @@ -55,9 +58,10 @@ py_library(
srcs_version = "PY3",
deps = [
":encounter",
"//proto/stu3:datatypes_py_pb2",
"//proto/stu3:ml_extensions_py_pb2",
"//proto/stu3:resources_py_pb2",
"//proto/r4:ml_extensions_py_pb2",
"//proto/r4/core:datatypes_py_pb2",
"//proto/r4/core/resources:encounter_py_pb2",
"//proto/r4/core/resources:patient_py_pb2",
],
)

Expand All @@ -67,14 +71,14 @@ py_test(
"label_test.py",
],
data = [
"//testdata/stu3:labels",
"//testdata/r4:labels",
],
python_version = "PY3",
deps = [
":label",
"//proto/stu3:datatypes_py_pb2",
"//proto/stu3:extensions_py_pb2",
"//proto/stu3:resources_py_pb2",
"//proto/r4:ml_extensions_py_pb2",
"//proto/r4/core:datatypes_py_pb2",
"//proto/r4/core/resources:encounter_py_pb2",
"@absl_py//absl/testing:absltest",
"@absl_py//absl/testing:parameterized",
"@com_google_protobuf//:protobuf_python",
Expand All @@ -97,8 +101,8 @@ py_library(
deps = [
":encounter",
":label",
"//proto/stu3:ml_extensions_py_pb2",
"//proto/stu3:resources_py_pb2",
"//proto/r4/core/resources:bundle_and_contained_resource_py_pb2",
"//proto/r4:ml_extensions_py_pb2",
"@absl_py//absl:app",
"@absl_py//absl/flags",
requirement("apache_beam"),
Expand All @@ -111,14 +115,14 @@ py_test(
"bundle_to_label_test.py",
],
data = [
"//testdata/stu3:labels",
"//testdata/r4:labels",
],
python_version = "PY3",
deps = [
":bundle_to_label_lib",
":label",
"@com_google_protobuf//:protobuf_python",
"//proto/stu3:resources_py_pb2",
"//proto/r4/core/resources:bundle_and_contained_resource_py_pb2",
"@absl_py//absl/testing:absltest",
requirement("apache_beam"),
requirement("nose"),
Expand Down
14 changes: 7 additions & 7 deletions py/google/fhir/labels/bundle_to_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import apache_beam as beam
from apache_beam.options import pipeline_options

from proto.stu3 import ml_extensions_pb2
from proto.stu3 import resources_pb2
from proto.r4 import ml_extensions_pb2
from proto.r4.core.resources import bundle_and_contained_resource_pb2
from py.google.fhir.labels import encounter
from py.google.fhir.labels import label

Expand All @@ -44,7 +44,7 @@
'The setup file for Dataflow dependencies')


@beam.typehints.with_input_types(resources_pb2.Bundle)
@beam.typehints.with_input_types(bundle_and_contained_resource_pb2.Bundle)
@beam.typehints.with_output_types(ml_extensions_pb2.EventLabel)
class LengthOfStayRangeLabelAt24HoursFn(beam.DoFn):
"""Converts Bundle into length of stay range at 24 hours label.
Expand All @@ -58,12 +58,12 @@ def __init__(self, for_synthea: bool = False):
self._for_synthea = for_synthea

def process(
self,
bundle: resources_pb2.Bundle) -> Iterator[ml_extensions_pb2.EventLabel]:
self, bundle: bundle_and_contained_resource_pb2.Bundle
) -> Iterator[ml_extensions_pb2.EventLabel]:
"""Iterate through bundle and yield label.
Args:
bundle: input stu3.Bundle proto
bundle: input R4 Bundle proto
Yields:
stu3.EventLabel proto.
Expand Down Expand Up @@ -106,7 +106,7 @@ def main(argv: List[str]):
p = beam.Pipeline(options=GetPipelineOptions())
bundles = p | 'read' >> beam.io.ReadFromTFRecord(
flags.FLAGS.input_path,
coder=beam.coders.ProtoCoder(resources_pb2.Bundle))
coder=beam.coders.ProtoCoder(bundle_and_contained_resource_pb2.Bundle))
labels = bundles | 'BundleToLabel' >> beam.ParDo(
LengthOfStayRangeLabelAt24HoursFn(for_synthea=flags.FLAGS.for_synthea))
_ = labels | beam.io.WriteToTFRecord(
Expand Down
8 changes: 3 additions & 5 deletions py/google/fhir/labels/bundle_to_label_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
from apache_beam.testing import test_pipeline
from apache_beam.testing import util
from google.protobuf import text_format
from proto.stu3 import resources_pb2
from proto.r4.core.resources import bundle_and_contained_resource_pb2
from py.google.fhir.labels import bundle_to_label
from py.google.fhir.labels import label


# TODO(kunzhang, cykoo): Move this to a proper location.
_TESTDATA_PATH = 'com_google_fhir/testdata/stu3/labels'
_TESTDATA_PATH = 'com_google_fhir/testdata/r4/labels'


class BundleToLabelTest(absltest.TestCase):
Expand All @@ -42,7 +40,7 @@ def _VerifyPipeline(self, for_synthea: bool):
bundle_text_file = 'bundle_1.pbtxt'
if for_synthea:
bundle_text_file = 'bundle_synthea.pbtxt'
bundle = resources_pb2.Bundle()
bundle = bundle_and_contained_resource_pb2.Bundle()
with open(os.path.join(self._test_data_dir, bundle_text_file)) as f:
text_format.Parse(f.read(), bundle)
enc = bundle.entry[0].resource.encounter
Expand Down
38 changes: 22 additions & 16 deletions py/google/fhir/labels/encounter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import datetime
from typing import Iterator, Optional

from proto.stu3 import codes_pb2
from proto.stu3 import datatypes_pb2
from proto.stu3 import resources_pb2
from proto.r4.core import codes_pb2
from proto.r4.core import datatypes_pb2
from proto.r4.core.resources import bundle_and_contained_resource_pb2
from proto.r4.core.resources import encounter_pb2
from proto.r4.core.resources import patient_pb2


ENCOUNTER_CLASS_CODESYSTEM = 'http://hl7.org/fhir/v3/ActCode'
Expand All @@ -47,42 +49,44 @@ def ToTime(date_and_time: datatypes_pb2.DateTime) -> datetime.datetime:
return datetime.datetime.utcfromtimestamp(date_and_time.value_us / 1000000)


def EncounterIsFinished(encounter: resources_pb2.Encounter) -> bool:
def EncounterIsFinished(encounter: encounter_pb2.Encounter) -> bool:
return (encounter.period.HasField('start') and
encounter.period.HasField('end') and
encounter.status.value ==
codes_pb2.EncounterStatusCode.FINISHED)


def EncounterIsValidHospitalization(encounter: resources_pb2.Encounter) -> bool:
def EncounterIsValidHospitalization(encounter: encounter_pb2.Encounter) -> bool:
enc_class = encounter.class_value
return (EncounterIsFinished(encounter) and
enc_class.system.value == ENCOUNTER_CLASS_CODESYSTEM and
enc_class.code.value == CLASS_INPATIENT)


def EncounterIsValidHospitalizationForSynthea(
encounter: resources_pb2.Encounter) -> bool:
encounter: encounter_pb2.Encounter) -> bool:
enc_class = encounter.class_value
return (EncounterIsFinished(encounter) and
enc_class.code.value == 'inpatient')


def AtDuration(encounter: resources_pb2.Encounter,
def AtDuration(encounter: encounter_pb2.Encounter,
hours: int) -> datetime.datetime:
# encounter.start + hours
result = ToTime(encounter.period.start) + datetime.timedelta(hours=hours)
assert result <= ToTime(encounter.period.end)
return result


def EncounterLengthDays(encounter: resources_pb2.Encounter) -> float:
def EncounterLengthDays(encounter: encounter_pb2.Encounter) -> float:
# Needs a float to properly put encounters in ranges.
length_delta = ToTime(encounter.period.end) - ToTime(encounter.period.start)
return float(length_delta.total_seconds()) / SECS_PER_DAY


def GetPatient(bundle: resources_pb2.Bundle) -> Optional[resources_pb2.Patient]:
def GetPatient(
bundle: bundle_and_contained_resource_pb2.Bundle
) -> Optional[patient_pb2.Patient]:
for entry in bundle.entry:
if entry.resource.HasField('patient'):
return entry.resource.patient
Expand All @@ -93,7 +97,8 @@ def GetPatient(bundle: resources_pb2.Bundle) -> Optional[resources_pb2.Patient]:
# Use generator to be memory efficient.
#
def AllEncounters(
bundle: resources_pb2.Bundle) -> Iterator[resources_pb2.Encounter]:
bundle: bundle_and_contained_resource_pb2.Bundle
) -> Iterator[encounter_pb2.Encounter]:
"""Yields all encounters in a bundle.
Args:
Expand All @@ -108,8 +113,8 @@ def AllEncounters(


def InpatientEncounters(
bundle: resources_pb2.Bundle,
for_synthea: bool = False) -> Iterator[resources_pb2.Encounter]:
bundle: bundle_and_contained_resource_pb2.Bundle,
for_synthea: bool = False) -> Iterator[encounter_pb2.Encounter]:
"""Yields all inpatient encounters in a bundle.
Args:
Expand All @@ -126,9 +131,10 @@ def InpatientEncounters(
yield encounter


def InpatientEncountersLongerThan(bundle: resources_pb2.Bundle,
n_hours: int,
for_synthea: bool = False):
def InpatientEncountersLongerThan(
bundle: bundle_and_contained_resource_pb2.Bundle,
n_hours: int,
for_synthea: bool = False):
"""Yields all inpatient encounters in a bundle that is longer than N hours.
Args:
Expand All @@ -146,6 +152,6 @@ def InpatientEncountersLongerThan(bundle: resources_pb2.Bundle,


# One line wrapper for 24.
def Inpatient24HrEncounters(bundle: resources_pb2.Bundle,
def Inpatient24HrEncounters(bundle: bundle_and_contained_resource_pb2.Bundle,
for_synthea: bool = False):
return InpatientEncountersLongerThan(bundle, 24, for_synthea)
14 changes: 7 additions & 7 deletions py/google/fhir/labels/encounter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

from absl.testing import absltest
from google.protobuf import text_format
from proto.stu3 import resources_pb2
from proto.r4.core.resources import bundle_and_contained_resource_pb2
from proto.r4.core.resources import encounter_pb2
from py.google.fhir.labels import encounter


_TESTDATA_PATH = 'com_google_fhir/testdata/stu3/labels'
_TESTDATA_PATH = 'com_google_fhir/testdata/r4/labels'


class EncounterTest(absltest.TestCase):
Expand All @@ -32,16 +32,16 @@ def setUp(self):
super(EncounterTest, self).setUp()
self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(),
_TESTDATA_PATH)
self._enc = resources_pb2.Encounter()
self._enc = encounter_pb2.Encounter()
with open(os.path.join(self._test_data_dir, 'encounter_1.pbtxt')) as f:
text_format.Parse(f.read(), self._enc)
self._bundle = resources_pb2.Bundle()
self._bundle = bundle_and_contained_resource_pb2.Bundle()
self._bundle.entry.add().resource.encounter.CopyFrom(self._enc)
self._synthea_enc = resources_pb2.Encounter()
self._synthea_enc = encounter_pb2.Encounter()
with open(os.path.join(self._test_data_dir,
'encounter_synthea.pbtxt')) as f:
text_format.Parse(f.read(), self._synthea_enc)
self._synthea_bundle = resources_pb2.Bundle()
self._synthea_bundle = bundle_and_contained_resource_pb2.Bundle()
self._synthea_bundle.entry.add().resource.encounter.CopyFrom(
self._synthea_enc)

Expand Down
18 changes: 10 additions & 8 deletions py/google/fhir/labels/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
import datetime
from typing import Iterator, Optional

from proto.stu3 import datatypes_pb2
from proto.stu3 import ml_extensions_pb2
from proto.stu3 import resources_pb2
from proto.r4 import ml_extensions_pb2
from proto.r4.core import datatypes_pb2
from proto.r4.core.resources import encounter_pb2
from proto.r4.core.resources import patient_pb2
from py.google.fhir.labels import encounter


Expand Down Expand Up @@ -60,9 +61,10 @@ def ToMicroSeconds(dt: datetime.datetime) -> int:


# Note: this API only compose encounter level API.
def ComposeLabel(patient: resources_pb2.Patient, enc: resources_pb2.Encounter,
label_name: str, label_val: str,
label_time: datetime.datetime) -> ml_extensions_pb2.EventLabel:
def ComposeLabel(
patient: patient_pb2.Patient, enc: encounter_pb2.Encounter, label_name: str,
label_val: str,
label_time: datetime.datetime) -> ml_extensions_pb2.EventLabel:
"""Compose an event_label proto given inputs.
Args:
Expand Down Expand Up @@ -92,8 +94,8 @@ def ComposeLabel(patient: resources_pb2.Patient, enc: resources_pb2.Encounter,


def LengthOfStayRangeAt24Hours(
patient: resources_pb2.Patient,
enc: resources_pb2.Encounter) -> Iterator[ml_extensions_pb2.EventLabel]:
patient: patient_pb2.Patient,
enc: encounter_pb2.Encounter) -> Iterator[ml_extensions_pb2.EventLabel]:
"""Generate length of stay range labels at 24 hours after admission.
Args:
Expand Down
17 changes: 9 additions & 8 deletions py/google/fhir/labels/label_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
from absl.testing import absltest
from absl.testing import parameterized
from google.protobuf import text_format
from proto.stu3 import datatypes_pb2
from proto.stu3 import ml_extensions_pb2
from proto.stu3 import resources_pb2
from proto.r4 import ml_extensions_pb2
from proto.r4.core import datatypes_pb2
from proto.r4.core.resources import encounter_pb2
from proto.r4.core.resources import patient_pb2
from py.google.fhir.labels import label


_TESTDATA_PATH = 'com_google_fhir/testdata/stu3/labels'
_TESTDATA_PATH = 'com_google_fhir/testdata/r4/labels'


class LabelTest(parameterized.TestCase):
Expand All @@ -36,10 +37,10 @@ def setUp(self):
super(LabelTest, self).setUp()
self._test_data_dir = os.path.join(absltest.get_default_test_srcdir(),
_TESTDATA_PATH)
self._enc = resources_pb2.Encounter()
self._enc = encounter_pb2.Encounter()
with open(os.path.join(self._test_data_dir, 'encounter_1.pbtxt')) as f:
text_format.Parse(f.read(), self._enc)
self._patient = resources_pb2.Patient()
self._patient = patient_pb2.Patient()
self._patient.id.value = 'Patient/1'

self._expected_label = ml_extensions_pb2.EventLabel()
Expand Down Expand Up @@ -92,7 +93,7 @@ def testComposeLabel(self):
{'end_us': 1234827090000000, 'label_val': 'less_or_equal_3'}
)
def testLengthOfStayRangeAt24Hours(self, end_us, label_val):
enc = resources_pb2.Encounter()
enc = encounter_pb2.Encounter()
enc.CopyFrom(self._enc)
enc.period.end.value_us = end_us
labels = [l for l in label.LengthOfStayRangeAt24Hours(
Expand All @@ -107,7 +108,7 @@ def testLengthOfStayRangeAt24Hours(self, end_us, label_val):
self.assertEqual([expected_label], labels)

def testLengthOfStayRangeAt24HoursLT24Hours(self):
enc = resources_pb2.Encounter()
enc = encounter_pb2.Encounter()
enc.CopyFrom(self._enc)
enc.period.end.value_us = 1234567891000000
with self.assertRaises(AssertionError):
Expand Down
Loading

0 comments on commit 93a9666

Please sign in to comment.