diff --git a/kioss/__init__.py b/kioss/__init__.py index 693c771..c8d4c57 100644 --- a/kioss/__init__.py +++ b/kioss/__init__.py @@ -1,4 +1,4 @@ -from kioss._plan import SourcePipe as Pipe +from kioss._plan import APipe as Pipe from kioss._util import LOGGER from kioss import _plan, _visitor _plan.APipe.ITERATOR_GENERATING_VISITOR_CLASS = _visitor.IteratorGeneratingVisitor \ No newline at end of file diff --git a/kioss/_concurrent_exec.py b/kioss/_concurrent_exec.py index 39fbebf..6477fa4 100644 --- a/kioss/_concurrent_exec.py +++ b/kioss/_concurrent_exec.py @@ -86,7 +86,7 @@ def __iter__(self) -> Iterator[Union[R, _ExceptionContainer]]: -class ThreadedFlatteningIteratorWrapper(ThreadedMappingIteratorWrapper[Union[T, _Skip]]): +class ThreadedFlatteningIteratorWrapper(ThreadedMappingIteratorWrapper[T]): _SKIP: _Skip = _Skip() _BUFFER_SIZE = 32 _INIT_RETRY_BACKFOFF = 0.0005 diff --git a/kioss/_plan.py b/kioss/_plan.py index 9ec0461..ac45d22 100644 --- a/kioss/_plan.py +++ b/kioss/_plan.py @@ -2,6 +2,8 @@ from typing import ( Any, Callable, + Collection, + Generic, Iterable, Iterator, List, @@ -22,7 +24,7 @@ class APipe(Iterable[T], ABC): ITERATOR_GENERATING_VISITOR_CLASS: "Optional[Type[_visitor.IteratorGeneratingVisitor]]" = None - def __init__(self, upstream: "APipe[U]"): + def __init__(self, upstream: "APipe"): self.upstream = upstream if self.ITERATOR_GENERATING_VISITOR_CLASS is None: raise ValueError("ITERATOR_GENERATING_VISITOR_CLASS not instantiated") @@ -238,6 +240,9 @@ def register_error_sample(error): return samples + @classmethod + def __rshift__(cls, source: Callable[[], Iterator[T]]) -> "SourcePipe[T]": + return SourcePipe(source) class SourcePipe(APipe[T]): def __init__(self, source: Callable[[], Iterator[T]]): diff --git a/tests/test_legacy.py b/tests/test_legacy.py index 2a4d495..b2ab58d 100644 --- a/tests/test_legacy.py +++ b/tests/test_legacy.py @@ -181,11 +181,11 @@ def store_error_types(error): # test rasing: self.assertRaises( ValueError, - Pipe([map(int, "12-3")].__iter__).flatten(n_threads=n_threads).collect, + Pipe([map(int, "12-3")].__iter__).flatten(n_threads=n_threads).collect, # type: ignore ) self.assertRaises( ValueError, - Pipe(lambda: map(int, "-")).flatten(n_threads=n_threads).collect, + Pipe(lambda: map(int, "-")).flatten(n_threads=n_threads).collect, # type: ignore ) def test_add(self): @@ -220,10 +220,11 @@ def test_map(self, n_threads: int): ), set(map(func, range(1, N))), ) + l: List[List[int]] = [[1], [], [3]] self.assertSetEqual( set( - Pipe([[1], [], [3]].__iter__) - .map(iter) + Pipe(l.__iter__) + .map(lambda l: iter(l)) .map(next, n_threads=n_threads) .catch(RuntimeError) ), @@ -505,7 +506,7 @@ def test_invalid_source(self): @parameterized.expand([[1], [2], [3]]) def test_invalid_flatten_upstream(self, n_threads: int): self.assertRaises( - TypeError, Pipe(range(3).__iter__).flatten(n_threads=n_threads).collect + TypeError, Pipe(range(3).__iter__).flatten(n_threads=n_threads).collect # type: ignore ) def test_planning_and_execution_decoupling(self):