Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic period adjustment #119

Merged
merged 6 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- bump: minor
changes:
added:
- Automatic period adjustment helper functionality.
changed:
- Default error threshold for tests widened to 1e-3.
29 changes: 4 additions & 25 deletions policyengine_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ def set_input(
return warnings.warn(warning_message, Warning)
if self.variable.value_type in (float, int) and isinstance(array, str):
array = tools.eval_expression(array)
if self.variable.set_input:
if (
self.variable.set_input
and period.unit != self.variable.definition_period
):
return self.variable.set_input(self, period, array)
return self._set(period, array, branch_name)

Expand Down Expand Up @@ -285,30 +288,6 @@ def _set(
raise ValueError(
"A period must be specified to set values, except for variables with periods.ETERNITY as as period_definition."
)
if (
self.variable.definition_period != period.unit
or period.size > 1
):
name = self.variable.name
period_size_adj = (
f"{period.unit}"
if (period.size == 1)
else f"{period.size}-{period.unit}s"
)
error_message = os.linesep.join(
[
f'Unable to set a value for variable "{name}" for {period_size_adj}-long period "{period}".',
f'"{name}" can only be set for one {self.variable.definition_period} at a time. Please adapt your input.',
f'If you are the maintainer of "{name}", you can consider adding it a set_input attribute to enable automatic period casting.',
]
)

raise PeriodMismatchError(
self.variable.name,
period,
self.variable.definition_period,
error_message,
)

should_store_on_disk = (
self._on_disk_storable
Expand Down
39 changes: 30 additions & 9 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from policyengine_core.errors import CycleError, SpiralError
from policyengine_core.holders.holder import Holder
from policyengine_core.periods import Period
from policyengine_core.periods.config import ETERNITY
from policyengine_core.periods.config import ETERNITY, MONTH, YEAR
from policyengine_core.periods.helpers import period
from policyengine_core.tracers import (
FullTracer,
Expand All @@ -27,7 +27,7 @@
from policyengine_core.experimental import MemoryConfig
from policyengine_core.populations import Population
from policyengine_core.tracers import SimpleTracer
from policyengine_core.variables import Variable
from policyengine_core.variables import Variable, QuantityType
from policyengine_core.reforms.reform import Reform
from policyengine_core.parameters import get_parameter

Expand Down Expand Up @@ -454,13 +454,14 @@
variable_name, check_existence=True
)

self._check_period_consistency(period, variable)

# Check if we've neutralized via parameters.
try:
if self.tax_benefit_system.parameters(period).gov.abolitions[
variable.name
]:
if (
variable.is_neutralized
or self.tax_benefit_system.parameters(period).gov.abolitions[
variable.name
]
):
return holder.default_array()
except Exception as e:
pass
Expand All @@ -470,6 +471,20 @@
if cached_array is not None:
return cached_array

if variable.definition_period == MONTH and period.unit == YEAR:
if variable.quantity_type == QuantityType.STOCK:
contained_months = period.get_subperiods(MONTH)
return self.calculate(variable_name, contained_months[-1])

Check warning on line 477 in policyengine_core/simulations/simulation.py

View check run for this annotation

Codecov / codecov/patch

policyengine_core/simulations/simulation.py#L476-L477

Added lines #L476 - L477 were not covered by tests
else:
return self.calculate_add(variable_name, period)

Check warning on line 479 in policyengine_core/simulations/simulation.py

View check run for this annotation

Codecov / codecov/patch

policyengine_core/simulations/simulation.py#L479

Added line #L479 was not covered by tests
elif variable.definition_period == YEAR and period.unit == MONTH:
if variable.quantity_type == QuantityType.STOCK:
return self.calculate(variable_name, period.this_year)

Check warning on line 482 in policyengine_core/simulations/simulation.py

View check run for this annotation

Codecov / codecov/patch

policyengine_core/simulations/simulation.py#L482

Added line #L482 was not covered by tests
else:
return self.calculate_divide(variable_name, period)

Check warning on line 484 in policyengine_core/simulations/simulation.py

View check run for this annotation

Codecov / codecov/patch

policyengine_core/simulations/simulation.py#L484

Added line #L484 was not covered by tests

self._check_period_consistency(period, variable)

if variable.defined_for is not None:
mask = (
self.calculate(
Expand Down Expand Up @@ -607,10 +622,13 @@
)
)

return sum(
result = sum(
self.calculate(variable_name, sub_period)
for sub_period in period.get_subperiods(variable.definition_period)
)
holder = self.get_holder(variable.name)
holder.put_in_cache(result, period, self.branch_name)
return result

def calculate_divide(
self,
Expand Down Expand Up @@ -640,9 +658,12 @@

if period.unit == periods.MONTH:
computation_period = period.this_year
return (
result = (
self.calculate(variable_name, period=computation_period) / 12.0
)
holder = self.get_holder(variable.name)
holder.put_in_cache(result, period, self.branch_name)
return result
elif period.unit == periods.YEAR:
return self.calculate(variable_name, period)

Expand Down
3 changes: 2 additions & 1 deletion policyengine_core/tools/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from policyengine_core.scripts import build_tax_benefit_system
from policyengine_core.reforms import Reform, set_parameter
from policyengine_core.populations import ADD, DIVIDE

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -464,7 +465,7 @@ def assert_near(
import numpy as np

if absolute_error_margin is None and relative_error_margin is None:
absolute_error_margin = 0
absolute_error_margin = 1e-3
if not isinstance(value, np.ndarray):
value = np.array(value)
if isinstance(value, EnumArray):
Expand Down
35 changes: 27 additions & 8 deletions policyengine_core/variables/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from policyengine_core.entities import Entity
from policyengine_core.enums import Enum, EnumArray
from policyengine_core.periods import Period
from policyengine_core.holders import (
set_input_dispatch_by_period,
set_input_divide_by_period,
)
from policyengine_core.periods import DAY, ETERNITY

from . import config, helpers

Expand Down Expand Up @@ -176,13 +181,6 @@ def __init__(self, baseline_variable=None):
periods.ETERNITY,
),
)
self.quantity_type = self.set(
attr,
"quantity_type",
required=False,
allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
default=QuantityType.FLOW,
)
self.label = self.set(
attr, "label", allowed_type=str, setter=self.set_label
)
Expand All @@ -192,13 +190,34 @@ def __init__(self, baseline_variable=None):
attr, "cerfa_field", allowed_type=(str, dict)
)
self.unit = self.set(attr, "unit", allowed_type=str)
self.quantity_type = self.set(
attr,
"quantity_type",
required=False,
allowed_values=(QuantityType.STOCK, QuantityType.FLOW),
default=QuantityType.STOCK
if (
self.value_type in (bool, int, Enum, str, datetime.date)
or self.unit == "/1"
)
else QuantityType.FLOW,
)
self.documentation = self.set(
attr,
"documentation",
allowed_type=str,
setter=self.set_documentation,
)
self.set_input = self.set_set_input(attr.pop("set_input", None))
self.set_input = self.set_set_input(
attr.pop(
"set_input",
set_input_dispatch_by_period
if self.quantity_type == QuantityType.STOCK
else set_input_divide_by_period,
)
)
if self.definition_period in (DAY, ETERNITY):
self.set_input = None
self.calculate_output = self.set_calculate_output(
attr.pop("calculate_output", None)
)
Expand Down
5 changes: 0 additions & 5 deletions tests/core/test_calculate_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def simulation(tax_benefit_system):
)


def test_calculate_output_default(simulation):
with pytest.raises(ValueError):
simulation.calculate_output("simple_variable", 2017)


def test_calculate_output_add(simulation):
simulation.set_input("variable_with_calculate_output_add", "2017-01", [10])
simulation.set_input("variable_with_calculate_output_add", "2017-05", [20])
Expand Down
26 changes: 0 additions & 26 deletions tests/core/test_countries.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,6 @@ def test_non_existing_variable(simulation):
simulation.calculate("non_existent_variable", PERIOD)


@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True)
def test_calculate_variable_with_wrong_definition_period(simulation):
year = str(PERIOD.this_year)

with pytest.raises(ValueError) as error:
simulation.calculate("basic_income", year)

error_message = str(error.value)
expected_words = ["period", year, "month", "basic_income", "ADD"]

for word in expected_words:
assert (
word in error_message
), f"Expected '{word}' in error message '{error_message}'"


@pytest.mark.parametrize("simulation", [({}, PERIOD)], indirect=True)
def test_divide_option_on_month_defined_variable(simulation):
with pytest.raises(ValueError):
Expand All @@ -107,16 +91,6 @@ def test_divide_option_with_complex_period(simulation):
), f"Expected '{word}' in error message '{error_message}'"


def test_input_with_wrong_period(tax_benefit_system):
year = str(PERIOD.this_year)
variables = {"basic_income": {year: 12000}}
simulation_builder = SimulationBuilder()
simulation_builder.set_default_period(PERIOD)

with pytest.raises(ValueError):
simulation_builder.build_from_variables(tax_benefit_system, variables)


def test_variable_with_reference(make_simulation, isolated_tax_benefit_system):
variables = {"salary": 4000}
simulation = make_simulation(
Expand Down
31 changes: 2 additions & 29 deletions tests/core/test_holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,6 @@ def test_set_input_enum_item(couple):
assert result == housing.HousingOccupancyStatus.free_lodger


def test_yearly_input_month_variable(couple):
with pytest.raises(PeriodMismatchError) as error:
couple.set_input("rent", 2019, 3000)
assert (
'Unable to set a value for variable "rent" for year-long period'
in error.value.message
)


def test_3_months_input_month_variable(couple):
with pytest.raises(PeriodMismatchError) as error:
couple.set_input("rent", "month:2019-01:3", 3000)
assert (
'Unable to set a value for variable "rent" for 3-months-long period'
in error.value.message
)


def test_month_input_year_variable(couple):
with pytest.raises(PeriodMismatchError) as error:
couple.set_input("housing_tax", "2019-01", 3000)
assert (
'Unable to set a value for variable "housing_tax" for month-long period'
in error.value.message
)


def test_enum_dtype(couple):
simulation = couple
status_occupancy = numpy.asarray([2], dtype=numpy.int16)
Expand Down Expand Up @@ -157,8 +130,8 @@ def test_get_memory_usage_with_trace(single):
memory_usage = salary_holder.get_memory_usage()
assert memory_usage["nb_requests"] == 15
assert (
memory_usage["nb_requests_by_array"] == 1.25
) # 15 calculations / 12 arrays
memory_usage["nb_requests_by_array"] == 15 / 13
) # 15 calculations / 13 arrays (12 months plus the year is cached too)


def test_set_input_dispatch_by_period(single):
Expand Down
15 changes: 0 additions & 15 deletions tests/core/test_reforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,6 @@ def apply(self):
assert_near(goes_to_school, [True], absolute_error_margin=0)


def test_neutralization_optimization(make_simulation, tax_benefit_system):
reform = WithBasicIncomeNeutralized(tax_benefit_system)

period = "2017-01"
simulation = make_simulation(reform, {}, period)
simulation.debug = True

simulation.calculate("basic_income", period="2013-01")
simulation.calculate_add("basic_income", period="2013")

# As basic_income is neutralized, it should not be cached
basic_income_holder = simulation.persons.get_holder("basic_income")
assert basic_income_holder.get_known_periods() == []


def test_input_variable_neutralization(make_simulation, tax_benefit_system):
class test_salary_neutralization(Reform):
def apply(self):
Expand Down