From dc023411aed737eb573f8165fd08608d468ed8a6 Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Thu, 15 Aug 2024 13:04:03 -0400 Subject: [PATCH] fixing bug with net_avg and flow_threshold --- cellpose_napari/_dock_widget.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/cellpose_napari/_dock_widget.py b/cellpose_napari/_dock_widget.py index 3e2e60b..c09342e 100644 --- a/cellpose_napari/_dock_widget.py +++ b/cellpose_napari/_dock_widget.py @@ -57,12 +57,11 @@ def _deco(func): @thread_worker @no_grad() def run_cellpose(image, model_type, custom_model, channels, channel_axis, diameter, - net_avg, resample, cellprob_threshold, - model_match_threshold, do_3D, stitch_threshold): + resample, cellprob_threshold, + flow_threshold, do_3D, stitch_threshold): from cellpose import models - flow_threshold = (31.0 - model_match_threshold) / 10. - if model_match_threshold==0.0: + if flow_threshold==0.0: flow_threshold = 0.0 logger.debug('flow_threshold=0 => no masks thrown out due to model mismatch') logger.debug(f'computing masks with cellprob_threshold={cellprob_threshold}, flow_threshold={flow_threshold}') @@ -74,7 +73,6 @@ def run_cellpose(image, model_type, custom_model, channels, channel_axis, diamet channels=channels, channel_axis=channel_axis, diameter=diameter, - net_avg=net_avg, resample=resample, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, @@ -101,15 +99,15 @@ def compute_diameter(image, channels, model_type): return diam @thread_worker - def compute_masks(masks_orig, flows_orig, cellprob_threshold, model_match_threshold): + def compute_masks(masks_orig, flows_orig, cellprob_threshold, flow_threshold): import cv2 from cellpose.utils import fill_holes_and_remove_small_masks from cellpose.dynamics import get_masks from cellpose.transforms import resize_image #print(flows_orig[3].shape, flows_orig[2].shape, masks_orig.shape) - flow_threshold = (31.0 - model_match_threshold) / 10. - if model_match_threshold==0.0: + flow_threshold = (31.0 - flow_threshold) / 10. + if flow_threshold==0.0: flow_threshold = 0.0 logger.debug('flow_threshold=0 => no masks thrown out due to model mismatch') logger.debug(f'computing masks with cellprob_threshold={cellprob_threshold}, flow_threshold={flow_threshold}') @@ -131,9 +129,8 @@ def compute_masks(masks_orig, flows_orig, cellprob_threshold, model_match_thresh compute_diameter_shape = dict(widget_type='PushButton', text='compute diameter from shape layer', tooltip='create shape layer with circles and/or squares, select above, and diameter will be estimated from it'), compute_diameter_button = dict(widget_type='PushButton', text='compute diameter from image', tooltip='cellpose model will estimate diameter from image using specified channels'), cellprob_threshold = dict(widget_type='FloatSlider', name='cellprob_threshold', value=0.0, min=-8.0, max=8.0, step=0.2, tooltip='cell probability threshold (set lower to get more cells and larger cells)'), - model_match_threshold = dict(widget_type='FloatSlider', name='model_match_threshold', value=27.0, min=0.0, max=30.0, step=0.2, tooltip='threshold on gradient match to accept a mask (set lower to get more cells)'), + flow_threshold = dict(widget_type='FloatSlider', name='flow_threshold', value=0.4, min=0.0, max=3.0, step=0.05, tooltip='threshold on gradient match to accept a mask (set higher to get more cells, or to zero to turn off)'), compute_masks_button = dict(widget_type='PushButton', text='recompute last masks with new cellprob + model match', enabled=False), - net_average = dict(widget_type='CheckBox', text='average 4 nets', value=True, tooltip='average 4 different fit networks (default) or if not checked run only 1 network (fast)'), resample_dynamics = dict(widget_type='CheckBox', text='resample dynamics', value=False, tooltip='if False, mask estimation with dynamics run on resized image with diameter=30; if True, flows are resized to original image size before dynamics and mask estimation (turn on for more smooth masks)'), process_3D = dict(widget_type='CheckBox', text='process stack as 3D', value=False, tooltip='use default 3D processing where flows in X, Y, and Z are computed and dynamics run in 3D to create masks'), stitch_threshold_3D = dict(widget_type='LineEdit', label='stitch threshold slices', value=0, tooltip='across time or Z, stitch together masks with IoU threshold of "stitch threshold" to create 3D segmentation'), @@ -153,9 +150,8 @@ def widget(#label_logo, compute_diameter_shape, compute_diameter_button, cellprob_threshold, - model_match_threshold, + flow_threshold, compute_masks_button, - net_average, resample_dynamics, process_3D, stitch_threshold_3D, @@ -256,10 +252,9 @@ def _new_segmentation(segmentation): max(0, optional_nuclear_channel)], channel_axis=widget.channel_axis, diameter=float(diameter), - net_avg=net_average, resample=resample_dynamics, cellprob_threshold=cellprob_threshold, - model_match_threshold=model_match_threshold, + flow_threshold=flow_threshold, do_3D=(process_3D and float(stitch_threshold_3D)==0 and image_layer.ndim>2), stitch_threshold=float(stitch_threshold_3D) if image_layer.ndim>2 else 0.0) cp_worker.returned.connect(_new_segmentation) @@ -290,7 +285,7 @@ def _compute_masks(e: Any): mask_worker = compute_masks(widget.masks_orig, widget.flows_orig, widget.cellprob_threshold.value, - widget.model_match_threshold.value) + widget.flow_threshold.value) mask_worker.returned.connect(update_masks) mask_worker.start()