Skip to content

Commit

Permalink
fixed overwriting old contours and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alasdairwilson committed Jul 11, 2024
1 parent a491a35 commit ac22797
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 26 deletions.
66 changes: 42 additions & 24 deletions ripplemapper/analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,24 @@ def add_boundary_contours(ripple_images: list[RippleImage] | RippleImage | Rippl
ripple_images = [ripple_images]
for ripple_image in ripple_images:
if len(ripple_image.contours) > 0:
for contour in ripple_image.contours:
if 'Upper Boundary' in contour.method or 'Lower Boundary' in contour.method:
if overwrite:
warnings.warn(f"Overwriting boundary contour for image: {ripple_image.source_file}")
ripple_image.contours.remove(contour)
else:
warnings.warn(f"Boundary contour already exists, skipping image: {ripple_image.source_file}")
continue
indexes = []
for i in range(len(ripple_image.contours)):
if ripple_image.contours[i].method == 'Upper Boundary':
indexes.append(i)
if ripple_image.contours[i].method == 'Lower Boundary':
indexes.append(i)
if len(indexes) > 0:
if overwrite:
warnings.warn(f"Overwriting boundary contours for image: {ripple_image.source_file}")
if len(indexes) == 1:
ripple_image.contours.pop(indexes[0])
if len(indexes) == 2:
ripple_image.contours.pop(indexes[0])
# they have now moved by 1.
ripple_image.contours.pop(indexes[1]-1)
else:
warnings.warn(f"Boundary contours already exist, skipping image: {ripple_image.source_file}")
continue
edges = detect_edges(ripple_image.image)
processed_edges = process_edges(edges)
contours = find_contours(processed_edges, level=level)
Expand All @@ -47,14 +57,20 @@ def add_a_star_contours(ripple_images: list[RippleImage] | RippleImage | RippleI
if len(ripple_image.contours) < 2:
warnings.warn(f"RippleImage object must have at least two contours, skipping image: {ripple_image.source_file}")
continue
for contour in ripple_image.contours:
if 'A* traversal' in contour.method:
if overwrite:
warnings.warn(f"Overwriting A* contour for image: {ripple_image.source_file}")
ripple_image.contours.remove(contour)
else:
warnings.warn(f"A* contour already exists, skipping image: {ripple_image.source_file}")
continue
methods = [contour.method for contour in ripple_image.contours]
if 'A* traversal' in methods:
if overwrite:
warnings.warn(f"Overwriting A* contour for image: {ripple_image.source_file}")
# find me the method index that matches 'A* traversal'
for contour in ripple_image.contours:
print(contour.method)
if contour.method == 'A* traversal':
ripple_image.contours.remove(contour)
print(ripple_image.contours)
else:
warnings.warn(f"A* contour already exists, skipping image: {ripple_image.source_file}")
continue

cont1 = np.flip(ripple_image.contours[contour_index[0]].values).astype(np.int32).T
cont2 = np.flip(ripple_image.contours[contour_index[1]].values).astype(np.int32).T
contour = combine_contours(cont1, cont2)
Expand All @@ -76,14 +92,16 @@ def add_chan_vese_contours(ripple_images: list[RippleImage] | RippleImage | Ripp
ripple_images = [ripple_images]
for ripple_image in ripple_images:
if len(ripple_image.contours) > 0:
for contour in ripple_image.contours:
if 'Chan-Vese' in contour.method:
if overwrite:
warnings.warn(f"Overwriting Chan-Vese contour for image: {ripple_image.source_file}")
ripple_image.contours.remove(contour)
else:
warnings.warn(f"Chan-Vese contour already exists, skipping image: {ripple_image.source_file}")
continue
methods = [contour.method for contour in ripple_image.contours]
if 'Chan-Vese' in methods:
if overwrite:
warnings.warn(f"Overwriting Chan-Vese contour for image: {ripple_image.source_file}")
for contour in ripple_image.contours:
if contour.method == 'Chan-Vese':
ripple_image.contours.remove(contour)
else:
warnings.warn(f"Chan-Vese contour already exists, skipping image: {ripple_image.source_file}")
continue
if use_gradients:
grad = np.sum(np.abs(np.gradient(ripple_image.image)), axis=0)
img = cv2.GaussianBlur(grad / np.max(grad), (7,7), 0)+(1-(ripple_image.image/np.max(ripple_image.image)))
Expand Down
4 changes: 2 additions & 2 deletions ripplemapper/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

__all__ = ["load_image", "load_tif", "load_dir", "load_dir_to_obj", "load"]

def load(file: str | PosixPath):
def load(file: str | PosixPath | WindowsPath):
"""Load a file into a ripplemapper object based on file extension."""
from ripplemapper.classes import (RippleContour, RippleImage,
RippleImageSeries)
Expand All @@ -26,7 +26,7 @@ def load(file: str | PosixPath):

# TODO (ADW): Add support for other image file types just use load_tif for now.
# should probably be looping in this function rather than the dispatched functions but... it's fine for now.
def load_image(file: str | PosixPath) -> np.ndarray:
def load_image(file: str | PosixPath | WindowsPath) -> np.ndarray:
"""Load an image file based on file extension."""
# TODO (ADW): this needs to be refactored to allow lists.
if isinstance(file, PosixPath) | isinstance(file, WindowsPath):
Expand Down
57 changes: 57 additions & 0 deletions ripplemapper/tests/test_analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,50 @@ def test_add_boundary_contours(loaded_example_image):
assert loaded_example_image.contours[0].method == 'Upper Boundary'
assert loaded_example_image.contours[1].method == 'Lower Boundary'

def test_overwrite_boundary_contours(loaded_example_image):
add_boundary_contours(loaded_example_image)
with pytest.warns(UserWarning, match="Overwriting boundary contour"):
add_boundary_contours(loaded_example_image, overwrite=True)
assert len(loaded_example_image.contours) == 2
assert loaded_example_image.contours[0].method == 'Upper Boundary'
assert loaded_example_image.contours[1].method == 'Lower Boundary'

def test_overwrite_single_boundary_contour(loaded_example_image):
add_boundary_contours(loaded_example_image)
loaded_example_image.contours.pop(0)
assert len(loaded_example_image.contours) == 1
with pytest.warns(UserWarning, match="Overwriting boundary contour"):
add_boundary_contours(loaded_example_image, overwrite=True)
assert len(loaded_example_image.contours) == 2
assert loaded_example_image.contours[0].method == 'Upper Boundary'
assert loaded_example_image.contours[1].method == 'Lower Boundary'

def test_add_a_star_contours(loaded_example_image):
add_boundary_contours(loaded_example_image)
add_a_star_contours(loaded_example_image)
assert len(loaded_example_image.contours) == 3
assert loaded_example_image.contours[2].method == 'A* traversal'

def test_overwrite_a_star_contours(loaded_example_image):
add_boundary_contours(loaded_example_image)
add_a_star_contours(loaded_example_image)
with pytest.warns(UserWarning, match="Overwriting A"):
add_a_star_contours(loaded_example_image, overwrite=True)
assert len(loaded_example_image.contours) == 3
assert loaded_example_image.contours[2].method == 'A* traversal'

def test_add_chan_vese_contours(loaded_example_image):
add_chan_vese_contours(loaded_example_image)
assert len(loaded_example_image.contours) == 1
assert loaded_example_image.contours[0].method == 'Chan-Vese'

def test_overwrite_chan_vese_contours(loaded_example_image):
add_chan_vese_contours(loaded_example_image)
with pytest.warns(UserWarning, match="Overwriting Chan-Vese contour"):
add_chan_vese_contours(loaded_example_image, overwrite=True)
assert len(loaded_example_image.contours) == 1
assert loaded_example_image.contours[0].method == 'Chan-Vese'

def test_remove_small_bumps(loaded_example_contour):
smoothed_contour = remove_small_bumps(loaded_example_contour)
assert smoothed_contour.values.shape[1] <= loaded_example_contour.values.shape[1]
Expand Down Expand Up @@ -57,5 +90,29 @@ def test_remove_small_bumps_from_image_series(loaded_example_image_series):
for image in loaded_example_image_series.images:
assert len(image.contours) == 2

def test_add_boundary_contours_emits_warning(loaded_example_image):
add_boundary_contours(loaded_example_image)
with pytest.warns(UserWarning, match="Boundary contours already exist, skipping image"):
add_boundary_contours(loaded_example_image)
assert len(loaded_example_image.contours) == 2

def test_add_a_star_contours_emits_warning(loaded_example_image):
with pytest.warns(UserWarning, match="RippleImage object must have at least two contours, skipping image:"):
add_a_star_contours(loaded_example_image)
assert len(loaded_example_image.contours) == 0

def test_add_a_star_overwrite_warning(loaded_example_image):
add_boundary_contours(loaded_example_image)
add_a_star_contours(loaded_example_image)
with pytest.warns(UserWarning, match="contour already exists, skipping image"):
add_a_star_contours(loaded_example_image, overwrite=False)
assert len(loaded_example_image.contours) == 3

def test_add_chan_vese_contours_emits_warning(loaded_example_image):
add_chan_vese_contours(loaded_example_image)
with pytest.warns(UserWarning, match="Chan-Vese contour already exists, skipping image:"):
add_chan_vese_contours(loaded_example_image)
assert len(loaded_example_image.contours) == 1

if __name__ == '__main__':
pytest.main()
7 changes: 7 additions & 0 deletions ripplemapper/tests/test_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def test_plot_image_calls_imshow(loaded_example_image):
assert plt.imshow.called
plt.imshow.assert_called_with(loaded_example_image.image, cmap='gray')

def test_plot_image_without_data(loaded_example_image):
plt.imshow = MagicMock()
with pytest.warns(UserWarning, match="Image not loaded for image: "):
loaded_example_image.image = None
plot_image(loaded_example_image, include_contours=False)
assert plt.imshow.called

def test_plot_image_with_contours_calls_plot(loaded_example_image_with_contours):
plt.imshow = MagicMock()
plt.plot = MagicMock()
Expand Down
4 changes: 4 additions & 0 deletions ripplemapper/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def plot_image(ripple_image, include_contours: bool=True, cmap: str='gray', **k
x_min = min(x_min, np.min(contour.values[1]))
y_max = max(y_max, np.max(contour.values[0]))
y_min = min(y_min, np.min(contour.values[0]))
if x_min == np.inf:
x_min = 0
if y_min == np.inf:
y_min = 0
x_min, x_max, y_min, y_max = int(np.floor(x_min)), int(np.ceil(x_max)), int(np.floor(y_min)), int(np.ceil(y_max))
plt.imshow(np.zeros((y_max, x_max)), cmap=cmap, **kwargs)
plt.xlim(x_min, x_max)
Expand Down

0 comments on commit ac22797

Please sign in to comment.