Skip to content

Commit

Permalink
Format and improve error handling (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilwoodruff authored Feb 6, 2024
1 parent 8609546 commit 504f4ef
Show file tree
Hide file tree
Showing 19 changed files with 192 additions and 80 deletions.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
added:
- Improvements to error handling in bad variable declarations.
26 changes: 15 additions & 11 deletions policyengine_core/charts/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def bar_chart(
"""

hover_text_labels = [
hover_text_function(index, value)
if hover_text_function is not None
else None
(
hover_text_function(index, value)
if hover_text_function is not None
else None
)
for index, value in data.items()
]

Expand All @@ -57,9 +59,11 @@ def bar_chart(
positive_colour if v > 0 else negative_colour
for v in data.values
],
hovertemplate="%{customdata[0]}<extra></extra>"
if hover_text_labels is not None
else None,
hovertemplate=(
"%{customdata[0]}<extra></extra>"
if hover_text_labels is not None
else None
),
)
)
return format_fig(fig)
Expand Down Expand Up @@ -94,11 +98,11 @@ def cross_section_bar_chart(
"Category": category,
"Cross section": cross_section_value,
"Value": value,
"Hover text": hover_text_function(
cross_section_value, category, value
)
if hover_text_function is not None
else None,
"Hover text": (
hover_text_function(cross_section_value, category, value)
if hover_text_function is not None
else None
),
}
df = df.append(row, ignore_index=True)

Expand Down
1 change: 1 addition & 0 deletions policyengine_core/country_template/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
See https://openfisca.org/doc/key-concepts/person,_entities,_role.html
"""

from typing import Any

from policyengine_core.entities import build_entity
Expand Down
9 changes: 5 additions & 4 deletions policyengine_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,11 @@ def get_memory_usage(self) -> dict:
usage.update(
dict(
nb_requests=nb_requests,
nb_requests_by_array=nb_requests
/ float(usage["nb_arrays"])
if usage["nb_arrays"] > 0
else numpy.nan,
nb_requests_by_array=(
nb_requests / float(usage["nb_arrays"])
if usage["nb_arrays"] > 0
else numpy.nan
),
)
)

Expand Down
3 changes: 1 addition & 2 deletions policyengine_core/parameters/at_instant_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,4 @@ def get_at_instant(self, instant: Instant) -> Any:
return self._get_at_instant(instant)

@abc.abstractmethod
def _get_at_instant(self, instant):
...
def _get_at_instant(self, instant): ...
4 changes: 4 additions & 0 deletions policyengine_core/parameters/operations/uprate_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
# Apply the uprater and add to the parameter
value_at_start = parameter(last_instant)
uprater_at_start = uprating_parameter(last_instant)
if uprater_at_start is None:
raise ValueError(
f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} at {entry_instant} because the uprating parameter is not defined at {last_instant}."
)
uprater_at_entry = uprating_parameter(entry_instant)
uprater_change = uprater_at_entry / uprater_at_start
uprated_value = value_at_start * uprater_change
Expand Down
6 changes: 3 additions & 3 deletions policyengine_core/parameters/parameter_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class ParameterNode(AtInstantLike):
A node in the legislation `parameter tree <https://openfisca.org/doc/coding-the-legislation/legislation_parameters.html>`_.
"""

_allowed_keys: typing.Optional[
typing.Iterable[str]
] = None # By default, no restriction on the keys
_allowed_keys: typing.Optional[typing.Iterable[str]] = (
None # By default, no restriction on the keys
)

parent: "ParameterNode" = None
"""The parent of the node, or None if the node is the root of the tree."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ def build_from_node(
# Recursively vectorize the children of the node
vectorial_subnodes = tuple(
[
VectorialParameterNodeAtInstant.build_from_node(
node[subnode_name]
).vector
if isinstance(
node[subnode_name], parameters.ParameterNodeAtInstant
(
VectorialParameterNodeAtInstant.build_from_node(
node[subnode_name]
).vector
if isinstance(
node[subnode_name], parameters.ParameterNodeAtInstant
)
else node[subnode_name]
)
else node[subnode_name]
for subnode_name in subnodes_name
]
)
Expand All @@ -44,9 +46,11 @@ def build_from_node(
dtype=[
(
subnode_name,
subnode.dtype
if isinstance(subnode, numpy.recarray)
else "float",
(
subnode.dtype
if isinstance(subnode, numpy.recarray)
else "float"
),
)
for (subnode_name, subnode) in zip(
subnodes_name, vectorial_subnodes
Expand Down
6 changes: 3 additions & 3 deletions policyengine_core/periods/instant_.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def __str__(self) -> str:
"""
instant_str = config.str_by_instant_cache.get(self)
if instant_str is None:
config.str_by_instant_cache[
self
] = instant_str = self.date.isoformat()
config.str_by_instant_cache[self] = instant_str = (
self.date.isoformat()
)
return instant_str

@property
Expand Down
6 changes: 3 additions & 3 deletions policyengine_core/reforms/reform.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def api_id(self):
sanitised_period_values = {}
for period, value in period_values.items():
period = period_(period)
sanitised_period_values[
f"{period.start}.{period.stop}"
] = value
sanitised_period_values[f"{period.start}.{period.stop}"] = (
value
)
sanitised_parameter_values[path] = sanitised_period_values

response = requests.post(
Expand Down
1 change: 1 addition & 0 deletions policyengine_core/simulations/individual_sim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
IndividualSim and any other interfaces to intialising and running simulations on hypothetical situations. Deprecated.
"""

from functools import partial
from typing import Dict, List

Expand Down
37 changes: 25 additions & 12 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def build_from_dataset(self) -> None:
entity_id_field in data
), f"Missing {entity_id_field} column in the dataset. Each person entity must have an ID array defined for ETERNITY."

get_eternity_array = (
lambda ds: ds[list(ds.keys())[0]]
get_eternity_array = lambda ds: (
ds[list(ds.keys())[0]]
if self.dataset.data_format == Dataset.TIME_PERIOD_ARRAYS
else ds
)
Expand Down Expand Up @@ -466,6 +466,11 @@ def calculate_dataframe(
period = periods.period(period)
elif period is None and self.default_calculation_period is not None:
period = periods.period(self.default_calculation_period)

# Check each variable exists
for variable_name in variable_names:
if variable_name not in self.tax_benefit_system.variables:
raise ValueError(f"Variable {variable_name} does not exist.")
df = pd.DataFrame()
entities = [
self.tax_benefit_system.get_variable(variable_name).entity.key
Expand Down Expand Up @@ -785,7 +790,9 @@ def _run_formula(
)
values = values + parameter(period.start)
except:
pass
raise ValueError(
f"In the variable '{variable.name}', the 'adds' attribute is a list that contains a string '{added_variable}' that does not match any variable or parameter."
)
if variable.subtracts is not None and len(variable.subtracts) > 0:
if isinstance(variable.subtracts, str):
try:
Expand Down Expand Up @@ -820,7 +827,9 @@ def _run_formula(
)
values = values + parameter(period.start)
except:
pass
raise ValueError(
f"In the variable '{variable.name}', the 'subtracts' attribute is a list that contains a string '{subtracted_variable}' that does not match any variable or parameter."
)
return values

if self.trace and not isinstance(
Expand Down Expand Up @@ -873,9 +882,11 @@ def _check_period_consistency(
"Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format(
variable.name,
period,
"month"
if variable.definition_period == periods.MONTH
else "year",
(
"month"
if variable.definition_period == periods.MONTH
else "year"
),
)
)

Expand Down Expand Up @@ -1244,8 +1255,9 @@ def extract_person(
variable
).get_known_periods()
if len(known_periods) > 0:
value = self.get_holder(variable).get_array(
known_periods[0]
first_known_period = known_periods[0]
value = self.calculate(
variable, first_known_period
)[group_index]
situation[entity.plural][entity.key][variable] = {
str(known_periods[0]): value
Expand All @@ -1271,9 +1283,10 @@ def extract_person(
variable
).get_known_periods()
if len(known_periods) > 0:
value = self.get_holder(variable).get_array(
known_periods[0]
)[person_index]
first_known_period = known_periods[0]
value = self.calculate(variable, first_known_period)[
person_index
]
situation[person.plural][person_name][variable] = {
str(known_periods[0]): value
}
Expand Down
26 changes: 15 additions & 11 deletions policyengine_core/simulations/simulation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@

class SimulationBuilder:
def __init__(self):
self.default_period: Period = None # Simulation period used for variables when no period is defined
self.persons_plural: str = None # Plural name for person entity in current tax and benefits system
self.default_period: Period = (
None # Simulation period used for variables when no period is defined
)
self.persons_plural: str = (
None # Plural name for person entity in current tax and benefits system
)

# JSON input - Memory of known input values. Indexed by variable or axis name.
self.input_buffer: typing.Dict[
Expand All @@ -55,9 +59,9 @@ def __init__(self):
self.has_axes = False
self.axes_entity_counts: typing.Dict[Entity.plural, int] = {}
self.axes_entity_ids: typing.Dict[Entity.plural, typing.List[int]] = {}
self.axes_memberships: typing.Dict[
Entity.plural, typing.List[int]
] = {}
self.axes_memberships: typing.Dict[Entity.plural, typing.List[int]] = (
{}
)
self.axes_roles: typing.Dict[Entity.plural, typing.List[int]] = {}

def build_from_dict(
Expand Down Expand Up @@ -484,12 +488,12 @@ def add_group_entity(
entity_ids = entity_ids + list(persons_to_allocate)
for person_id in persons_to_allocate:
person_index = persons_ids.index(person_id)
self.memberships[entity.plural][
person_index
] = entity_ids.index(person_id)
self.roles[entity.plural][
person_index
] = entity.flattened_roles[0]
self.memberships[entity.plural][person_index] = (
entity_ids.index(person_id)
)
self.roles[entity.plural][person_index] = (
entity.flattened_roles[0]
)
# Adjust previously computed ids and counts
self.entity_ids[entity.plural] = entity_ids
self.entity_counts[entity.plural] = len(entity_ids)
Expand Down
10 changes: 6 additions & 4 deletions policyengine_core/taxbenefitsystems/tax_benefit_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ class TaxBenefitSystem:
_parameters_at_instant_cache: Optional[Dict[Any, Any]] = None
person_key_plural: str = None
preprocess_parameters: str = None
baseline: "TaxBenefitSystem" = None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained.
baseline: "TaxBenefitSystem" = (
None # Baseline tax-benefit system. Used only by reforms. Note: Reforms can be chained.
)
cache_blacklist = None
decomposition_file_path = None
variable_module_metadata: dict = None
Expand Down Expand Up @@ -183,9 +185,9 @@ def base_tax_benefit_system(self) -> "TaxBenefitSystem":
baseline = self.baseline
if baseline is None:
return self
self._base_tax_benefit_system = (
base_tax_benefit_system
) = baseline.base_tax_benefit_system
self._base_tax_benefit_system = base_tax_benefit_system = (
baseline.base_tax_benefit_system
)
return base_tax_benefit_system

def instantiate_entities(self) -> Dict[str, Population]:
Expand Down
9 changes: 3 additions & 6 deletions policyengine_core/taxscales/tax_scale_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,17 @@ def __ne__(self, _other: object) -> typing.NoReturn:
)

@abc.abstractmethod
def __repr__(self) -> str:
...
def __repr__(self) -> str: ...

@abc.abstractmethod
def calc(
self,
tax_base: NumericalArray,
right: bool,
) -> numpy.float_:
...
) -> numpy.float_: ...

@abc.abstractmethod
def to_dict(self) -> dict:
...
def to_dict(self) -> dict: ...

def copy(self) -> typing.Any:
new = commons.empty_clone(self)
Expand Down
8 changes: 5 additions & 3 deletions policyengine_core/variables/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def get_neutralized_variable(variable):
result = variable.clone()
result.is_neutralized = True
result.label = (
"[Neutralized]"
if variable.label is None
else "[Neutralized] {}".format(variable.label),
(
"[Neutralized]"
if variable.label is None
else "[Neutralized] {}".format(variable.label)
),
)

return result
Expand Down
Loading

0 comments on commit 504f4ef

Please sign in to comment.