Skip to content

Commit

Permalink
__rshift__ + no typing errors outside test
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Dec 3, 2023
1 parent 9baca3c commit 03aa35c
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion kioss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kioss._plan import SourcePipe as Pipe
from kioss._plan import APipe as Pipe
from kioss._util import LOGGER
from kioss import _plan, _visitor
_plan.APipe.ITERATOR_GENERATING_VISITOR_CLASS = _visitor.IteratorGeneratingVisitor
2 changes: 1 addition & 1 deletion kioss/_concurrent_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __iter__(self) -> Iterator[Union[R, _ExceptionContainer]]:



class ThreadedFlatteningIteratorWrapper(ThreadedMappingIteratorWrapper[Union[T, _Skip]]):
class ThreadedFlatteningIteratorWrapper(ThreadedMappingIteratorWrapper[T]):
_SKIP: _Skip = _Skip()
_BUFFER_SIZE = 32
_INIT_RETRY_BACKFOFF = 0.0005
Expand Down
7 changes: 6 additions & 1 deletion kioss/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import (
Any,
Callable,
Collection,
Generic,
Iterable,
Iterator,
List,
Expand All @@ -22,7 +24,7 @@

class APipe(Iterable[T], ABC):
ITERATOR_GENERATING_VISITOR_CLASS: "Optional[Type[_visitor.IteratorGeneratingVisitor]]" = None
def __init__(self, upstream: "APipe[U]"):
def __init__(self, upstream: "APipe"):
self.upstream = upstream
if self.ITERATOR_GENERATING_VISITOR_CLASS is None:
raise ValueError("ITERATOR_GENERATING_VISITOR_CLASS not instantiated")
Expand Down Expand Up @@ -238,6 +240,9 @@ def register_error_sample(error):

return samples

@classmethod
def __rshift__(cls, source: Callable[[], Iterator[T]]) -> "SourcePipe[T]":
return SourcePipe(source)

class SourcePipe(APipe[T]):
def __init__(self, source: Callable[[], Iterator[T]]):
Expand Down
11 changes: 6 additions & 5 deletions tests/test_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ def store_error_types(error):
# test rasing:
self.assertRaises(
ValueError,
Pipe([map(int, "12-3")].__iter__).flatten(n_threads=n_threads).collect,
Pipe([map(int, "12-3")].__iter__).flatten(n_threads=n_threads).collect, # type: ignore
)
self.assertRaises(
ValueError,
Pipe(lambda: map(int, "-")).flatten(n_threads=n_threads).collect,
Pipe(lambda: map(int, "-")).flatten(n_threads=n_threads).collect, # type: ignore
)

def test_add(self):
Expand Down Expand Up @@ -220,10 +220,11 @@ def test_map(self, n_threads: int):
),
set(map(func, range(1, N))),
)
l: List[List[int]] = [[1], [], [3]]
self.assertSetEqual(
set(
Pipe([[1], [], [3]].__iter__)
.map(iter)
Pipe(l.__iter__)
.map(lambda l: iter(l))
.map(next, n_threads=n_threads)
.catch(RuntimeError)
),
Expand Down Expand Up @@ -505,7 +506,7 @@ def test_invalid_source(self):
@parameterized.expand([[1], [2], [3]])
def test_invalid_flatten_upstream(self, n_threads: int):
self.assertRaises(
TypeError, Pipe(range(3).__iter__).flatten(n_threads=n_threads).collect
TypeError, Pipe(range(3).__iter__).flatten(n_threads=n_threads).collect # type: ignore
)

def test_planning_and_execution_decoupling(self):
Expand Down

0 comments on commit 03aa35c

Please sign in to comment.