From 57453d0ffc5488991ba9e302cf7d809b7d191208 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Tue, 5 Jan 2021 19:34:42 -0800 Subject: [PATCH 1/2] plotly figures for t1w --- pynets/reports/plotting.py | 58 ++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 2 files changed, 59 insertions(+) create mode 100644 pynets/reports/plotting.py diff --git a/pynets/reports/plotting.py b/pynets/reports/plotting.py new file mode 100644 index 00000000..4344158e --- /dev/null +++ b/pynets/reports/plotting.py @@ -0,0 +1,58 @@ +"""Plotting functions to embed into html reports.""" + +import nibabel as nib +import os +import numpy as np +import plotly.graph_objs as go +from plotly.subplots import make_subplots + + +def plot_t1w(t1w): + """Plot t1w images using plotly. + + Parameters + ---------- + t1w : str + Path to a t1w image. + + Returns + ------- + fig : plotly.graph_objs.Figure + A plotly figure that is ready-to-embed using the to_html method. + """ + + # Load data from file + t1w_arr = nib.load(t1w).get_fdata() + + # Space out z-slices + z_max = np.shape(t1w_arr)[2] + pad = 30 + z_slices = np.linspace(pad, z_max-pad, num=21, dtype=int) + + # Init figure + nrows = 3 + ncols = 7 + fig = make_subplots(nrows, ncols, vertical_spacing=0.005, horizontal_spacing=0.005) + + # Get subplot positions + fig_idxs = [(row, col) for row in range(1, nrows+1) + for col in range(1, ncols+1)] + + for idx, z_slice in enumerate(z_slices): + + # Slice and rotate the t1w array + img_slice = np.rot90(t1w_arr[:, :, z_slice], k=3) + + # Get subplot coords + x_coord, y_coord = fig_idxs[idx] + + fig.add_trace(go.Heatmap(z=img_slice, showscale=False, colorscale="gray"), + x_coord, y_coord) + + # Update axes + fig.update_xaxes(showticklabels=False, row=x_coord, col=y_coord) + fig.update_yaxes(showticklabels=False, row=x_coord, col=y_coord) + + fig.update_layout(width=800, height=500) + + return fig diff --git a/requirements.txt b/requirements.txt index cedd103e..66e76a79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,3 +36,4 @@ git+https://github.com/dPys/nilearn.git@enh/parc_conn git+https://github.com/dPys/deepbrain.git@master urllib3>=1.25.4 mplcyberpunk>=0.1.11 +plotly>=4.14.1 From e7d579f0a2475ad468498b6999d645dd22061021 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Thu, 7 Jan 2021 18:49:46 -0800 Subject: [PATCH 2/2] plotly figures for segs --- pynets/reports/plotting.py | 51 ++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/pynets/reports/plotting.py b/pynets/reports/plotting.py index 4344158e..d187f769 100644 --- a/pynets/reports/plotting.py +++ b/pynets/reports/plotting.py @@ -7,13 +7,17 @@ from plotly.subplots import make_subplots -def plot_t1w(t1w): +def plot_t1w_with_segs(t1w, wm, csf): """Plot t1w images using plotly. Parameters ---------- t1w : str Path to a t1w image. + wm : str + Path to wm image. + csf : str + Path to csf image. Returns ------- @@ -23,11 +27,8 @@ def plot_t1w(t1w): # Load data from file t1w_arr = nib.load(t1w).get_fdata() - - # Space out z-slices - z_max = np.shape(t1w_arr)[2] - pad = 30 - z_slices = np.linspace(pad, z_max-pad, num=21, dtype=int) + wm_arr = nib.load(wm).get_fdata() + csf_arr = nib.load(csf).get_fdata() # Init figure nrows = 3 @@ -38,21 +39,35 @@ def plot_t1w(t1w): fig_idxs = [(row, col) for row in range(1, nrows+1) for col in range(1, ncols+1)] - for idx, z_slice in enumerate(z_slices): + # Plot segs + _add_overlay(fig, t1w_arr, "gray", 1.0, fig_idxs) + _add_overlay(fig, wm_arr, "ice", 0.5, fig_idxs) + _add_overlay(fig, csf_arr, "ice", 0.5, fig_idxs) - # Slice and rotate the t1w array - img_slice = np.rot90(t1w_arr[:, :, z_slice], k=3) + fig.update_layout(width=800, height=500) - # Get subplot coords - x_coord, y_coord = fig_idxs[idx] + return fig - fig.add_trace(go.Heatmap(z=img_slice, showscale=False, colorscale="gray"), - x_coord, y_coord) - # Update axes - fig.update_xaxes(showticklabels=False, row=x_coord, col=y_coord) - fig.update_yaxes(showticklabels=False, row=x_coord, col=y_coord) +def _add_overlay(fig, data_arr, colorscale, opacity, fig_idxs): + """Add an overlay to the figure.""" - fig.update_layout(width=800, height=500) + # Space out z_slices + z_max = np.shape(data_arr)[2] + pad = 30 + z_slices = np.linspace(pad, z_max-pad, num=21, dtype=int) - return fig + for idx, z_slice in enumerate(z_slices): + + # Get subplot position + x_pos, y_pos = fig_idxs[idx] + + # Don't plot if array is all zeros + if np.mean(data_arr[:, :, z_slice]) > 0: + + fig.add_trace(go.Heatmap(z=np.rot90(data_arr[:, :, z_slice], k=3), showscale=False, + colorscale=colorscale, opacity=opacity), + x_pos, y_pos) + + fig.update_xaxes(showticklabels=False, row=x_pos, col=y_pos) + fig.update_yaxes(showticklabels=False, row=x_pos, col=y_pos)