Skip to content

Commit

Permalink
Merge pull request #335 from roboflow/use-nms-setting-from-request
Browse files Browse the repository at this point in the history
Grounding dino: Use class agnostic nms value from request
  • Loading branch information
probicheaux authored Mar 26, 2024
2 parents 4c3809d + 1058bcd commit 56458cd
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
2 changes: 2 additions & 0 deletions inference/core/entities/requests/groundingdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from inference.core.entities.requests.dynamic_class_base import (
DynamicClassBaseInferenceRequest,
)
from inference.core.env import CLASS_AGNOSTIC_NMS


class GroundingDINOInferenceRequest(DynamicClassBaseInferenceRequest):
Expand All @@ -15,3 +16,4 @@ class GroundingDINOInferenceRequest(DynamicClassBaseInferenceRequest):
box_threshold: Optional[float] = 0.5
grounding_dino_version_id: Optional[str] = "default"
text_threshold: Optional[float] = 0.5
class_agnostic_nms: Optional[bool] = CLASS_AGNOSTIC_NMS
1 change: 1 addition & 0 deletions inference/core/models/classification_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def infer(
Args:
image (Any): The image or list of images to be processed.
- can be a BGR numpy array, filepath, InferenceRequestImage, PIL Image, byte-string, etc.
disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False.
disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False.
disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False.
Expand Down
10 changes: 2 additions & 8 deletions inference/models/grounding_dino/grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ def __init__(
GROUNDING_DINO_CONFIG_PATH = os.path.join(
GROUNDING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py"
)
# GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(
# GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
# )

if not os.path.exists(GROUNDING_DINO_CACHE_DIR):
os.makedirs(GROUNDING_DINO_CACHE_DIR)
Expand All @@ -53,10 +50,6 @@ def __init__(
url = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH)

# if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
# url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
# urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH)

self.model = Model(
model_config_path=GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=os.path.join(
Expand Down Expand Up @@ -95,6 +88,7 @@ def infer(
class_filter: list = None,
box_threshold=0.5,
text_threshold=0.5,
class_agnostic_nms=CLASS_AGNOSTIC_NMS,
**kwargs
):
"""
Expand All @@ -121,7 +115,7 @@ def infer(

self.class_names = text

if CLASS_AGNOSTIC_NMS:
if class_agnostic_nms:
detections = detections.with_nms(class_agnostic=True)
else:
detections = detections.with_nms()
Expand Down

0 comments on commit 56458cd

Please sign in to comment.