Skip to content

Commit

Permalink
.throttle: per_minute/per_hour
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Oct 7, 2024
1 parent 5a2a43a commit c2ebac9
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 13 deletions.
17 changes: 14 additions & 3 deletions streamable/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
OSConcurrentMappingIterable,
RaisingIterator,
ThrottlingIntervalIterator,
ThrottlingPerSecondIterator,
ThrottlingPerPeriodIterator,
TruncatingOnCountIterator,
TruncatingOnPredicateIterator,
)
Expand All @@ -37,6 +37,8 @@
validate_group_size,
validate_iterator,
validate_throttle_interval,
validate_throttle_per_hour,
validate_throttle_per_minute,
validate_throttle_per_second,
validate_truncate_args,
)
Expand Down Expand Up @@ -155,14 +157,23 @@ def observe(iterator: Iterator[T], what: str) -> Iterator[T]:
def throttle(
iterator: Iterator[T],
per_second: int = cast(int, float("inf")),
per_minute: int = cast(int, float("inf")),
per_hour: int = cast(int, float("inf")),
interval: datetime.timedelta = datetime.timedelta(0),
) -> Iterator[T]:
validate_iterator(iterator)
validate_throttle_per_second(per_second)
validate_throttle_per_minute(per_minute)
validate_throttle_per_hour(per_hour)
validate_throttle_interval(interval)

if per_second < float("inf"):
iterator = ThrottlingPerSecondIterator(iterator, per_second)
if per_second + per_minute + per_hour < float("inf"):
iterator = ThrottlingPerPeriodIterator(
iterator,
per_second,
per_minute,
per_hour,
)
if interval > datetime.timedelta(0):
iterator = ThrottlingIntervalIterator(iterator, interval.total_seconds())
return iterator
Expand Down
56 changes: 53 additions & 3 deletions streamable/iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,31 +299,81 @@ def __next__(self) -> T:
return elem


class ThrottlingPerSecondIterator(Iterator[T]):
def __init__(self, iterator: Iterator[T], per_second: int) -> None:
class ThrottlingPerPeriodIterator(Iterator[T]):
def __init__(
self,
iterator: Iterator[T],
per_second: int,
per_minute: int,
per_hour: int,
) -> None:
self.iterator = iterator

self.per_second = per_second
self.per_minute = per_minute
self.per_hour = per_hour

self.second: int = -1
self.minute: int = -1
self.hour: int = -1

self.yields_in_second = 0
self.yields_in_minute = 0
self.yields_in_hour = 0

self.offset: Optional[float] = None

def __next__(self) -> T:
current_time = time.time()
if not self.offset:
self.offset = current_time
current_time -= self.offset

current_second = int(current_time)
if self.second != current_second:
self.second = current_second
self.yields_in_second = 0

current_minute = current_second // 60
if self.minute != current_minute:
self.minute = current_minute
self.yields_in_minute = 0

current_hour = current_minute // 60
if self.hour != current_hour:
self.hour = current_hour
self.yields_in_hour = 0

if self.yields_in_second >= self.per_second:
# sleep until next second
time.sleep(ceil(current_time) - current_time)
per_second_constraint_sleep = ceil(current_time) - current_time
return next(self)

if self.yields_in_minute >= self.per_minute:
# sleep until next minute
per_minute_constraint_sleep = (
ceil(current_time / 60) - current_time / 60
) * 60
return next(self)

if self.yields_in_hour >= self.per_hour:
# sleep until next hour
per_hour_constraint_sleep = (
ceil(current_time / 3600) - current_time / 3600
) * 3600
return next(self)

to_sleep = max(
per_second_constraint_sleep,
per_minute_constraint_sleep,
per_hour_constraint_sleep,
)
if to_sleep:
time.sleep(to_sleep)

self.yields_in_second += 1
self.yields_in_minute += 1
self.yields_in_hour += 1
return next(self.iterator)


Expand Down
27 changes: 22 additions & 5 deletions streamable/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
validate_group_interval,
validate_group_size,
validate_throttle_interval,
validate_throttle_per_hour,
validate_throttle_per_minute,
validate_throttle_per_second,
validate_truncate_args,
)
Expand Down Expand Up @@ -354,23 +356,31 @@ def observe(self, what: str = "elements") -> "Stream[T]":
def throttle(
self,
per_second: int = cast(int, float("inf")),
per_minute: int = cast(int, float("inf")),
per_hour: int = cast(int, float("inf")),
interval: datetime.timedelta = datetime.timedelta(0),
) -> "Stream[T]":
"""
Slows the iteration down to ensure both:
Slows iteration to respect:
- a maximum number of yields `per_second`
- a minimum `interval` between yields`
- a maximum number of yields `per_minute`
- a maximum number of yields `per_hour`
- a minimum `interval` elapses between yields
Args:
per_second (float, optional): Maximum number of yields per second (no limit by default).
per_minute (float, optional): Maximum number of yields per minute (no limit by default).
per_hour (float, optional): Maximum number of yields per hour (no limit by default).
interval (datetime.timedelta, optional): Minimum span of time between yields (no limit by default).
Returns:
Stream[T]: A stream yielding upstream elements slower, according to `per_second` and `interval` limits.
Stream[T]: A stream yielding upstream elements under the provided rate constraints.
"""
validate_throttle_per_second(per_second)
validate_throttle_per_minute(per_minute)
validate_throttle_per_hour(per_hour)
validate_throttle_interval(interval)
return ThrottleStream(self, per_second, interval)
return ThrottleStream(self, per_second, per_minute, per_hour, interval)

def truncate(
self, count: Optional[int] = None, when: Optional[Callable[[T], Any]] = None
Expand Down Expand Up @@ -544,10 +554,17 @@ def accept(self, visitor: "Visitor[V]") -> V:

class ThrottleStream(DownStream[T, T]):
def __init__(
self, upstream: Stream[T], per_second: int, interval: datetime.timedelta
self,
upstream: Stream[T],
per_second: int,
per_minute: int,
per_hour: int,
interval: datetime.timedelta,
) -> None:
super().__init__(upstream)
self._per_second = per_second
self._per_minute = per_minute
self._per_hour = per_hour
self._interval = interval

def accept(self, visitor: "Visitor[V]") -> V:
Expand Down
14 changes: 14 additions & 0 deletions streamable/util/validationtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def validate_throttle_per_second(per_second: int):
)


def validate_throttle_per_minute(per_minute: int):
if per_minute < 1:
raise ValueError(
f"`per_minute` is the maximum number of elements to yield per minute, it must be >= 1 but got {per_minute}."
)


def validate_throttle_per_hour(per_hour: int):
if per_hour < 1:
raise ValueError(
f"`per_hour` is the maximum number of elements to yield per hour, it must be >= 1 but got {per_hour}."
)


def validate_throttle_interval(interval: datetime.timedelta):
if interval < datetime.timedelta(0):
raise ValueError(
Expand Down
6 changes: 5 additions & 1 deletion streamable/visitors/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def visit_observe_stream(self, stream: ObserveStream[T]) -> Iterator[T]:

def visit_throttle_stream(self, stream: ThrottleStream[T]) -> Iterator[T]:
return functions.throttle(
stream.upstream.accept(self), stream._per_second, stream._interval
stream.upstream.accept(self),
stream._per_second,
stream._per_minute,
stream._per_hour,
stream._interval,
)

def visit_truncate_stream(self, stream: TruncateStream[T]) -> Iterator[T]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class CustomCallable:
.group(size=100, by=None, interval=None)
.observe('groups')
.flatten(concurrency=4)
.throttle(per_second=64, interval=datetime.timedelta(seconds=1))
.throttle(per_second=64, per_minute=inf, per_hour=inf, interval=datetime.timedelta(seconds=1))
.observe('foos')
.catch(TypeError, when=bool, finally_raise=True)
.catch(TypeError, when=bool, replacement=None, finally_raise=True)
Expand Down

0 comments on commit c2ebac9

Please sign in to comment.