Skip to content

Commit

Permalink
.throttle: add per_minute / per_hour
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Oct 8, 2024
1 parent dd3f8cd commit ba0785d
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 33 deletions.
23 changes: 18 additions & 5 deletions streamable/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
OSConcurrentMappingIterable,
RaisingIterator,
ThrottlingIntervalIterator,
ThrottlingPerSecondIterator,
ThrottlingPerPeriodIterator,
TruncatingOnCountIterator,
TruncatingOnPredicateIterator,
)
Expand All @@ -37,7 +37,7 @@
validate_group_size,
validate_iterator,
validate_throttle_interval,
validate_throttle_per_second,
validate_throttle_per_period,
validate_truncate_args,
)

Expand Down Expand Up @@ -155,14 +155,27 @@ def observe(iterator: Iterator[T], what: str) -> Iterator[T]:
def throttle(
iterator: Iterator[T],
per_second: int = cast(int, float("inf")),
per_minute: int = cast(int, float("inf")),
per_hour: int = cast(int, float("inf")),
interval: datetime.timedelta = datetime.timedelta(0),
) -> Iterator[T]:
validate_iterator(iterator)
validate_throttle_per_second(per_second)
validate_throttle_per_period("per_second", per_second)
validate_throttle_per_period("per_minute", per_minute)
validate_throttle_per_period("per_hour", per_hour)
validate_throttle_interval(interval)

if per_second < float("inf"):
iterator = ThrottlingPerSecondIterator(iterator, per_second)
if any(
per_period < float("inf") for per_period in (per_second, per_minute, per_hour)
):
iterator = ThrottlingPerPeriodIterator(
iterator,
[
ThrottlingPerPeriodIterator.RestrictivePeriod(1, per_second),
ThrottlingPerPeriodIterator.RestrictivePeriod(60, per_minute),
ThrottlingPerPeriodIterator.RestrictivePeriod(3660, per_hour),
],
)
if interval > datetime.timedelta(0):
iterator = ThrottlingIntervalIterator(iterator, interval.total_seconds())
return iterator
Expand Down
56 changes: 42 additions & 14 deletions streamable/iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,31 +299,59 @@ def __next__(self) -> T:
return elem


class ThrottlingPerSecondIterator(Iterator[T]):
def __init__(self, iterator: Iterator[T], per_second: int) -> None:
class ThrottlingPerPeriodIterator(Iterator[T]):
class RestrictivePeriod(NamedTuple):
seconds: int
max_yields: int

def __init__(
self,
iterator: Iterator[T],
restrictive_periods: List[RestrictivePeriod],
) -> None:
self.iterator = iterator
self.per_second = per_second

self.second: int = -1
self.yields_in_second = 0
self.restrictive_periods = [
restrictive_period
for restrictive_period in restrictive_periods
if restrictive_period.max_yields < float("inf")
]

self.floor_time_in_period_unit: List[int] = [-1] * len(self.restrictive_periods)

self.yields_in_periods = [0] * len(self.restrictive_periods)

self.offset: Optional[float] = None

def __next__(self) -> T:
current_time = time.time()
if not self.offset:
self.offset = current_time
current_time -= self.offset
current_second = int(current_time)
if self.second != current_second:
self.second = current_second
self.yields_in_second = 0

if self.yields_in_second >= self.per_second:
# sleep until next second
time.sleep(ceil(current_time) - current_time)

to_sleep = 0.0
for index, (period, max_yields_per_period) in enumerate(
self.restrictive_periods
):
time_in_period_unit = current_time / period
floor_time_in_period_unit = int(time_in_period_unit)
if self.floor_time_in_period_unit[index] != floor_time_in_period_unit:
self.floor_time_in_period_unit[index] = floor_time_in_period_unit
self.yields_in_periods[index] = 0

if self.yields_in_periods[index] >= max_yields_per_period:
# sleep until next period
to_sleep = max(
to_sleep,
(ceil(time_in_period_unit) - time_in_period_unit) * period,
)

if to_sleep:
time.sleep(to_sleep)
return next(self)

self.yields_in_second += 1
for index in range(len(self.yields_in_periods)):
self.yields_in_periods[index] += 1
return next(self.iterator)


Expand Down
29 changes: 22 additions & 7 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
validate_group_interval,
validate_group_size,
validate_throttle_interval,
validate_throttle_per_second,
validate_throttle_per_period,
validate_truncate_args,
)

Expand Down Expand Up @@ -354,23 +354,31 @@ def observe(self, what: str = "elements") -> "Stream[T]":
def throttle(
self,
per_second: int = cast(int, float("inf")),
per_minute: int = cast(int, float("inf")),
per_hour: int = cast(int, float("inf")),
interval: datetime.timedelta = datetime.timedelta(0),
) -> "Stream[T]":
"""
Slows the iteration down to ensure both:
Slows iteration to respect:
- a maximum number of yields `per_second`
- a minimum `interval` between yields`
- a maximum number of yields `per_minute`
- a maximum number of yields `per_hour`
- a minimum `interval` elapses between yields
Args:
per_second (float, optional): Maximum number of yields per second (no limit by default).
per_minute (float, optional): Maximum number of yields per minute (no limit by default).
per_hour (float, optional): Maximum number of yields per hour (no limit by default).
interval (datetime.timedelta, optional): Minimum span of time between yields (no limit by default).
Returns:
Stream[T]: A stream yielding upstream elements slower, according to `per_second` and `interval` limits.
Stream[T]: A stream yielding upstream elements under the provided rate constraints.
"""
validate_throttle_per_second(per_second)
validate_throttle_per_period("per_second", per_second)
validate_throttle_per_period("per_minute", per_minute)
validate_throttle_per_period("per_hour", per_hour)
validate_throttle_interval(interval)
return ThrottleStream(self, per_second, interval)
return ThrottleStream(self, per_second, per_minute, per_hour, interval)

def truncate(
self, count: Optional[int] = None, when: Optional[Callable[[T], Any]] = None
Expand Down Expand Up @@ -544,10 +552,17 @@ def accept(self, visitor: "Visitor[V]") -> V:

class ThrottleStream(DownStream[T, T]):
def __init__(
self, upstream: Stream[T], per_second: int, interval: datetime.timedelta
self,
upstream: Stream[T],
per_second: int,
per_minute: int,
per_hour: int,
interval: datetime.timedelta,
) -> None:
super().__init__(upstream)
self._per_second = per_second
self._per_minute = per_minute
self._per_hour = per_hour
self._interval = interval

def accept(self, visitor: "Visitor[V]") -> V:
Expand Down
6 changes: 3 additions & 3 deletions streamable/util/validationtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def validate_group_interval(interval: Optional[datetime.timedelta]):
raise ValueError(f"`interval` should be positive but got {repr(interval)}.")


def validate_throttle_per_second(per_second: int):
if per_second < 1:
def validate_throttle_per_period(per_period_arg_name: str, value: int):
if value < 1:
raise ValueError(
f"`per_second` is the maximum number of elements to yield per second, it must be >= 1 but got {per_second}."
f"`{per_period_arg_name}` is the maximum number of elements to yield per second, it must be >= 1 but got {value}."
)


Expand Down
6 changes: 5 additions & 1 deletion streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def visit_observe_stream(self, stream: ObserveStream[T]) -> Iterator[T]:

def visit_throttle_stream(self, stream: ThrottleStream[T]) -> Iterator[T]:
return functions.throttle(
stream.upstream.accept(self), stream._per_second, stream._interval
stream.upstream.accept(self),
stream._per_second,
stream._per_minute,
stream._per_hour,
stream._interval,
)

def visit_truncate_stream(self, stream: TruncateStream[T]) -> Iterator[T]:
Expand Down
2 changes: 1 addition & 1 deletion streamable/visitors/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def visit_observe_stream(self, stream: ObserveStream[T]) -> str:

def visit_throttle_stream(self, stream: ThrottleStream[T]) -> str:
self.methods_reprs.append(
f"throttle(per_second={stream._per_second}, interval={self._friendly_repr(stream._interval)})"
f"throttle(per_second={stream._per_second}, per_minute={stream._per_minute}, per_hour={stream._per_hour}, interval={self._friendly_repr(stream._interval)})"
)
return stream.upstream.accept(self)

Expand Down
47 changes: 45 additions & 2 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from streamable import Stream
from streamable.functions import NoopStopIteration
from streamable.iters import ThrottlingPerPeriodIterator
from streamable.util.functiontools import sidify

T = TypeVar("T")
Expand Down Expand Up @@ -254,7 +255,7 @@ class CustomCallable:
.group(size=100, by=None, interval=None)
.observe('groups')
.flatten(concurrency=4)
.throttle(per_second=64, interval=datetime.timedelta(seconds=1))
.throttle(per_second=64, per_minute=inf, per_hour=inf, interval=datetime.timedelta(seconds=1))
.observe('foos')
.catch(TypeError, when=bool, finally_raise=True)
.catch(TypeError, when=bool, replacement=None, finally_raise=True)
Expand Down Expand Up @@ -1044,9 +1045,18 @@ def test_throttle(self) -> None:
msg="`throttle` should raise error when called with `per_second` < 1.",
):
list(Stream([1]).throttle(per_second=0))
with self.assertRaises(
ValueError,
msg="`throttle` should raise error when called with `per_minute` < 1.",
):
list(Stream([1]).throttle(per_minute=0))
with self.assertRaises(
ValueError,
msg="`throttle` should raise error when called with `per_hour` < 1.",
):
list(Stream([1]).throttle(per_hour=0))

# test interval

interval_seconds = 0.3
super_slow_elem_pull_seconds = 1
N = 10
Expand Down Expand Up @@ -1088,6 +1098,18 @@ def test_throttle(self) -> None:
msg="`throttle` should avoid 'ValueError: sleep length must be non-negative' when upstream is slower than `interval`",
)

# test periods pruning
stream = Stream(range(11)).throttle(per_second=2)
self.assertEqual(
len(cast(ThrottlingPerPeriodIterator, iter(stream)).restrictive_periods),
1,
)
stream = Stream(range(11)).throttle(per_second=2, per_hour=1000)
self.assertEqual(
len(cast(ThrottlingPerPeriodIterator, iter(stream)).restrictive_periods),
2,
)

# test per_second

N = 11
Expand All @@ -1106,6 +1128,27 @@ def test_throttle(self) -> None:
msg="`throttle` must slow according to `per_second`",
)

# per_second and per_minute
N = 11
assert N % 2
per_minute = 8
per_second = 2
duration, res = timestream(
Stream(range(11)).throttle(per_second=per_second, per_minute=per_minute)
)
self.assertEqual(
res,
list(range(11)),
msg="`throttle` with `per_second` must yield upstream elements",
)
expected_duration = (N // per_minute) * 60 + (N % per_minute) // per_second
self.assertAlmostEqual(
duration,
expected_duration,
delta=0.01 * expected_duration,
msg="`throttle` must slow according to `per_second` and `per_minute`",
)

# test both

expected_duration = 2
Expand Down

0 comments on commit ba0785d

Please sign in to comment.