Skip to content

Commit

Permalink
feat: add multiprocessing to ContextualCountEmbedder (#450)
Browse files Browse the repository at this point in the history
* feat: add multiprocessing to ContextualCountEmbedder

* chore: added changelog entry

* chore: change tests and make multiprocessing conditional
  • Loading branch information
RaczeQ authored May 5, 2024
1 parent e80e2f4 commit 68eea1f
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 26 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Support for Python 3.12 after upgrading QuackOSM to `0.7.0` and DuckDB to `0.10.2`

### Changed

- Refactored `ContextualCountEmbedder` by adding multiprocessing for faster transformations

## [0.7.3] - 2024-04-21

### Changed
Expand Down
152 changes: 133 additions & 19 deletions srai/embedders/contextual_count_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
1. https://arxiv.org/abs/2111.00990
"""

from collections.abc import Iterator
from typing import Optional, Union
from collections.abc import Collection, Iterator
from functools import partial
from math import ceil
from multiprocessing import cpu_count
from typing import Any, Optional, Union

import geopandas as gpd
import numpy as np
import numpy.typing as npt
import pandas as pd
from tqdm import trange
from tqdm.contrib.concurrent import process_map

from srai.embedders.count_embedder import CountEmbedder
from srai.loaders.osm_loaders.filters import GroupedOsmTagsFilter, OsmTagsFilter
Expand All @@ -34,6 +39,8 @@ def __init__(
Union[list[str], OsmTagsFilter, GroupedOsmTagsFilter]
] = None,
count_subcategories: bool = False,
num_of_multiprocessing_workers: int = -1,
multiprocessing_activation_threshold: Optional[int] = None,
) -> None:
"""
Init ContextualCountEmbedder.
Expand All @@ -55,6 +62,13 @@ def __init__(
count_subcategories (bool, optional): Whether to count all subcategories individually
or count features only on the highest level based on features column name.
Defaults to False.
num_of_multiprocessing_workers (int, optional): Number of workers used for
multiprocessing. Defaults to -1 which results in a total number of available
cpu threads. `0` and `1` values disable multiprocessing.
Similar to `n_jobs` parameter from `scikit-learn` library.
multiprocessing_activation_threshold (int, optional): Number of seeds required to start
processing on multiple processes. Activating multiprocessing for a small
amount of points might not be feasible. Defaults to 100.
Raises:
ValueError: If `neighbourhood_distance` is negative.
Expand All @@ -68,6 +82,13 @@ def __init__(
if self.neighbourhood_distance < 0:
raise ValueError("Neighbourhood distance must be positive.")

self.num_of_multiprocessing_workers = _parse_num_of_multiprocessing_workers(
num_of_multiprocessing_workers
)
self.multiprocessing_activation_threshold = _parse_multiprocessing_activation_threshold(
multiprocessing_activation_threshold
)

def transform(
self,
regions_gdf: gpd.GeoDataFrame,
Expand Down Expand Up @@ -165,7 +186,8 @@ def _get_concatenated_embeddings(self, counts_df: pd.DataFrame) -> pd.DataFrame:

for distance, averaged_values in self._get_averaged_values_for_distances(counts_df):
result_array[
:, number_of_base_columns * distance : number_of_base_columns * (distance + 1)
:,
number_of_base_columns * distance : number_of_base_columns * (distance + 1),
] = averaged_values

return pd.DataFrame(data=result_array, index=counts_df.index, columns=columns)
Expand Down Expand Up @@ -193,25 +215,117 @@ def _get_averaged_values_for_distances(

number_of_base_columns = len(counts_df.columns)

for distance in range(1, self.neighbourhood_distance + 1):
neighbours_series = counts_df.index.map(
lambda region_id, neighbour_distance=distance: counts_df.index.intersection(
self.neighbourhood.get_neighbours_at_distance(
region_id, neighbour_distance, include_center=False
)
).values
)
activate_multiprocessing = (
self.num_of_multiprocessing_workers > 1
and len(counts_df.index) >= self.multiprocessing_activation_threshold
)

for distance in trange(
1,
self.neighbourhood_distance + 1,
desc="Generating embeddings for neighbours",
):
if len(counts_df.index) == 0:
continue

if activate_multiprocessing:
fn_neighbours = partial(
_get_existing_neighbours_at_distance,
neighbour_distance=distance,
counts_index=counts_df.index,
neighbourhood=self.neighbourhood,
)
neighbours_series = process_map(
fn_neighbours,
counts_df.index,
max_workers=self.num_of_multiprocessing_workers,
chunksize=ceil(
len(counts_df.index) / (4 * self.num_of_multiprocessing_workers)
),
disable=True,
)
else:
neighbours_series = counts_df.index.map(
lambda region_id, neighbour_distance=distance: counts_df.index.intersection(
self.neighbourhood.get_neighbours_at_distance(
region_id, neighbour_distance, include_center=False
)
).values
)

if len(neighbours_series) == 0:
continue

averaged_values_stacked = np.stack(
neighbours_series.map(
lambda region_ids: (
np.nan_to_num(np.nanmean(counts_df.loc[region_ids].values, axis=0))
if len(region_ids) > 0
else np.zeros((number_of_base_columns,))
if activate_multiprocessing:
fn_embeddings = partial(
_get_embeddings_for_neighbours,
counts_df=counts_df,
number_of_base_columns=number_of_base_columns,
)

averaged_values_stacked = np.stack(
process_map(
fn_embeddings,
neighbours_series,
max_workers=self.num_of_multiprocessing_workers,
chunksize=ceil(
len(neighbours_series) / (4 * self.num_of_multiprocessing_workers)
),
disable=True,
)
).values
)
)
else:
averaged_values_stacked = np.stack(
neighbours_series.map(
lambda region_ids: (
np.nan_to_num(np.nanmean(counts_df.loc[region_ids].values, axis=0))
if len(region_ids) > 0
else np.zeros((number_of_base_columns,))
)
).values
)

yield distance, averaged_values_stacked


def _parse_num_of_multiprocessing_workers(num_of_multiprocessing_workers: int) -> int:
if num_of_multiprocessing_workers == 0:
num_of_multiprocessing_workers = 1
elif num_of_multiprocessing_workers < 0:
num_of_multiprocessing_workers = cpu_count()

return num_of_multiprocessing_workers


def _parse_multiprocessing_activation_threshold(
multiprocessing_activation_threshold: Optional[int],
) -> int:
if not multiprocessing_activation_threshold:
multiprocessing_activation_threshold = 100

return multiprocessing_activation_threshold


def _get_existing_neighbours_at_distance(
region_id: IndexType,
neighbour_distance: int,
neighbourhood: Neighbourhood[IndexType],
counts_index: pd.Index,
) -> Any:
return counts_index.intersection(
neighbourhood.get_neighbours_at_distance(
region_id, neighbour_distance, include_center=False
)
).values


def _get_embeddings_for_neighbours(
region_ids: Collection[IndexType],
counts_df: pd.DataFrame,
number_of_base_columns: int,
) -> Any:
return (
np.nan_to_num(np.nanmean(counts_df.loc[region_ids].values, axis=0))
if len(region_ids) > 0
else np.zeros((number_of_base_columns,))
)
55 changes: 48 additions & 7 deletions tests/embedders/test_contextual_count_embedder.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
"""ContextualCountEmbedder tests."""

from contextlib import nullcontext as does_not_raise
from typing import TYPE_CHECKING, Any, Union
from pathlib import Path
from typing import Any, Union

import geopandas as gpd
import pandas as pd
import pytest
from pandas.testing import assert_frame_equal
from parametrization import Parametrization as P
from shapely.geometry import Polygon

from srai.constants import REGIONS_INDEX
from srai.constants import REGIONS_INDEX, WGS84_CRS
from srai.embedders import ContextualCountEmbedder
from srai.loaders.osm_loaders.filters import OsmTagsFilter
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import GEOFABRIK_LAYERS, OsmTagsFilter
from srai.neighbourhoods import H3Neighbourhood

if TYPE_CHECKING: # pragma: no cover
import geopandas as gpd
from srai.regionalizers import H3Regionalizer


def _create_features_dataframe(data: dict[str, Any]) -> pd.DataFrame:
Expand Down Expand Up @@ -721,7 +724,9 @@ def test_correct_embedding(

expected_result_df = request.getfixturevalue(expected_embedding_fixture)
assert_frame_equal(
embedding_df.sort_index(axis=1), expected_result_df.sort_index(axis=1), check_dtype=False
embedding_df.sort_index(axis=1),
expected_result_df.sort_index(axis=1),
check_dtype=False,
)


Expand Down Expand Up @@ -873,3 +878,39 @@ def test_incorrect_indexes(
concatenate_vectors=concatenate_features,
neighbourhood_distance=neighbourhood_distance,
).transform(regions_gdf=regions_gdf, features_gdf=features_gdf, joint_gdf=joint_gdf)


def test_bigger_example() -> None:
"""Test bigger example to get multiprocessing in action."""
geometry = gpd.GeoDataFrame(
geometry=[
Polygon(
[
(7.416769421059001, 43.7346112362936),
(7.416769421059001, 43.730681304758946),
(7.4218262821731, 43.730681304758946),
(7.4218262821731, 43.7346112362936),
]
)
],
crs=WGS84_CRS,
)

regions = H3Regionalizer(resolution=13).transform(geometry)
features = OSMPbfLoader(
pbf_file=Path(__file__).parent.parent
/ "loaders"
/ "osm_loaders"
/ "test_files"
/ "monaco.osm.pbf"
).load(area=regions, tags=GEOFABRIK_LAYERS)
joint = IntersectionJoiner().transform(regions=regions, features=features)
embeddings = ContextualCountEmbedder(
neighbourhood=H3Neighbourhood(),
neighbourhood_distance=10,
expected_output_features=GEOFABRIK_LAYERS,
).transform(regions_gdf=regions, features_gdf=features, joint_gdf=joint)

assert len(embeddings) == len(
regions
), f"Mismatched number of rows ({len(embeddings)}, {len(regions)})"

0 comments on commit 68eea1f

Please sign in to comment.