Skip to content

Commit

Permalink
fix: batch mask dilation issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ltdrdata committed Jul 24, 2024
1 parent 69af3ec commit 7aa30a9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import importlib

version_code = [0, 82, 3]
version_code = [0, 82, 4]
version_str = f"V{version_code[0]}.{version_code[1]}" + (f'.{version_code[2]}' if len(version_code) > 2 else '')
print(f"### Loading: ComfyUI-Inspire-Pack ({version_str})")

Expand Down
23 changes: 13 additions & 10 deletions inspire/libs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,13 @@ def clear(self):
self._data = {}


def make_2d_mask(mask):
def make_3d_mask(mask):
if len(mask.shape) == 4:
return mask.squeeze(0).squeeze(0)

elif len(mask.shape) == 3:
return mask.squeeze(0)

elif len(mask.shape) == 2:
return mask.unsqueeze(0)

return mask


Expand All @@ -306,14 +306,17 @@ def dilate_mask(mask: torch.Tensor, dilation_factor: float) -> torch.Tensor:
kernel_size = int(dilation_factor * 2) + 1
kernel = np.ones((abs(kernel_size), abs(kernel_size)), np.uint8)

mask = make_2d_mask(mask)
masks = make_3d_mask(mask).numpy()
dilated_masks = []
for m in masks:
if dilation_factor > 0:
m2 = cv2.dilate(m, kernel, iterations=1)
else:
m2 = cv2.erode(m, kernel, iterations=1)

if dilation_factor > 0:
mask_dilated = cv2.dilate(mask.numpy(), kernel, iterations=1)
else:
mask_dilated = cv2.erode(mask.numpy(), kernel, iterations=1)
dilated_masks.append(torch.from_numpy(m2))

return torch.from_numpy(mask_dilated).unsqueeze(0)
return torch.stack(dilated_masks)


def flatten_non_zero_override(masks: torch.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-inspire-pack"
description = "This extension provides various nodes to support Lora Block Weight and the Impact Pack. Provides many easily applicable regional features and applications for Variation Seed."
version = "0.82.3"
version = "0.82.4"
license = "LICENSE"
dependencies = ["matplotlib", "cachetools"]

Expand Down

0 comments on commit 7aa30a9

Please sign in to comment.