Skip to content

Commit

Permalink
Add method for propagating units to descendants for ParameterScale
Browse files Browse the repository at this point in the history
…objects (#162)

* feat: add method for propagating units to descendants for `ParameterScale` objects

* feat: increase test coverage

* chore: add changelog

* feat: increase test coverage

* fix: update sphinx version

* fix: change style to reduce nested ifs.
  • Loading branch information
abhcs authored Mar 4, 2024
1 parent 1c3d6a1 commit e90bc1d
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 1 deletion.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
changed:
- Add method for propagating units to descendants for ParameterScale objects
16 changes: 16 additions & 0 deletions policyengine_core/parameters/parameter_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ def __repr__(self) -> str:
]
)

def propagate_units(self) -> None:
unit_keys = filter(
lambda k: k in self.metadata,
parameters.ParameterScaleBracket.allowed_unit_keys(),
)
for unit_key in unit_keys:
child_key = unit_key[:-5]
for bracket in self.brackets:
if (
child_key in bracket.children
and "unit" not in bracket.children[child_key].metadata
):
bracket.children[child_key].metadata["unit"] = (
self.metadata[unit_key]
)

def get_descendants(self) -> Iterable:
for bracket in self.brackets:
yield bracket
Expand Down
4 changes: 4 additions & 0 deletions policyengine_core/parameters/parameter_scale_bracket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class ParameterScaleBracket(ParameterNode):
["amount", "threshold", "rate", "average_rate", "base"]
)

@staticmethod
def allowed_unit_keys():
return [key + "_unit" for key in ParameterScaleBracket._allowed_keys]

def get_descendants(self) -> Iterable[Parameter]:
for key in self._allowed_keys:
if key in self.children:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"coverage",
"furo",
"mypy==0.991",
"sphinx==4.5.0",
"sphinx==5.0.0",
"sphinx-argparse==0.4.0",
"sphinx-math-dollar==1.2.1",
"types-PyYAML==6.0.12.2",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
description: Propagate units in metadata of a scaled parameter to its descendants
metadata:
type: single_amount
threshold_unit: child
amount_unit: currency-USD
rate_unit: /1
label: Test unit propagation

brackets:
- threshold:
values:
1995-01-01: 0
amount:
values:
2017-01-01: 8_340
2018-01-01: 8_510
2019-01-01: 8_650
2020-01-01: 8_790
2021-01-01: 11_610
2022-01-01: 9_160
2023-01-01: 9_800
2024-01-01: 10_330
- threshold:
values:
1995-01-01: 1
amount:
values:
2017-01-01: 18_340
2018-01-01: 18_700
2019-01-01: 19_030
2020-01-01: 19_330
2021-01-01: 19_520
2022-01-01: 20_130
2023-01-01: 21_560
2024-01-01: 22_720
metadata:
unit: US dollars
- threshold:
values:
1995-01-01: 2
metadata:
a: b
amount:
values:
2017-01-01: 18_340
2018-01-01: 18_700
2019-01-01: 19_030
2020-01-01: 19_330
2021-01-01: 19_520
2022-01-01: 20_130
2023-01-01: 21_560
2024-01-01: 22_720
- threshold:
values:
1995-01-01: 3
metadata:
a: b
amount:
values:
2017-01-01: 18_340
2018-01-01: 18_700
2019-01-01: 19_030
2020-01-01: 19_330
2021-01-01: 19_520
2022-01-01: 20_130
2023-01-01: 21_560
2024-01-01: 22_720
metadata:
a: b
16 changes: 16 additions & 0 deletions tests/core/parameter_validation/test_propagate_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os

from policyengine_core.parameters import load_parameter_file

BASE_DIR = os.path.dirname(os.path.abspath(__file__))


def test_propagate_units():
path = os.path.join(BASE_DIR, "parameter_for_unit_propagation.yaml")
parameter = load_parameter_file(path)
parameter.propagate_units()
for i in range(4):
assert parameter.brackets[i].threshold.metadata["unit"] == "child"
assert parameter.brackets[i].amount.metadata["unit"] == (
"US dollars" if i == 1 else "currency-USD"
)

0 comments on commit e90bc1d

Please sign in to comment.