Skip to content

Commit

Permalink
Merge pull request #61 from MDAnalysis/test-with-scheduler-fixture
Browse files Browse the repository at this point in the history
test custom with scheduler fixture
  • Loading branch information
kain88-de authored Sep 21, 2018
2 parents d56fd7e + 58b1281 commit 3d78fc7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 36 deletions.
31 changes: 31 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
#
# PMDA
# Copyright (c) 2017 The MDAnalysis Development Team and contributors
# (see the file AUTHORS for the full list of names)
#
# Released under the GNU Public Licence, v2 or any higher version

from dask import distributed, multiprocessing
import pytest

@pytest.fixture(scope="session", params=(1, 2))
def client(tmpdir_factory, request):
with tmpdir_factory.mktemp("dask_cluster").as_cwd():
lc = distributed.LocalCluster(n_workers=request.param, processes=True)
client = distributed.Client(lc)

yield client

client.close()
lc.close()


@pytest.fixture(scope='session', params=('distributed', 'multiprocessing'))
def scheduler(request, client):
if request.param == 'distributed':
return client
else:
return multiprocessing
2 changes: 1 addition & 1 deletion pmda/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _conclude(self, ):
@staticmethod
def _reduce(res, result_single_frame):
""" 'add' action for an accumulator"""
if res == []:
if isinstance(res, list) and len(res) == 0:
# Convert res from an empty list to a numpy array
# which has the same shape as the single frame result
res = result_single_frame
Expand Down
32 changes: 17 additions & 15 deletions pmda/test/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,33 @@
# Released under the GNU Public Licence, v2 or any higher version
from __future__ import absolute_import, division

import pytest

import numpy as np

from numpy.testing import assert_equal

import MDAnalysis as mda
from pmda import custom

from MDAnalysisTests.datafiles import PSF, DCD
from MDAnalysisTests.util import no_deprecated_call
import pytest
from numpy.testing import assert_equal

from pmda import custom


def custom_function(mobile):
return mobile.center_of_geometry()


def test_AnalysisFromFunction():
def test_AnalysisFromFunction(scheduler):
u = mda.Universe(PSF, DCD)
step = 2
ana1 = custom.AnalysisFromFunction(custom_function, u,
u.atoms).run(step=step)
ana2 = custom.AnalysisFromFunction(custom_function, u,
u.atoms).run(step=step)
ana3 = custom.AnalysisFromFunction(custom_function, u,
u.atoms).run(step=step)
ana1 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
)
ana2 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
)
ana3 = custom.AnalysisFromFunction(custom_function, u, u.atoms).run(
step=step, scheduler=scheduler
)

results = []
for ts in u.trajectory[::step]:
Expand All @@ -53,8 +54,9 @@ def test_AnalysisFromFunction_otherAgs():
u2 = mda.Universe(PSF, DCD)
u3 = mda.Universe(PSF, DCD)
step = 2
ana = custom.AnalysisFromFunction(custom_function_2, u, u.atoms, u2.atoms,
u3.atoms).run(step=step)
ana = custom.AnalysisFromFunction(
custom_function_2, u, u.atoms, u2.atoms, u3.atoms
).run(step=step)

results = []
for ts in u.trajectory[::step]:
Expand Down
20 changes: 0 additions & 20 deletions pmda/test/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,6 @@ def test_sub_frames(analysis, n_jobs):
np.testing.assert_almost_equal(analysis.res, [10, 20, 30, 40])


@pytest.fixture(scope="session")
def client(tmpdir_factory):
with tmpdir_factory.mktemp("dask_cluster").as_cwd():
lc = distributed.LocalCluster(n_workers=2, processes=True)
client = distributed.Client(lc)

yield client

client.close()
lc.close()


@pytest.fixture(scope='session', params=('distributed', 'multiprocessing'))
def scheduler(request, client):
if request.param == 'distributed':
return client
else:
return multiprocessing


def test_scheduler(analysis, scheduler):
analysis.run(scheduler=scheduler)

Expand Down

0 comments on commit 3d78fc7

Please sign in to comment.