From 504f4ef7a90d072b86f1b31673d2d90318825a9d Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff <35577657+nikhilwoodruff@users.noreply.github.com> Date: Tue, 6 Feb 2024 11:28:09 +0000 Subject: [PATCH] Format and improve error handling (#152) --- changelog_entry.yaml | 4 ++ policyengine_core/charts/bar.py | 26 +++++++------ .../country_template/entities.py | 1 + policyengine_core/holders/holder.py | 9 +++-- .../parameters/at_instant_like.py | 3 +- .../operations/uprate_parameters.py | 4 ++ .../parameters/parameter_node.py | 6 +-- .../vectorial_parameter_node_at_instant.py | 22 ++++++----- policyengine_core/periods/instant_.py | 6 +-- policyengine_core/reforms/reform.py | 6 +-- .../simulations/individual_sim.py | 1 + policyengine_core/simulations/simulation.py | 37 +++++++++++++------ .../simulations/simulation_builder.py | 26 +++++++------ .../taxbenefitsystems/tax_benefit_system.py | 10 +++-- policyengine_core/taxscales/tax_scale_like.py | 9 ++--- policyengine_core/variables/helpers.py | 8 ++-- policyengine_core/variables/variable.py | 22 ++++++----- tests/core/variables/test_adds.py | 36 ++++++++++++++++++ tests/core/variables/test_subtracts.py | 36 ++++++++++++++++++ 19 files changed, 192 insertions(+), 80 deletions(-) create mode 100644 tests/core/variables/test_adds.py create mode 100644 tests/core/variables/test_subtracts.py diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..d53617c8c 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + added: + - Improvements to error handling in bad variable declarations. diff --git a/policyengine_core/charts/bar.py b/policyengine_core/charts/bar.py index bd3366df5..8f54ed169 100644 --- a/policyengine_core/charts/bar.py +++ b/policyengine_core/charts/bar.py @@ -32,9 +32,11 @@ def bar_chart( """ hover_text_labels = [ - hover_text_function(index, value) - if hover_text_function is not None - else None + ( + hover_text_function(index, value) + if hover_text_function is not None + else None + ) for index, value in data.items() ] @@ -57,9 +59,11 @@ def bar_chart( positive_colour if v > 0 else negative_colour for v in data.values ], - hovertemplate="%{customdata[0]}" - if hover_text_labels is not None - else None, + hovertemplate=( + "%{customdata[0]}" + if hover_text_labels is not None + else None + ), ) ) return format_fig(fig) @@ -94,11 +98,11 @@ def cross_section_bar_chart( "Category": category, "Cross section": cross_section_value, "Value": value, - "Hover text": hover_text_function( - cross_section_value, category, value - ) - if hover_text_function is not None - else None, + "Hover text": ( + hover_text_function(cross_section_value, category, value) + if hover_text_function is not None + else None + ), } df = df.append(row, ignore_index=True) diff --git a/policyengine_core/country_template/entities.py b/policyengine_core/country_template/entities.py index c085f8a77..f68b42d63 100644 --- a/policyengine_core/country_template/entities.py +++ b/policyengine_core/country_template/entities.py @@ -5,6 +5,7 @@ See https://openfisca.org/doc/key-concepts/person,_entities,_role.html """ + from typing import Any from policyengine_core.entities import build_entity diff --git a/policyengine_core/holders/holder.py b/policyengine_core/holders/holder.py index 084d86943..972a1ff28 100644 --- a/policyengine_core/holders/holder.py +++ b/policyengine_core/holders/holder.py @@ -157,10 +157,11 @@ def get_memory_usage(self) -> dict: usage.update( dict( nb_requests=nb_requests, - nb_requests_by_array=nb_requests - / float(usage["nb_arrays"]) - if usage["nb_arrays"] > 0 - else numpy.nan, + nb_requests_by_array=( + nb_requests / float(usage["nb_arrays"]) + if usage["nb_arrays"] > 0 + else numpy.nan + ), ) ) diff --git a/policyengine_core/parameters/at_instant_like.py b/policyengine_core/parameters/at_instant_like.py index 503b23592..ecde86143 100644 --- a/policyengine_core/parameters/at_instant_like.py +++ b/policyengine_core/parameters/at_instant_like.py @@ -18,5 +18,4 @@ def get_at_instant(self, instant: Instant) -> Any: return self._get_at_instant(instant) @abc.abstractmethod - def _get_at_instant(self, instant): - ... + def _get_at_instant(self, instant): ... diff --git a/policyengine_core/parameters/operations/uprate_parameters.py b/policyengine_core/parameters/operations/uprate_parameters.py index d94def379..530ed1839 100644 --- a/policyengine_core/parameters/operations/uprate_parameters.py +++ b/policyengine_core/parameters/operations/uprate_parameters.py @@ -85,6 +85,10 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode: # Apply the uprater and add to the parameter value_at_start = parameter(last_instant) uprater_at_start = uprating_parameter(last_instant) + if uprater_at_start is None: + raise ValueError( + f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} at {entry_instant} because the uprating parameter is not defined at {last_instant}." + ) uprater_at_entry = uprating_parameter(entry_instant) uprater_change = uprater_at_entry / uprater_at_start uprated_value = value_at_start * uprater_change diff --git a/policyengine_core/parameters/parameter_node.py b/policyengine_core/parameters/parameter_node.py index f8952c2ca..050b806b0 100644 --- a/policyengine_core/parameters/parameter_node.py +++ b/policyengine_core/parameters/parameter_node.py @@ -28,9 +28,9 @@ class ParameterNode(AtInstantLike): A node in the legislation `parameter tree `_. """ - _allowed_keys: typing.Optional[ - typing.Iterable[str] - ] = None # By default, no restriction on the keys + _allowed_keys: typing.Optional[typing.Iterable[str]] = ( + None # By default, no restriction on the keys + ) parent: "ParameterNode" = None """The parent of the node, or None if the node is the root of the tree.""" diff --git a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py index e4914ed3e..dbd1d14b9 100644 --- a/policyengine_core/parameters/vectorial_parameter_node_at_instant.py +++ b/policyengine_core/parameters/vectorial_parameter_node_at_instant.py @@ -27,13 +27,15 @@ def build_from_node( # Recursively vectorize the children of the node vectorial_subnodes = tuple( [ - VectorialParameterNodeAtInstant.build_from_node( - node[subnode_name] - ).vector - if isinstance( - node[subnode_name], parameters.ParameterNodeAtInstant + ( + VectorialParameterNodeAtInstant.build_from_node( + node[subnode_name] + ).vector + if isinstance( + node[subnode_name], parameters.ParameterNodeAtInstant + ) + else node[subnode_name] ) - else node[subnode_name] for subnode_name in subnodes_name ] ) @@ -44,9 +46,11 @@ def build_from_node( dtype=[ ( subnode_name, - subnode.dtype - if isinstance(subnode, numpy.recarray) - else "float", + ( + subnode.dtype + if isinstance(subnode, numpy.recarray) + else "float" + ), ) for (subnode_name, subnode) in zip( subnodes_name, vectorial_subnodes diff --git a/policyengine_core/periods/instant_.py b/policyengine_core/periods/instant_.py index d226c558a..4288e3a4a 100644 --- a/policyengine_core/periods/instant_.py +++ b/policyengine_core/periods/instant_.py @@ -35,9 +35,9 @@ def __str__(self) -> str: """ instant_str = config.str_by_instant_cache.get(self) if instant_str is None: - config.str_by_instant_cache[ - self - ] = instant_str = self.date.isoformat() + config.str_by_instant_cache[self] = instant_str = ( + self.date.isoformat() + ) return instant_str @property diff --git a/policyengine_core/reforms/reform.py b/policyengine_core/reforms/reform.py index 3b31cfe00..363feafd9 100644 --- a/policyengine_core/reforms/reform.py +++ b/policyengine_core/reforms/reform.py @@ -206,9 +206,9 @@ def api_id(self): sanitised_period_values = {} for period, value in period_values.items(): period = period_(period) - sanitised_period_values[ - f"{period.start}.{period.stop}" - ] = value + sanitised_period_values[f"{period.start}.{period.stop}"] = ( + value + ) sanitised_parameter_values[path] = sanitised_period_values response = requests.post( diff --git a/policyengine_core/simulations/individual_sim.py b/policyengine_core/simulations/individual_sim.py index 4ea542a21..3a82a8109 100644 --- a/policyengine_core/simulations/individual_sim.py +++ b/policyengine_core/simulations/individual_sim.py @@ -1,6 +1,7 @@ """ IndividualSim and any other interfaces to intialising and running simulations on hypothetical situations. Deprecated. """ + from functools import partial from typing import Dict, List diff --git a/policyengine_core/simulations/simulation.py b/policyengine_core/simulations/simulation.py index f89d91a3e..40ec8b30b 100644 --- a/policyengine_core/simulations/simulation.py +++ b/policyengine_core/simulations/simulation.py @@ -241,8 +241,8 @@ def build_from_dataset(self) -> None: entity_id_field in data ), f"Missing {entity_id_field} column in the dataset. Each person entity must have an ID array defined for ETERNITY." - get_eternity_array = ( - lambda ds: ds[list(ds.keys())[0]] + get_eternity_array = lambda ds: ( + ds[list(ds.keys())[0]] if self.dataset.data_format == Dataset.TIME_PERIOD_ARRAYS else ds ) @@ -466,6 +466,11 @@ def calculate_dataframe( period = periods.period(period) elif period is None and self.default_calculation_period is not None: period = periods.period(self.default_calculation_period) + + # Check each variable exists + for variable_name in variable_names: + if variable_name not in self.tax_benefit_system.variables: + raise ValueError(f"Variable {variable_name} does not exist.") df = pd.DataFrame() entities = [ self.tax_benefit_system.get_variable(variable_name).entity.key @@ -785,7 +790,9 @@ def _run_formula( ) values = values + parameter(period.start) except: - pass + raise ValueError( + f"In the variable '{variable.name}', the 'adds' attribute is a list that contains a string '{added_variable}' that does not match any variable or parameter." + ) if variable.subtracts is not None and len(variable.subtracts) > 0: if isinstance(variable.subtracts, str): try: @@ -820,7 +827,9 @@ def _run_formula( ) values = values + parameter(period.start) except: - pass + raise ValueError( + f"In the variable '{variable.name}', the 'subtracts' attribute is a list that contains a string '{subtracted_variable}' that does not match any variable or parameter." + ) return values if self.trace and not isinstance( @@ -873,9 +882,11 @@ def _check_period_consistency( "Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format( variable.name, period, - "month" - if variable.definition_period == periods.MONTH - else "year", + ( + "month" + if variable.definition_period == periods.MONTH + else "year" + ), ) ) @@ -1244,8 +1255,9 @@ def extract_person( variable ).get_known_periods() if len(known_periods) > 0: - value = self.get_holder(variable).get_array( - known_periods[0] + first_known_period = known_periods[0] + value = self.calculate( + variable, first_known_period )[group_index] situation[entity.plural][entity.key][variable] = { str(known_periods[0]): value @@ -1271,9 +1283,10 @@ def extract_person( variable ).get_known_periods() if len(known_periods) > 0: - value = self.get_holder(variable).get_array( - known_periods[0] - )[person_index] + first_known_period = known_periods[0] + value = self.calculate(variable, first_known_period)[ + person_index + ] situation[person.plural][person_name][variable] = { str(known_periods[0]): value } diff --git a/policyengine_core/simulations/simulation_builder.py b/policyengine_core/simulations/simulation_builder.py index b39e953f0..e36afdb2d 100644 --- a/policyengine_core/simulations/simulation_builder.py +++ b/policyengine_core/simulations/simulation_builder.py @@ -32,8 +32,12 @@ class SimulationBuilder: def __init__(self): - self.default_period: Period = None # Simulation period used for variables when no period is defined - self.persons_plural: str = None # Plural name for person entity in current tax and benefits system + self.default_period: Period = ( + None # Simulation period used for variables when no period is defined + ) + self.persons_plural: str = ( + None # Plural name for person entity in current tax and benefits system + ) # JSON input - Memory of known input values. Indexed by variable or axis name. self.input_buffer: typing.Dict[ @@ -55,9 +59,9 @@ def __init__(self): self.has_axes = False self.axes_entity_counts: typing.Dict[Entity.plural, int] = {} self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {} - self.axes_memberships: typing.Dict[ - Entity.plural, typing.List[int] - ] = {} + self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = ( + {} + ) self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {} def build_from_dict( @@ -484,12 +488,12 @@ def add_group_entity( entity_ids = entity_ids + list(persons_to_allocate) for person_id in persons_to_allocate: person_index = persons_ids.index(person_id) - self.memberships[entity.plural][ - person_index - ] = entity_ids.index(person_id) - self.roles[entity.plural][ - person_index - ] = entity.flattened_roles[0] + self.memberships[entity.plural][person_index] = ( + entity_ids.index(person_id) + ) + self.roles[entity.plural][person_index] = ( + entity.flattened_roles[0] + ) # Adjust previously computed ids and counts self.entity_ids[entity.plural] = entity_ids self.entity_counts[entity.plural] = len(entity_ids) diff --git a/policyengine_core/taxbenefitsystems/tax_benefit_system.py b/policyengine_core/taxbenefitsystems/tax_benefit_system.py index 51c8b17de..f8ced5930 100644 --- a/policyengine_core/taxbenefitsystems/tax_benefit_system.py +++ b/policyengine_core/taxbenefitsystems/tax_benefit_system.py @@ -69,7 +69,9 @@ class TaxBenefitSystem: _parameters_at_instant_cache: Optional[Dict[Any, Any]] = None person_key_plural: str = None preprocess_parameters: str = None - baseline: "TaxBenefitSystem" = None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained. + baseline: "TaxBenefitSystem" = ( + None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained. + ) cache_blacklist = None decomposition_file_path = None variable_module_metadata: dict = None @@ -183,9 +185,9 @@ def base_tax_benefit_system(self) -> "TaxBenefitSystem": baseline = self.baseline if baseline is None: return self - self._base_tax_benefit_system = ( - base_tax_benefit_system - ) = baseline.base_tax_benefit_system + self._base_tax_benefit_system = base_tax_benefit_system = ( + baseline.base_tax_benefit_system + ) return base_tax_benefit_system def instantiate_entities(self) -> Dict[str, Population]: diff --git a/policyengine_core/taxscales/tax_scale_like.py b/policyengine_core/taxscales/tax_scale_like.py index b9a7e5f14..dbedea580 100644 --- a/policyengine_core/taxscales/tax_scale_like.py +++ b/policyengine_core/taxscales/tax_scale_like.py @@ -48,20 +48,17 @@ def __ne__(self, _other: object) -> typing.NoReturn: ) @abc.abstractmethod - def __repr__(self) -> str: - ... + def __repr__(self) -> str: ... @abc.abstractmethod def calc( self, tax_base: NumericalArray, right: bool, - ) -> numpy.float_: - ... + ) -> numpy.float_: ... @abc.abstractmethod - def to_dict(self) -> dict: - ... + def to_dict(self) -> dict: ... def copy(self) -> typing.Any: new = commons.empty_clone(self) diff --git a/policyengine_core/variables/helpers.py b/policyengine_core/variables/helpers.py index 04fab093f..5eca45f5d 100644 --- a/policyengine_core/variables/helpers.py +++ b/policyengine_core/variables/helpers.py @@ -49,9 +49,11 @@ def get_neutralized_variable(variable): result = variable.clone() result.is_neutralized = True result.label = ( - "[Neutralized]" - if variable.label is None - else "[Neutralized] {}".format(variable.label), + ( + "[Neutralized]" + if variable.label is None + else "[Neutralized] {}".format(variable.label) + ), ) return result diff --git a/policyengine_core/variables/variable.py b/policyengine_core/variables/variable.py index 00877a49f..4674d26c1 100644 --- a/policyengine_core/variables/variable.py +++ b/policyengine_core/variables/variable.py @@ -198,12 +198,14 @@ def __init__(self, baseline_variable=None): "quantity_type", required=False, allowed_values=(QuantityType.STOCK, QuantityType.FLOW), - default=QuantityType.STOCK - if ( - self.value_type in (bool, int, Enum, str, datetime.date) - or self.unit == "/1" - ) - else QuantityType.FLOW, + default=( + QuantityType.STOCK + if ( + self.value_type in (bool, int, Enum, str, datetime.date) + or self.unit == "/1" + ) + else QuantityType.FLOW + ), ) self.documentation = self.set( attr, @@ -214,9 +216,11 @@ def __init__(self, baseline_variable=None): self.set_input = self.set_set_input( attr.pop( "set_input", - set_input_dispatch_by_period - if self.quantity_type == QuantityType.STOCK - else set_input_divide_by_period, + ( + set_input_dispatch_by_period + if self.quantity_type == QuantityType.STOCK + else set_input_divide_by_period + ), ) ) if self.definition_period in (DAY, ETERNITY): diff --git a/tests/core/variables/test_adds.py b/tests/core/variables/test_adds.py new file mode 100644 index 000000000..ff6f5bdd9 --- /dev/null +++ b/tests/core/variables/test_adds.py @@ -0,0 +1,36 @@ +import numpy as np + +from policyengine_core.entities import Entity +from policyengine_core.model_api import * +from policyengine_core.simulations import SimulationBuilder +from policyengine_core.taxbenefitsystems import TaxBenefitSystem + + +def test_bad_adds_raises_error(): + """A basic test of an ill-defined adds attribute in a variable.""" + Person = Entity("person", "people", "Person", "A person") + system = TaxBenefitSystem([Person]) + + class some_income(Variable): + value_type = float + entity = Person + definition_period = ETERNITY + label = "Income" + adds = ["income_not_defined"] + + system.add_variables(some_income) + + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": { + "person": {}, + }, + }, + ) + + try: + simulation.calculate("some_income") + raise Exception("Should have raised an error.") + except ValueError as e: + pass diff --git a/tests/core/variables/test_subtracts.py b/tests/core/variables/test_subtracts.py new file mode 100644 index 000000000..eeba872fa --- /dev/null +++ b/tests/core/variables/test_subtracts.py @@ -0,0 +1,36 @@ +import numpy as np + +from policyengine_core.entities import Entity +from policyengine_core.model_api import * +from policyengine_core.simulations import SimulationBuilder +from policyengine_core.taxbenefitsystems import TaxBenefitSystem + + +def test_bad_subtracts_raises_error(): + """A basic test of an ill-defined subtracts attribute in a variable.""" + Person = Entity("person", "people", "Person", "A person") + system = TaxBenefitSystem([Person]) + + class some_income(Variable): + value_type = float + entity = Person + definition_period = ETERNITY + label = "Income" + subtracts = ["income_not_defined"] + + system.add_variables(some_income) + + simulation = SimulationBuilder().build_from_dict( + system, + { + "people": { + "person": {}, + }, + }, + ) + + try: + simulation.calculate("some_income") + raise Exception("Should have raised an error.") + except ValueError as e: + pass