Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ebonnal committed Oct 14, 2024
1 parent 2a26a3e commit 47607b6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
17 changes: 12 additions & 5 deletions streamable/iters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Union,
cast,
)
import weakref

from streamable.util.functiontools import catch_and_raise_as

Expand Down Expand Up @@ -401,12 +402,13 @@ def __init__(
self.buffer_size = buffer_size
self.ordered = ordered

def _context_manager(self) -> ContextManager:
@property
def _context_manager(self) -> Callable[[], Optional[ContextManager]]:
@contextmanager
def dummy_context_manager_generator():
yield

return dummy_context_manager_generator()
return lambda: dummy_context_manager_generator()

@abstractmethod
def _launch_task(
Expand All @@ -420,7 +422,11 @@ def _future_result_collection(
) -> FutureResultCollection[Union[U, RaisingIterator.ExceptionContainer]]: ...

def __iter__(self) -> Iterator[Union[U, RaisingIterator.ExceptionContainer]]:
with self._context_manager():
context_manager = self._context_manager()
if context_manager is None:
raise ValueError("context manager is None")

with context_manager:
future_results = self._future_result_collection()

# queue tasks up to buffer_size
Expand Down Expand Up @@ -452,12 +458,13 @@ def __init__(
self.executor: Executor
self.via_processes = via_processes

def _context_manager(self) -> ContextManager:
@property
def _context_manager(self) -> Callable[[], Optional[ContextManager]]:
if self.via_processes:
self.executor = ProcessPoolExecutor(max_workers=self.concurrency)
else:
self.executor = ThreadPoolExecutor(max_workers=self.concurrency)
return self.executor
return weakref.ref(self.executor)

# picklable
@staticmethod
Expand Down
6 changes: 1 addition & 5 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,7 @@ def test_map(self, concurrency) -> None:
)
def test_process_concurrency(
self, ordered, order_mutation
) -> None: # pragma: no cover
import sys

if sys.version < "3.9.0":
return
) -> None:

lambda_identity = lambda x: x * 10

Expand Down

0 comments on commit 47607b6

Please sign in to comment.