diff --git a/streamable/functions.py b/streamable/functions.py index b04ccd6..78456fd 100644 --- a/streamable/functions.py +++ b/streamable/functions.py @@ -24,7 +24,7 @@ OSConcurrentMappingIterable, RaisingIterator, ThrottlingIntervalIterator, - ThrottlingPerSecondIterator, + ThrottlingPerPeriodIterator, TruncatingOnCountIterator, TruncatingOnPredicateIterator, ) @@ -37,7 +37,7 @@ validate_group_size, validate_iterator, validate_throttle_interval, - validate_throttle_per_second, + validate_throttle_per_period, validate_truncate_args, ) @@ -155,14 +155,27 @@ def observe(iterator: Iterator[T], what: str) -> Iterator[T]: def throttle( iterator: Iterator[T], per_second: int = cast(int, float("inf")), + per_minute: int = cast(int, float("inf")), + per_hour: int = cast(int, float("inf")), interval: datetime.timedelta = datetime.timedelta(0), ) -> Iterator[T]: validate_iterator(iterator) - validate_throttle_per_second(per_second) + validate_throttle_per_period("per_second", per_second) + validate_throttle_per_period("per_minute", per_minute) + validate_throttle_per_period("per_hour", per_hour) validate_throttle_interval(interval) - if per_second < float("inf"): - iterator = ThrottlingPerSecondIterator(iterator, per_second) + if any( + per_period < float("inf") for per_period in (per_second, per_minute, per_hour) + ): + iterator = ThrottlingPerPeriodIterator( + iterator, + [ + ThrottlingPerPeriodIterator.RestrictivePeriod(1, per_second), + ThrottlingPerPeriodIterator.RestrictivePeriod(60, per_minute), + ThrottlingPerPeriodIterator.RestrictivePeriod(3660, per_hour), + ], + ) if interval > datetime.timedelta(0): iterator = ThrottlingIntervalIterator(iterator, interval.total_seconds()) return iterator diff --git a/streamable/iters.py b/streamable/iters.py index aefe3b6..0ea8414 100644 --- a/streamable/iters.py +++ b/streamable/iters.py @@ -299,13 +299,28 @@ def __next__(self) -> T: return elem -class ThrottlingPerSecondIterator(Iterator[T]): - def __init__(self, iterator: Iterator[T], per_second: int) -> None: +class ThrottlingPerPeriodIterator(Iterator[T]): + class RestrictivePeriod(NamedTuple): + seconds: int + max_yields: int + + def __init__( + self, + iterator: Iterator[T], + restrictive_periods: List[RestrictivePeriod], + ) -> None: self.iterator = iterator - self.per_second = per_second - self.second: int = -1 - self.yields_in_second = 0 + self.restrictive_periods = [ + restrictive_period + for restrictive_period in restrictive_periods + if restrictive_period.max_yields < float("inf") + ] + + self.floor_time_in_period_unit: List[int] = [-1] * len(self.restrictive_periods) + + self.yields_in_periods = [0] * len(self.restrictive_periods) + self.offset: Optional[float] = None def __next__(self) -> T: @@ -313,17 +328,30 @@ def __next__(self) -> T: if not self.offset: self.offset = current_time current_time -= self.offset - current_second = int(current_time) - if self.second != current_second: - self.second = current_second - self.yields_in_second = 0 - - if self.yields_in_second >= self.per_second: - # sleep until next second - time.sleep(ceil(current_time) - current_time) + + to_sleep = 0.0 + for index, (period, max_yields_per_period) in enumerate( + self.restrictive_periods + ): + time_in_period_unit = current_time / period + floor_time_in_period_unit = int(time_in_period_unit) + if self.floor_time_in_period_unit[index] != floor_time_in_period_unit: + self.floor_time_in_period_unit[index] = floor_time_in_period_unit + self.yields_in_periods[index] = 0 + + if self.yields_in_periods[index] >= max_yields_per_period: + # sleep until next period + to_sleep = max( + to_sleep, + (ceil(time_in_period_unit) - time_in_period_unit) * period, + ) + + if to_sleep: + time.sleep(to_sleep) return next(self) - self.yields_in_second += 1 + for index in range(len(self.yields_in_periods)): + self.yields_in_periods[index] += 1 return next(self.iterator) diff --git a/streamable/stream.py b/streamable/stream.py index 95d8aca..1785cf9 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -27,7 +27,7 @@ validate_group_interval, validate_group_size, validate_throttle_interval, - validate_throttle_per_second, + validate_throttle_per_period, validate_truncate_args, ) @@ -354,23 +354,31 @@ def observe(self, what: str = "elements") -> "Stream[T]": def throttle( self, per_second: int = cast(int, float("inf")), + per_minute: int = cast(int, float("inf")), + per_hour: int = cast(int, float("inf")), interval: datetime.timedelta = datetime.timedelta(0), ) -> "Stream[T]": """ - Slows the iteration down to ensure both: + Slows iteration to respect: - a maximum number of yields `per_second` - - a minimum `interval` between yields` + - a maximum number of yields `per_minute` + - a maximum number of yields `per_hour` + - a minimum `interval` elapses between yields Args: per_second (float, optional): Maximum number of yields per second (no limit by default). + per_minute (float, optional): Maximum number of yields per minute (no limit by default). + per_hour (float, optional): Maximum number of yields per hour (no limit by default). interval (datetime.timedelta, optional): Minimum span of time between yields (no limit by default). Returns: - Stream[T]: A stream yielding upstream elements slower, according to `per_second` and `interval` limits. + Stream[T]: A stream yielding upstream elements under the provided rate constraints. """ - validate_throttle_per_second(per_second) + validate_throttle_per_period("per_second", per_second) + validate_throttle_per_period("per_minute", per_minute) + validate_throttle_per_period("per_hour", per_hour) validate_throttle_interval(interval) - return ThrottleStream(self, per_second, interval) + return ThrottleStream(self, per_second, per_minute, per_hour, interval) def truncate( self, count: Optional[int] = None, when: Optional[Callable[[T], Any]] = None @@ -544,10 +552,17 @@ def accept(self, visitor: "Visitor[V]") -> V: class ThrottleStream(DownStream[T, T]): def __init__( - self, upstream: Stream[T], per_second: int, interval: datetime.timedelta + self, + upstream: Stream[T], + per_second: int, + per_minute: int, + per_hour: int, + interval: datetime.timedelta, ) -> None: super().__init__(upstream) self._per_second = per_second + self._per_minute = per_minute + self._per_hour = per_hour self._interval = interval def accept(self, visitor: "Visitor[V]") -> V: diff --git a/streamable/util/validationtools.py b/streamable/util/validationtools.py index 81b1211..1ecec2c 100644 --- a/streamable/util/validationtools.py +++ b/streamable/util/validationtools.py @@ -29,10 +29,10 @@ def validate_group_interval(interval: Optional[datetime.timedelta]): raise ValueError(f"`interval` should be positive but got {repr(interval)}.") -def validate_throttle_per_second(per_second: int): - if per_second < 1: +def validate_throttle_per_period(per_period_arg_name: str, value: int): + if value < 1: raise ValueError( - f"`per_second` is the maximum number of elements to yield per second, it must be >= 1 but got {per_second}." + f"`{per_period_arg_name}` is the maximum number of elements to yield per second, it must be >= 1 but got {value}." ) diff --git a/streamable/visitors/iterator.py b/streamable/visitors/iterator.py index 407ff4b..3ea7f3b 100644 --- a/streamable/visitors/iterator.py +++ b/streamable/visitors/iterator.py @@ -101,7 +101,11 @@ def visit_observe_stream(self, stream: ObserveStream[T]) -> Iterator[T]: def visit_throttle_stream(self, stream: ThrottleStream[T]) -> Iterator[T]: return functions.throttle( - stream.upstream.accept(self), stream._per_second, stream._interval + stream.upstream.accept(self), + stream._per_second, + stream._per_minute, + stream._per_hour, + stream._interval, ) def visit_truncate_stream(self, stream: TruncateStream[T]) -> Iterator[T]: diff --git a/streamable/visitors/representation.py b/streamable/visitors/representation.py index 28cf230..c0d675e 100644 --- a/streamable/visitors/representation.py +++ b/streamable/visitors/representation.py @@ -85,7 +85,7 @@ def visit_observe_stream(self, stream: ObserveStream[T]) -> str: def visit_throttle_stream(self, stream: ThrottleStream[T]) -> str: self.methods_reprs.append( - f"throttle(per_second={stream._per_second}, interval={self._friendly_repr(stream._interval)})" + f"throttle(per_second={stream._per_second}, per_minute={stream._per_minute}, per_hour={stream._per_hour}, interval={self._friendly_repr(stream._interval)})" ) return stream.upstream.accept(self) diff --git a/tests/test_stream.py b/tests/test_stream.py index 3e6c502..4987674 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -23,6 +23,7 @@ from streamable import Stream from streamable.functions import NoopStopIteration +from streamable.iters import ThrottlingPerPeriodIterator from streamable.util.functiontools import sidify T = TypeVar("T") @@ -254,7 +255,7 @@ class CustomCallable: .group(size=100, by=None, interval=None) .observe('groups') .flatten(concurrency=4) - .throttle(per_second=64, interval=datetime.timedelta(seconds=1)) + .throttle(per_second=64, per_minute=inf, per_hour=inf, interval=datetime.timedelta(seconds=1)) .observe('foos') .catch(TypeError, when=bool, finally_raise=True) .catch(TypeError, when=bool, replacement=None, finally_raise=True) @@ -1044,9 +1045,18 @@ def test_throttle(self) -> None: msg="`throttle` should raise error when called with `per_second` < 1.", ): list(Stream([1]).throttle(per_second=0)) + with self.assertRaises( + ValueError, + msg="`throttle` should raise error when called with `per_minute` < 1.", + ): + list(Stream([1]).throttle(per_minute=0)) + with self.assertRaises( + ValueError, + msg="`throttle` should raise error when called with `per_hour` < 1.", + ): + list(Stream([1]).throttle(per_hour=0)) # test interval - interval_seconds = 0.3 super_slow_elem_pull_seconds = 1 N = 10 @@ -1088,6 +1098,18 @@ def test_throttle(self) -> None: msg="`throttle` should avoid 'ValueError: sleep length must be non-negative' when upstream is slower than `interval`", ) + # test periods pruning + stream = Stream(range(11)).throttle(per_second=2) + self.assertEqual( + len(cast(ThrottlingPerPeriodIterator, iter(stream)).restrictive_periods), + 1, + ) + stream = Stream(range(11)).throttle(per_second=2, per_hour=1000) + self.assertEqual( + len(cast(ThrottlingPerPeriodIterator, iter(stream)).restrictive_periods), + 2, + ) + # test per_second N = 11 @@ -1106,6 +1128,27 @@ def test_throttle(self) -> None: msg="`throttle` must slow according to `per_second`", ) + # per_second and per_minute + N = 11 + assert N % 2 + per_minute = 8 + per_second = 2 + duration, res = timestream( + Stream(range(11)).throttle(per_second=per_second, per_minute=per_minute) + ) + self.assertEqual( + res, + list(range(11)), + msg="`throttle` with `per_second` must yield upstream elements", + ) + expected_duration = (N // per_minute) * 60 + (N % per_minute) // per_second + self.assertAlmostEqual( + duration, + expected_duration, + delta=0.01 * expected_duration, + msg="`throttle` must slow according to `per_second` and `per_minute`", + ) + # test both expected_duration = 2