Skip to content

Commit

Permalink
fix tests maybe
Browse files Browse the repository at this point in the history
  • Loading branch information
kain88-de committed Oct 30, 2018
1 parent 688ca22 commit 6770bca
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 36 deletions.
9 changes: 6 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Released under the GNU Public Licence, v2 or any higher version

from dask import distributed
import dask
import pytest


Expand All @@ -24,9 +25,11 @@ def client(tmpdir_factory, request):
lc.close()


@pytest.fixture(scope='session', params=('distributed', 'multiprocessing'))
@pytest.fixture(scope='session', params=('distributed', 'multiprocessing', 'single-threaded'))
def scheduler(request, client):
if request.param == 'distributed':
return client
arg = client
else:
return request.param
arg = request.param
with dask.config.set(scheduler=arg):
yield
43 changes: 30 additions & 13 deletions pmda/leaflet.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ def run(self,
start=None,
stop=None,
step=None,
scheduler=None,
n_jobs=-1,
cutoff=15.0):
"""Perform the calculation
Expand All @@ -244,35 +243,53 @@ def run(self,
stop frame of analysis
step : int, optional
number of frames to skip between each analysed frame
scheduler : dask scheduler, optional
Use dask scheduler, defaults to multiprocessing. This can be used
to spread work to a distributed scheduler
n_jobs : int, optional
number of tasks to start, if `-1` use number of logical cpu cores.
This argument will be ignored when the distributed scheduler is
used
"""
if scheduler is None:
# are we using a distributed scheduler or should we use multiprocessing?
scheduler = dask.config.get('scheduler', None)
if scheduler is None and client is None:
scheduler = 'multiprocessing'
elif scheduler is None:
# maybe we can grab a global worker
try:
from dask import distributed
scheduler = distributed.worker.get_client()
except ValueError:
pass
except ImportError:
pass

if n_jobs == -1:
n_jobs = cpu_count()

# we could not find a global scheduler to use and we ask for a single
# job. Therefore we run this on the single threaded scheduler for
# debugging.
if scheduler is None and n_jobs == 1:
scheduler = 'single-threaded'

if n_blocks is None:
if scheduler == 'multiprocessing':
n_jobs = cpu_count()
n_blocks = n_jobs
elif isinstance(scheduler, distributed.Client):
n_jobs = len(scheduler.ncores())
n_blocks = len(scheduler.ncores())
else:
raise ValueError(
"Couldn't guess ideal number of jobs from scheduler."
"Please provide `n_jobs` in call to method.")

with timeit() as b_universe:
universe = mda.Universe(self._top, self._traj)
n_blocks = 1
warnings.warn(
"Couldn't guess ideal number of blocks from scheduler. Set n_blocks=1"
"Please provide `n_blocks` in call to method.")

scheduler_kwargs = {'scheduler': scheduler}
if scheduler == 'multiprocessing':
scheduler_kwargs['num_workers'] = n_jobs

with timeit() as b_universe:
universe = mda.Universe(self._top, self._traj)

start, stop, step = self._trajectory.check_slice_indices(
start, stop, step)
with timeit() as total:
Expand Down
3 changes: 0 additions & 3 deletions pmda/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,6 @@ def run(self,
stop frame of analysis
step : int, optional
number of frames to skip between each analysed frame
scheduler : dask scheduler, optional
Use dask scheduler, defaults to multiprocessing. This can be used
to spread work to a distributed scheduler
n_jobs : int, optional
number of jobs to start, if `-1` use number of logical cpu cores.
This argument will be ignored when the distributed scheduler is
Expand Down
6 changes: 3 additions & 3 deletions pmda/test/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def test_AnalysisFromFunction(scheduler):
u = mda.Universe(PSF, DCD)
step = 2
ana1 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
step=step
)
ana2 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
step=step
)
ana3 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
step=step
)

results = []
Expand Down
15 changes: 1 addition & 14 deletions pmda/test/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def analysis():
return ana


def test_wrong_scheduler(analysis):
with pytest.raises(ValueError):
analysis.run(scheduler=2)


@pytest.mark.parametrize('n_jobs', (1, 2))
def test_all_frames(analysis, n_jobs):
analysis.run(n_jobs=n_jobs)
Expand All @@ -91,16 +86,8 @@ def test_no_frames(analysis, n_jobs):
assert analysis.timing.universe == 0


@pytest.fixture(scope='session', params=('distributed', 'multiprocessing'))
def scheduler(request, client):
if request.param == 'distributed':
return client
else:
return request.param


def test_scheduler(analysis, scheduler):
analysis.run(scheduler=scheduler)
analysis.run()


def test_nframes_less_nblocks_warning(analysis):
Expand Down

0 comments on commit 6770bca

Please sign in to comment.