Skip to content

Commit

Permalink
Add functionality for extracting households (#144)
Browse files Browse the repository at this point in the history
* Add functionality to sample a household from a microsim
Fixes #143

* Versioning

* Add random sample feature
  • Loading branch information
nikhilwoodruff authored Jan 2, 2024
1 parent 70de623 commit 5200aeb
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 0 deletions.
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- Simulation helper to extract individual households from a microsimulation.
110 changes: 110 additions & 0 deletions policyengine_core/simulations/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,3 +1182,113 @@ def derivative(
new_value = alt_sim.calculate(variable, period)
difference = new_value - original_value
return difference / delta

def sample_person(self) -> dict:
"""
Sample a person from the simulation. Returns a situation JSON with their inputs (including their containing entities).
Returns:
dict: A dictionary containing the person's values.
"""
person_count = self.persons.count
index = np.random.randint(person_count)
return self.extract_person(index)

def extract_person(
self,
index: int = 0,
exclude_entities: tuple = ("state",),
) -> dict:
"""
Extract a person from the simulation. Returns a situation JSON with their inputs (including their containing entities).
Args:
index (int): The index of the person to extract.
Returns:
dict: A dictionary containing the person's values.
"""
situation = {}
people_indices = []
people_indices_by_entity = {}

for population in self.populations.values():
entity = population.entity
if (
not population.entity.is_person
and entity.key not in exclude_entities
):
situation[entity.plural] = {
entity.key: {
"members": [],
},
}
group_index = population.members_entity_id[index]
other_people_indices = [
index
for index in range(len(population.members_entity_id))
if population.members_entity_id[index] == group_index
]

people_indices.extend(other_people_indices)
people_indices = list(set(people_indices))
people_indices_by_entity[entity.key] = other_people_indices
for variable in self.input_variables:
if (
self.tax_benefit_system.get_variable(
variable
).entity.key
== entity.key
):
known_periods = self.get_holder(
variable
).get_known_periods()
if len(known_periods) > 0:
value = self.get_holder(variable).get_array(
known_periods[0]
)[group_index]
situation[entity.plural][entity.key][variable] = {
str(known_periods[0]): value
}

person = self.populations["person"].entity
situation[person.plural] = {}
for person_index in people_indices:
person_name = f"{person.key}_{person_index + 1}"
for entity_key in people_indices_by_entity:
entity = self.populations[entity_key].entity
if person_index in people_indices_by_entity[entity.key]:
situation[entity.plural][entity.key]["members"].append(
person_name
)
situation[person.plural][person_name] = {}
for variable in self.input_variables:
if (
self.tax_benefit_system.get_variable(variable).entity.key
== person.key
):
known_periods = self.get_holder(
variable
).get_known_periods()
if len(known_periods) > 0:
value = self.get_holder(variable).get_array(
known_periods[0]
)[person_index]
situation[person.plural][person_name][variable] = {
str(known_periods[0]): value
}

return json.loads(json.dumps(situation, cls=NpEncoder))


class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.bool_):
return bool(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return str(obj)
2 changes: 2 additions & 0 deletions policyengine_core/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from policyengine_core.enums import EnumArray

from .test_from_situation import generate_test_from_situation


def assert_near(
value,
Expand Down
23 changes: 23 additions & 0 deletions policyengine_core/tools/test_from_situation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import yaml
from pathlib import Path
import numpy as np
import json


def generate_test_from_situation(situation: dict, file_path: str):
"""Generate a test from a situation.
Args:
situation (dict): The situation to generate the test from.
test_name (str): The name of the test.
"""

yaml_contents = [
{
"input": situation,
"output": {},
}
]

with open(Path(file_path), "w+") as f:
yaml.dump(yaml_contents, f)

0 comments on commit 5200aeb

Please sign in to comment.