From abba28e64bff85beacd466f817a1dc3550aa7159 Mon Sep 17 00:00:00 2001 From: Martin Schorb Date: Fri, 9 Feb 2024 15:47:50 +0100 Subject: [PATCH] implement test up to luigi task --- mobie/import_data/image.py | 6 ++++-- mobie/import_data/utils.py | 10 ++++++++-- test/test_image_data.py | 14 ++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/mobie/import_data/image.py b/mobie/import_data/image.py index 6cb7026..a4f60a1 100644 --- a/mobie/import_data/image.py +++ b/mobie/import_data/image.py @@ -43,7 +43,9 @@ def import_image_data(in_path, in_key, out_path, raise NotImplementedError("Selection of sub-arrays only possible with OME-Zarr output.") if selected_input_channel: - if len(selected_input_channel) < 2: + if type(selected_input_channel) is int: + selected_input_channel = [0, selected_input_channel] + elif len(selected_input_channel) < 2: # if only one element, we assume relevant image stack dimension is 0 (like channel for multi-channel tifs). selected_input_channel = [0, selected_input_channel[0]] elif len(selected_input_channel) > 2: @@ -52,7 +54,7 @@ def import_image_data(in_path, in_key, out_path, with open_file(in_path, mode="r") as f: shape = f[in_key].shape newshape = list(shape) - _unused_ = newshape.pop(selected_input_channel[1]) + _unused_ = newshape.pop(selected_input_channel[0]) roi_begin = [0] * len(shape) roi_end = list(shape) diff --git a/mobie/import_data/utils.py b/mobie/import_data/utils.py index 8c12180..28ab0f3 100644 --- a/mobie/import_data/utils.py +++ b/mobie/import_data/utils.py @@ -1,5 +1,6 @@ import json import os +import numpy as np import luigi import nifty.distributed as ndist @@ -47,12 +48,17 @@ def compute_node_labels(seg_path, seg_key, return data -def check_input_data(in_path, in_key, resolution, require3d, channel): +def check_input_data(in_path, in_key, resolution, require3d, channel, roi_begin=None, roi_end=None): # TODO to support data with channel, we need to support downscaling with channels if channel is not None: raise NotImplementedError with open_file(in_path, "r") as f: ndim = f[in_key].ndim + if any((roi_begin, roi_end)): + # reduce singleton dimensons + if any(np.array(roi_end) - np.array(roi_begin) == 1): + ndim = ndim - np.sum(np.array(roi_end) - np.array(roi_begin) == 1) + if require3d and ndim != 3: raise ValueError(f"Expect 3d data, got ndim={ndim}") if len(resolution) != ndim: @@ -73,7 +79,7 @@ def downscale(in_path, in_key, out_path, config_dir = os.path.join(tmp_folder, "configs") # ome.zarr can also be written in 2d, all other formats require 3d require3d = metadata_format != "ome.zarr" - check_input_data(in_path, in_key, resolution, require3d, channel) + check_input_data(in_path, in_key, resolution, require3d, channel, roi_begin=roi_begin, roi_end=roi_end) write_global_config(config_dir, block_shape=block_shape, require3d=require3d, roi_begin=roi_begin, roi_end=roi_end) diff --git a/test/test_image_data.py b/test/test_image_data.py index 7904c23..c12889e 100644 --- a/test/test_image_data.py +++ b/test/test_image_data.py @@ -314,6 +314,20 @@ def test_skip_metadata(self): self.check_data(os.path.join(self.root, self.dataset_name), im_name) + def test_input_channel(self): + path1 = os.path.join(self.test_folder, '3ch.h5') + key = 'data' + self.make_hdf5_data(path1, key, shape=(3,128,128)) + + mobie.add_image(path1, key, self.root, self.dataset_name, '3ch_test', + resolution=(1, 1), scale_factors=[[2,2]], + chunks=(64, 64), tmp_folder=self.tmp_folder, + file_format='ome.zarr', + target="local", max_jobs=self.max_jobs, selected_input_channel=1) + + + pass + # # data validation #