Skip to content

Commit

Permalink
pass correct mpi size to ScanWindow; bug fixed in ScanWindow legend
Browse files Browse the repository at this point in the history
Closes #46.
  • Loading branch information
leofang committed May 28, 2019
1 parent d207237 commit 787a35e
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 51 deletions.
48 changes: 8 additions & 40 deletions nsls2ptycho/core/ptycho_recon.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from PyQt5 import QtCore
from datetime import datetime
from nsls2ptycho.core.ptycho_param import Param
import mpi4py
mpi4py.rc.initialize = False
from mpi4py import MPI
import sys, os
import pickle # dump param into disk
import subprocess # call mpirun from shell
Expand All @@ -17,6 +14,7 @@
print('[!] Unable to import core.HXN_databroker packages some features will '
'be unavailable')
print('[!] (import error: {})'.format(ex))
from nsls2ptycho.core.utils import use_mpi_machinefile


class PtychoReconWorker(QtCore.QThread):
Expand Down Expand Up @@ -86,46 +84,16 @@ def _parse_one_line(self):
return stdout_2.split()

def recon_api(self, param:Param, update_fcn=None):
# working version
# "1" is just a placeholder to be overwritten soon
mpirun_command = ["mpirun", "-n", "1", "python", "-W", "ignore", "-m","nsls2ptycho.core.ptycho.recon_ptycho_gui"]

if param.gpu_flag:
num_processes = str(len(param.gpus))
mpirun_command[2] = str(len(param.gpus))
elif param.mpi_file_path == '':
mpirun_command[2] = str(param.processes) if param.processes > 1 else str(1)
else:
num_processes = str(param.processes) if param.processes > 1 else str(1)
mpirun_command = ["mpirun", "-n", num_processes, "python", "-W", "ignore", "-m","nsls2ptycho.core.ptycho.recon_ptycho_gui"]
mpirun_command = use_mpi_machinefile(mpirun_command, param.mpi_file_path)

if 'MPICH' in MPI.get_vendor()[0]:
mpirun_command.insert(-2, "-u") # force flush asap (MPICH is weird...)

# use MPI machine file if available, assuming each line of which is:
# ip_address slots=n max-slots=n --- Open MPI
# ip_address:n --- MPICH
if param.mpi_file_path != '':
with open(param.mpi_file_path, 'r') as f:
node_count = 0
if MPI.get_vendor()[0] == 'Open MPI':
for line in f:
line = line.split()
node_count += int(line[1].split('=')[-1])
mpirun_command.insert(3, "-machinefile")
# use mpirun to find where MPI is installed
import shutil
path = os.path.split(shutil.which('mpirun'))[0]
if path[-3:] == 'bin':
path = path[:-3]
mpirun_command[4:4] = ["--prefix", path, "-x", "PATH", "-x", "LD_LIBRARY_PATH"]
elif 'MPICH' in MPI.get_vendor()[0]:
for line in f:
line = line.split(":")
node_count += int(line[1])
mpirun_command.insert(3, "-f")
else:
raise RuntimeError("mpi4py is built on top of unrecognized MPI library. "
"Only Open MPI and MPICH are tested.")
mpirun_command[2] = str(node_count) # use all available nodes
mpirun_command.insert(4, param.mpi_file_path)
#param.gpus = range(node_count)
#print(" ".join(mpirun_command))

try:
self.return_value = None
with subprocess.Popen(mpirun_command,
Expand Down
55 changes: 55 additions & 0 deletions nsls2ptycho/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import matplotlib.cm as cm
import numpy as np

import mpi4py
mpi4py.rc.initialize = False
from mpi4py import MPI

from nsls2ptycho.core.ptycho.utils import split


Expand All @@ -26,10 +30,12 @@ def plot_point_process_distribution(pts, mpi_size, colormap=cm.jet):
plt.scatter(pts[0, a[i][0]:a[i][1]], pts[1, a[i][0]:a[i][1]], c=colors[i])
plt.show()


def find_owner(filename):
# from https://stackoverflow.com/a/1830635
return getpwuid(os.stat(filename).st_uid).pw_name


def clean_shared_memory(pid=None):
'''
This function cleans up shared memory segments created by the GUI or a buggy Open MPI.
Expand All @@ -54,3 +60,52 @@ def clean_shared_memory(pid=None):
s.unlink()

print("Done.")


def get_mpi_num_processes(mpi_file_path):
# use MPI machine file if available, assuming each line of which is:
# ip_address slots=n max-slots=n --- Open MPI
# ip_address:n --- MPICH
with open(mpi_file_path, 'r') as f:
node_count = 0
if MPI.get_vendor()[0] == 'Open MPI':
for line in f:
line = line.split()
node_count += int(line[1].split('=')[-1])
elif 'MPICH' in MPI.get_vendor()[0]:
for line in f:
line = line.split(":")
node_count += int(line[1])
else:
# TODO: support MVAPICH?
raise RuntimeError("mpi4py is built on top of unrecognized MPI library. "
"Only Open MPI and MPICH are tested.")

return node_count


def use_mpi_machinefile(mpirun_command, mpi_file_path):
# use MPI machine file if available, assuming each line of which is:
# ip_address slots=n max-slots=n --- Open MPI
# ip_address:n --- MPICH
node_count = get_mpi_num_processes(mpi_file_path)

if MPI.get_vendor()[0] == 'Open MPI':
mpirun_command.insert(3, "-machinefile")
# use mpirun to find where MPI is installed
import shutil
path = os.path.split(shutil.which('mpirun'))[0]
if path.endswith('bin'):
path = path[:-3]
mpirun_command[4:4] = ["--prefix", path, "-x", "PATH", "-x", "LD_LIBRARY_PATH"]
elif 'MPICH' in MPI.get_vendor()[0]:
mpirun_command.insert(-2, "-u") # force flush asap (MPICH is weird...)
mpirun_command.insert(3, "-f")
else:
# TODO: support MVAPICH?
raise RuntimeError("mpi4py is built on top of unrecognized MPI library. "
"Only Open MPI and MPICH are tested.")
mpirun_command[2] = str(node_count) # use all available nodes
mpirun_command.insert(4, mpi_file_path)

return mpirun_command
20 changes: 14 additions & 6 deletions nsls2ptycho/core/widgets/mplcanvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,23 @@ def update_scatter(self, pts, mpi_size, colormap=cm.jet):
- mpi_size: number of MPI processes
- colormap
'''
label_set = set([i for i in range(9)] + [i for i in range(mpi_size, mpi_size-6, -1)])
labels = []
idx = False
if len(self.line_handlers) == 0:
a = split(pts.shape[1], mpi_size)
colors = colormap(np.linspace(0, 1, len(a)))
label_set = set([i for i in range(9)] + [i for i in range(mpi_size, mpi_size-6, -1)])
for i in range(mpi_size):
if mpi_size <=15 or i in label_set:
label = 'Process %i'%i
s = rcParams['lines.markersize']**2 # matplotlib default
labels.append(label)
elif i==mpi_size-6 and i not in label_set:
label = r' $\vdots$'
s = 0
labels.append(label)
idx = True
else:
label = '_nolegend_' # matplotlib undocumented secret...
s = rcParams['lines.markersize']**2 # matplotlib default
h = self.axes.scatter(pts[0, a[i][0]:a[i][1]], pts[1, a[i][0]:a[i][1]], c=colors[i], label=label, s=s)
h = self.axes.scatter(pts[0, a[i][0]:a[i][1]], pts[1, a[i][0]:a[i][1]], c=colors[i], label=label)
self.line_handlers.append(h)
else: # assuming mpi_size is unchanged
a = split(pts.shape[1], mpi_size)
Expand All @@ -144,7 +146,13 @@ def update_scatter(self, pts, mpi_size, colormap=cm.jet):

# we have a rectangular window, make the plot align to its center left
self.axes.set_aspect(aspect='equal', anchor='W')
self.axes.legend(bbox_to_anchor=(0.98, 1.0), fancybox=True)
legend = self.axes.legend(bbox_to_anchor=(0.98, 1.0), fancybox=True)

# for the label \vdots, remove its marker
if idx:
legend.legendHandles[9].set_sizes([0])
#self.axes.legend(legend.legendHandles, labels, bbox_to_anchor=(0.98, 1.0), fancybox=True)

self.draw()


Expand Down
6 changes: 4 additions & 2 deletions nsls2ptycho/ptycho_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PyQt5.QtWidgets import QFileDialog, QAction

from nsls2ptycho.ui import ui_ptycho
from nsls2ptycho.core.utils import clean_shared_memory
from nsls2ptycho.core.utils import clean_shared_memory, get_mpi_num_processes
from nsls2ptycho.core.ptycho_param import Param
from nsls2ptycho.core.ptycho_recon import PtychoReconWorker, PtychoReconFakeWorker, HardWorker
from nsls2ptycho.core.ptycho_qt_utils import PtychoStream
Expand Down Expand Up @@ -545,8 +545,10 @@ def start(self, batch_mode=False):
# copied from nsls2ptycho/core/ptycho_recon.py
if self.param.gpu_flag:
num_processes = str(len(self.param.gpus))
else:
elif self.param.mpi_file_path == '':
num_processes = str(self.param.processes) if self.param.processes > 1 else str(1)
else:
num_processes = str(get_mpi_num_processes(self.param.mpi_file_path))
self.scanWindow.update_image(self._scan_points, int(num_processes))


Expand Down
6 changes: 3 additions & 3 deletions run-ptycho
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ export LD_LIBRARY_PATH=${CUDA_HOME}/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=${CUDA_HOME}/nvvm/lib64:$LD_LIBRARY_PATH

# Numba
export NUMBAPRO_LIBDEVICE=/usr/local/cuda/nvvm/libdevice/
export NUMBAPRO_NVVM=/usr/local/cuda/nvvm/lib64/libnvvm.so
export NUMBAPRO_LIBDEVICE=$CUDA_HOME/nvvm/libdevice/
export NUMBAPRO_NVVM=$CUDA_HOME/nvvm/lib64/libnvvm.so

# use the production conda environment
source /opt/conda/bin/activate ptycho_production

PTYCHO_HOME=/home/$USER/.ptycho_gui
PTYCHO_HOME=$HOME/.ptycho_gui
if [ ! -d $PTYCHO_HOME ]; then
mkdir $PTYCHO_HOME
fi
Expand Down

0 comments on commit 787a35e

Please sign in to comment.