Skip to content

Commit

Permalink
implement test up to luigi task
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Schorb committed Feb 9, 2024
1 parent 9b22618 commit abba28e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
6 changes: 4 additions & 2 deletions mobie/import_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def import_image_data(in_path, in_key, out_path,
raise NotImplementedError("Selection of sub-arrays only possible with OME-Zarr output.")

if selected_input_channel:
if len(selected_input_channel) < 2:
if type(selected_input_channel) is int:
selected_input_channel = [0, selected_input_channel]
elif len(selected_input_channel) < 2:
# if only one element, we assume relevant image stack dimension is 0 (like channel for multi-channel tifs).
selected_input_channel = [0, selected_input_channel[0]]
elif len(selected_input_channel) > 2:
Expand All @@ -52,7 +54,7 @@ def import_image_data(in_path, in_key, out_path,
with open_file(in_path, mode="r") as f:
shape = f[in_key].shape
newshape = list(shape)
_unused_ = newshape.pop(selected_input_channel[1])
_unused_ = newshape.pop(selected_input_channel[0])

roi_begin = [0] * len(shape)
roi_end = list(shape)
Expand Down
10 changes: 8 additions & 2 deletions mobie/import_data/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import numpy as np

import luigi
import nifty.distributed as ndist
Expand Down Expand Up @@ -47,12 +48,17 @@ def compute_node_labels(seg_path, seg_key,
return data


def check_input_data(in_path, in_key, resolution, require3d, channel):
def check_input_data(in_path, in_key, resolution, require3d, channel, roi_begin=None, roi_end=None):
# TODO to support data with channel, we need to support downscaling with channels
if channel is not None:
raise NotImplementedError
with open_file(in_path, "r") as f:
ndim = f[in_key].ndim
if any((roi_begin, roi_end)):
# reduce singleton dimensons
if any(np.array(roi_end) - np.array(roi_begin) == 1):
ndim = ndim - np.sum(np.array(roi_end) - np.array(roi_begin) == 1)

if require3d and ndim != 3:
raise ValueError(f"Expect 3d data, got ndim={ndim}")
if len(resolution) != ndim:
Expand All @@ -73,7 +79,7 @@ def downscale(in_path, in_key, out_path,
config_dir = os.path.join(tmp_folder, "configs")
# ome.zarr can also be written in 2d, all other formats require 3d
require3d = metadata_format != "ome.zarr"
check_input_data(in_path, in_key, resolution, require3d, channel)
check_input_data(in_path, in_key, resolution, require3d, channel, roi_begin=roi_begin, roi_end=roi_end)
write_global_config(config_dir, block_shape=block_shape, require3d=require3d,
roi_begin=roi_begin, roi_end=roi_end)

Expand Down
14 changes: 14 additions & 0 deletions test/test_image_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,20 @@ def test_skip_metadata(self):
self.check_data(os.path.join(self.root, self.dataset_name), im_name)


def test_input_channel(self):
path1 = os.path.join(self.test_folder, '3ch.h5')
key = 'data'
self.make_hdf5_data(path1, key, shape=(3,128,128))

mobie.add_image(path1, key, self.root, self.dataset_name, '3ch_test',
resolution=(1, 1), scale_factors=[[2,2]],
chunks=(64, 64), tmp_folder=self.tmp_folder,
file_format='ome.zarr',
target="local", max_jobs=self.max_jobs, selected_input_channel=1)


pass

#
# data validation
#
Expand Down

0 comments on commit abba28e

Please sign in to comment.