From 504f4ef7a90d072b86f1b31673d2d90318825a9d Mon Sep 17 00:00:00 2001
From: Nikhil Woodruff <35577657+nikhilwoodruff@users.noreply.github.com>
Date: Tue, 6 Feb 2024 11:28:09 +0000
Subject: [PATCH] Format and improve error handling (#152)
---
changelog_entry.yaml | 4 ++
policyengine_core/charts/bar.py | 26 +++++++------
.../country_template/entities.py | 1 +
policyengine_core/holders/holder.py | 9 +++--
.../parameters/at_instant_like.py | 3 +-
.../operations/uprate_parameters.py | 4 ++
.../parameters/parameter_node.py | 6 +--
.../vectorial_parameter_node_at_instant.py | 22 ++++++-----
policyengine_core/periods/instant_.py | 6 +--
policyengine_core/reforms/reform.py | 6 +--
.../simulations/individual_sim.py | 1 +
policyengine_core/simulations/simulation.py | 37 +++++++++++++------
.../simulations/simulation_builder.py | 26 +++++++------
.../taxbenefitsystems/tax_benefit_system.py | 10 +++--
policyengine_core/taxscales/tax_scale_like.py | 9 ++---
policyengine_core/variables/helpers.py | 8 ++--
policyengine_core/variables/variable.py | 22 ++++++-----
tests/core/variables/test_adds.py | 36 ++++++++++++++++++
tests/core/variables/test_subtracts.py | 36 ++++++++++++++++++
19 files changed, 192 insertions(+), 80 deletions(-)
create mode 100644 tests/core/variables/test_adds.py
create mode 100644 tests/core/variables/test_subtracts.py
diff --git a/changelog_entry.yaml b/changelog_entry.yaml
index e69de29bb..d53617c8c 100644
--- a/changelog_entry.yaml
+++ b/changelog_entry.yaml
@@ -0,0 +1,4 @@
+- bump: patch
+ changes:
+ added:
+ - Improvements to error handling in bad variable declarations.
diff --git a/policyengine_core/charts/bar.py b/policyengine_core/charts/bar.py
index bd3366df5..8f54ed169 100644
--- a/policyengine_core/charts/bar.py
+++ b/policyengine_core/charts/bar.py
@@ -32,9 +32,11 @@ def bar_chart(
"""
hover_text_labels = [
- hover_text_function(index, value)
- if hover_text_function is not None
- else None
+ (
+ hover_text_function(index, value)
+ if hover_text_function is not None
+ else None
+ )
for index, value in data.items()
]
@@ -57,9 +59,11 @@ def bar_chart(
positive_colour if v > 0 else negative_colour
for v in data.values
],
- hovertemplate="%{customdata[0]}"
- if hover_text_labels is not None
- else None,
+ hovertemplate=(
+ "%{customdata[0]}"
+ if hover_text_labels is not None
+ else None
+ ),
)
)
return format_fig(fig)
@@ -94,11 +98,11 @@ def cross_section_bar_chart(
"Category": category,
"Cross section": cross_section_value,
"Value": value,
- "Hover text": hover_text_function(
- cross_section_value, category, value
- )
- if hover_text_function is not None
- else None,
+ "Hover text": (
+ hover_text_function(cross_section_value, category, value)
+ if hover_text_function is not None
+ else None
+ ),
}
df = df.append(row, ignore_index=True)
diff --git a/policyengine_core/country_template/entities.py b/policyengine_core/country_template/entities.py
index c085f8a77..f68b42d63 100644
--- a/policyengine_core/country_template/entities.py
+++ b/policyengine_core/country_template/entities.py
@@ -5,6 +5,7 @@
See https://openfisca.org/doc/key-concepts/person,_entities,_role.html
"""
+
from typing import Any
from policyengine_core.entities import build_entity
diff --git a/policyengine_core/holders/holder.py b/policyengine_core/holders/holder.py
index 084d86943..972a1ff28 100644
--- a/policyengine_core/holders/holder.py
+++ b/policyengine_core/holders/holder.py
@@ -157,10 +157,11 @@ def get_memory_usage(self) -> dict:
usage.update(
dict(
nb_requests=nb_requests,
- nb_requests_by_array=nb_requests
- / float(usage["nb_arrays"])
- if usage["nb_arrays"] > 0
- else numpy.nan,
+ nb_requests_by_array=(
+ nb_requests / float(usage["nb_arrays"])
+ if usage["nb_arrays"] > 0
+ else numpy.nan
+ ),
)
)
diff --git a/policyengine_core/parameters/at_instant_like.py b/policyengine_core/parameters/at_instant_like.py
index 503b23592..ecde86143 100644
--- a/policyengine_core/parameters/at_instant_like.py
+++ b/policyengine_core/parameters/at_instant_like.py
@@ -18,5 +18,4 @@ def get_at_instant(self, instant: Instant) -> Any:
return self._get_at_instant(instant)
@abc.abstractmethod
- def _get_at_instant(self, instant):
- ...
+ def _get_at_instant(self, instant): ...
diff --git a/policyengine_core/parameters/operations/uprate_parameters.py b/policyengine_core/parameters/operations/uprate_parameters.py
index d94def379..530ed1839 100644
--- a/policyengine_core/parameters/operations/uprate_parameters.py
+++ b/policyengine_core/parameters/operations/uprate_parameters.py
@@ -85,6 +85,10 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
# Apply the uprater and add to the parameter
value_at_start = parameter(last_instant)
uprater_at_start = uprating_parameter(last_instant)
+ if uprater_at_start is None:
+ raise ValueError(
+ f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} at {entry_instant} because the uprating parameter is not defined at {last_instant}."
+ )
uprater_at_entry = uprating_parameter(entry_instant)
uprater_change = uprater_at_entry / uprater_at_start
uprated_value = value_at_start * uprater_change
diff --git a/policyengine_core/parameters/parameter_node.py b/policyengine_core/parameters/parameter_node.py
index f8952c2ca..050b806b0 100644
--- a/policyengine_core/parameters/parameter_node.py
+++ b/policyengine_core/parameters/parameter_node.py
@@ -28,9 +28,9 @@ class ParameterNode(AtInstantLike):
A node in the legislation `parameter tree `_.
"""
- _allowed_keys: typing.Optional[
- typing.Iterable[str]
- ] = None # By default, no restriction on the keys
+ _allowed_keys: typing.Optional[typing.Iterable[str]] = (
+ None # By default, no restriction on the keys
+ )
parent: "ParameterNode" = None
"""The parent of the node, or None if the node is the root of the tree."""
diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py
index e4914ed3e..dbd1d14b9 100644
--- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py
+++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py
@@ -27,13 +27,15 @@ def build_from_node(
# Recursively vectorize the children of the node
vectorial_subnodes = tuple(
[
- VectorialParameterNodeAtInstant.build_from_node(
- node[subnode_name]
- ).vector
- if isinstance(
- node[subnode_name], parameters.ParameterNodeAtInstant
+ (
+ VectorialParameterNodeAtInstant.build_from_node(
+ node[subnode_name]
+ ).vector
+ if isinstance(
+ node[subnode_name], parameters.ParameterNodeAtInstant
+ )
+ else node[subnode_name]
)
- else node[subnode_name]
for subnode_name in subnodes_name
]
)
@@ -44,9 +46,11 @@ def build_from_node(
dtype=[
(
subnode_name,
- subnode.dtype
- if isinstance(subnode, numpy.recarray)
- else "float",
+ (
+ subnode.dtype
+ if isinstance(subnode, numpy.recarray)
+ else "float"
+ ),
)
for (subnode_name, subnode) in zip(
subnodes_name, vectorial_subnodes
diff --git a/policyengine_core/periods/instant_.py b/policyengine_core/periods/instant_.py
index d226c558a..4288e3a4a 100644
--- a/policyengine_core/periods/instant_.py
+++ b/policyengine_core/periods/instant_.py
@@ -35,9 +35,9 @@ def __str__(self) -> str:
"""
instant_str = config.str_by_instant_cache.get(self)
if instant_str is None:
- config.str_by_instant_cache[
- self
- ] = instant_str = self.date.isoformat()
+ config.str_by_instant_cache[self] = instant_str = (
+ self.date.isoformat()
+ )
return instant_str
@property
diff --git a/policyengine_core/reforms/reform.py b/policyengine_core/reforms/reform.py
index 3b31cfe00..363feafd9 100644
--- a/policyengine_core/reforms/reform.py
+++ b/policyengine_core/reforms/reform.py
@@ -206,9 +206,9 @@ def api_id(self):
sanitised_period_values = {}
for period, value in period_values.items():
period = period_(period)
- sanitised_period_values[
- f"{period.start}.{period.stop}"
- ] = value
+ sanitised_period_values[f"{period.start}.{period.stop}"] = (
+ value
+ )
sanitised_parameter_values[path] = sanitised_period_values
response = requests.post(
diff --git a/policyengine_core/simulations/individual_sim.py b/policyengine_core/simulations/individual_sim.py
index 4ea542a21..3a82a8109 100644
--- a/policyengine_core/simulations/individual_sim.py
+++ b/policyengine_core/simulations/individual_sim.py
@@ -1,6 +1,7 @@
"""
IndividualSim and any other interfaces to intialising and running simulations on hypothetical situations. Deprecated.
"""
+
from functools import partial
from typing import Dict, List
diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py
index f89d91a3e..40ec8b30b 100644
--- a/policyengine_core/simulations/simulation.py
+++ b/policyengine_core/simulations/simulation.py
@@ -241,8 +241,8 @@ def build_from_dataset(self) -> None:
entity_id_field in data
), f"Missing {entity_id_field} column in the dataset. Each person entity must have an ID array defined for ETERNITY."
- get_eternity_array = (
- lambda ds: ds[list(ds.keys())[0]]
+ get_eternity_array = lambda ds: (
+ ds[list(ds.keys())[0]]
if self.dataset.data_format == Dataset.TIME_PERIOD_ARRAYS
else ds
)
@@ -466,6 +466,11 @@ def calculate_dataframe(
period = periods.period(period)
elif period is None and self.default_calculation_period is not None:
period = periods.period(self.default_calculation_period)
+
+ # Check each variable exists
+ for variable_name in variable_names:
+ if variable_name not in self.tax_benefit_system.variables:
+ raise ValueError(f"Variable {variable_name} does not exist.")
df = pd.DataFrame()
entities = [
self.tax_benefit_system.get_variable(variable_name).entity.key
@@ -785,7 +790,9 @@ def _run_formula(
)
values = values + parameter(period.start)
except:
- pass
+ raise ValueError(
+ f"In the variable '{variable.name}', the 'adds' attribute is a list that contains a string '{added_variable}' that does not match any variable or parameter."
+ )
if variable.subtracts is not None and len(variable.subtracts) > 0:
if isinstance(variable.subtracts, str):
try:
@@ -820,7 +827,9 @@ def _run_formula(
)
values = values + parameter(period.start)
except:
- pass
+ raise ValueError(
+ f"In the variable '{variable.name}', the 'subtracts' attribute is a list that contains a string '{subtracted_variable}' that does not match any variable or parameter."
+ )
return values
if self.trace and not isinstance(
@@ -873,9 +882,11 @@ def _check_period_consistency(
"Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format(
variable.name,
period,
- "month"
- if variable.definition_period == periods.MONTH
- else "year",
+ (
+ "month"
+ if variable.definition_period == periods.MONTH
+ else "year"
+ ),
)
)
@@ -1244,8 +1255,9 @@ def extract_person(
variable
).get_known_periods()
if len(known_periods) > 0:
- value = self.get_holder(variable).get_array(
- known_periods[0]
+ first_known_period = known_periods[0]
+ value = self.calculate(
+ variable, first_known_period
)[group_index]
situation[entity.plural][entity.key][variable] = {
str(known_periods[0]): value
@@ -1271,9 +1283,10 @@ def extract_person(
variable
).get_known_periods()
if len(known_periods) > 0:
- value = self.get_holder(variable).get_array(
- known_periods[0]
- )[person_index]
+ first_known_period = known_periods[0]
+ value = self.calculate(variable, first_known_period)[
+ person_index
+ ]
situation[person.plural][person_name][variable] = {
str(known_periods[0]): value
}
diff --git a/policyengine_core/simulations/simulation_builder.py b/policyengine_core/simulations/simulation_builder.py
index b39e953f0..e36afdb2d 100644
--- a/policyengine_core/simulations/simulation_builder.py
+++ b/policyengine_core/simulations/simulation_builder.py
@@ -32,8 +32,12 @@
class SimulationBuilder:
def __init__(self):
- self.default_period: Period = None # Simulation period used for variables when no period is defined
- self.persons_plural: str = None # Plural name for person entity in current tax and benefits system
+ self.default_period: Period = (
+ None # Simulation period used for variables when no period is defined
+ )
+ self.persons_plural: str = (
+ None # Plural name for person entity in current tax and benefits system
+ )
# JSON input - Memory of known input values. Indexed by variable or axis name.
self.input_buffer: typing.Dict[
@@ -55,9 +59,9 @@ def __init__(self):
self.has_axes = False
self.axes_entity_counts: typing.Dict[Entity.plural, int] = {}
self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {}
- self.axes_memberships: typing.Dict[
- Entity.plural, typing.List[int]
- ] = {}
+ self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = (
+ {}
+ )
self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {}
def build_from_dict(
@@ -484,12 +488,12 @@ def add_group_entity(
entity_ids = entity_ids + list(persons_to_allocate)
for person_id in persons_to_allocate:
person_index = persons_ids.index(person_id)
- self.memberships[entity.plural][
- person_index
- ] = entity_ids.index(person_id)
- self.roles[entity.plural][
- person_index
- ] = entity.flattened_roles[0]
+ self.memberships[entity.plural][person_index] = (
+ entity_ids.index(person_id)
+ )
+ self.roles[entity.plural][person_index] = (
+ entity.flattened_roles[0]
+ )
# Adjust previously computed ids and counts
self.entity_ids[entity.plural] = entity_ids
self.entity_counts[entity.plural] = len(entity_ids)
diff --git a/policyengine_core/taxbenefitsystems/tax_benefit_system.py b/policyengine_core/taxbenefitsystems/tax_benefit_system.py
index 51c8b17de..f8ced5930 100644
--- a/policyengine_core/taxbenefitsystems/tax_benefit_system.py
+++ b/policyengine_core/taxbenefitsystems/tax_benefit_system.py
@@ -69,7 +69,9 @@ class TaxBenefitSystem:
_parameters_at_instant_cache: Optional[Dict[Any, Any]] = None
person_key_plural: str = None
preprocess_parameters: str = None
- baseline: "TaxBenefitSystem" = None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained.
+ baseline: "TaxBenefitSystem" = (
+ None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained.
+ )
cache_blacklist = None
decomposition_file_path = None
variable_module_metadata: dict = None
@@ -183,9 +185,9 @@ def base_tax_benefit_system(self) -> "TaxBenefitSystem":
baseline = self.baseline
if baseline is None:
return self
- self._base_tax_benefit_system = (
- base_tax_benefit_system
- ) = baseline.base_tax_benefit_system
+ self._base_tax_benefit_system = base_tax_benefit_system = (
+ baseline.base_tax_benefit_system
+ )
return base_tax_benefit_system
def instantiate_entities(self) -> Dict[str, Population]:
diff --git a/policyengine_core/taxscales/tax_scale_like.py b/policyengine_core/taxscales/tax_scale_like.py
index b9a7e5f14..dbedea580 100644
--- a/policyengine_core/taxscales/tax_scale_like.py
+++ b/policyengine_core/taxscales/tax_scale_like.py
@@ -48,20 +48,17 @@ def __ne__(self, _other: object) -> typing.NoReturn:
)
@abc.abstractmethod
- def __repr__(self) -> str:
- ...
+ def __repr__(self) -> str: ...
@abc.abstractmethod
def calc(
self,
tax_base: NumericalArray,
right: bool,
- ) -> numpy.float_:
- ...
+ ) -> numpy.float_: ...
@abc.abstractmethod
- def to_dict(self) -> dict:
- ...
+ def to_dict(self) -> dict: ...
def copy(self) -> typing.Any:
new = commons.empty_clone(self)
diff --git a/policyengine_core/variables/helpers.py b/policyengine_core/variables/helpers.py
index 04fab093f..5eca45f5d 100644
--- a/policyengine_core/variables/helpers.py
+++ b/policyengine_core/variables/helpers.py
@@ -49,9 +49,11 @@ def get_neutralized_variable(variable):
result = variable.clone()
result.is_neutralized = True
result.label = (
- "[Neutralized]"
- if variable.label is None
- else "[Neutralized] {}".format(variable.label),
+ (
+ "[Neutralized]"
+ if variable.label is None
+ else "[Neutralized] {}".format(variable.label)
+ ),
)
return result
diff --git a/policyengine_core/variables/variable.py b/policyengine_core/variables/variable.py
index 00877a49f..4674d26c1 100644
--- a/policyengine_core/variables/variable.py
+++ b/policyengine_core/variables/variable.py
@@ -198,12 +198,14 @@ def __init__(self, baseline_variable=None):
"quantity_type",
required=False,
allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
- default=QuantityType.STOCK
- if (
- self.value_type in (bool, int, Enum, str, datetime.date)
- or self.unit == "/1"
- )
- else QuantityType.FLOW,
+ default=(
+ QuantityType.STOCK
+ if (
+ self.value_type in (bool, int, Enum, str, datetime.date)
+ or self.unit == "/1"
+ )
+ else QuantityType.FLOW
+ ),
)
self.documentation = self.set(
attr,
@@ -214,9 +216,11 @@ def __init__(self, baseline_variable=None):
self.set_input = self.set_set_input(
attr.pop(
"set_input",
- set_input_dispatch_by_period
- if self.quantity_type == QuantityType.STOCK
- else set_input_divide_by_period,
+ (
+ set_input_dispatch_by_period
+ if self.quantity_type == QuantityType.STOCK
+ else set_input_divide_by_period
+ ),
)
)
if self.definition_period in (DAY, ETERNITY):
diff --git a/tests/core/variables/test_adds.py b/tests/core/variables/test_adds.py
new file mode 100644
index 000000000..ff6f5bdd9
--- /dev/null
+++ b/tests/core/variables/test_adds.py
@@ -0,0 +1,36 @@
+import numpy as np
+
+from policyengine_core.entities import Entity
+from policyengine_core.model_api import *
+from policyengine_core.simulations import SimulationBuilder
+from policyengine_core.taxbenefitsystems import TaxBenefitSystem
+
+
+def test_bad_adds_raises_error():
+ """A basic test of an ill-defined adds attribute in a variable."""
+ Person = Entity("person", "people", "Person", "A person")
+ system = TaxBenefitSystem([Person])
+
+ class some_income(Variable):
+ value_type = float
+ entity = Person
+ definition_period = ETERNITY
+ label = "Income"
+ adds = ["income_not_defined"]
+
+ system.add_variables(some_income)
+
+ simulation = SimulationBuilder().build_from_dict(
+ system,
+ {
+ "people": {
+ "person": {},
+ },
+ },
+ )
+
+ try:
+ simulation.calculate("some_income")
+ raise Exception("Should have raised an error.")
+ except ValueError as e:
+ pass
diff --git a/tests/core/variables/test_subtracts.py b/tests/core/variables/test_subtracts.py
new file mode 100644
index 000000000..eeba872fa
--- /dev/null
+++ b/tests/core/variables/test_subtracts.py
@@ -0,0 +1,36 @@
+import numpy as np
+
+from policyengine_core.entities import Entity
+from policyengine_core.model_api import *
+from policyengine_core.simulations import SimulationBuilder
+from policyengine_core.taxbenefitsystems import TaxBenefitSystem
+
+
+def test_bad_subtracts_raises_error():
+ """A basic test of an ill-defined subtracts attribute in a variable."""
+ Person = Entity("person", "people", "Person", "A person")
+ system = TaxBenefitSystem([Person])
+
+ class some_income(Variable):
+ value_type = float
+ entity = Person
+ definition_period = ETERNITY
+ label = "Income"
+ subtracts = ["income_not_defined"]
+
+ system.add_variables(some_income)
+
+ simulation = SimulationBuilder().build_from_dict(
+ system,
+ {
+ "people": {
+ "person": {},
+ },
+ },
+ )
+
+ try:
+ simulation.calculate("some_income")
+ raise Exception("Should have raised an error.")
+ except ValueError as e:
+ pass