Skip to content

Commit

Permalink
Merge pull request #333 from roboflow/grounding-dino-enhancements
Browse files Browse the repository at this point in the history
Grounding DINO enhancements and bugfixes
  • Loading branch information
probicheaux authored Mar 26, 2024
2 parents 2185ed7 + a3bdef3 commit 4c3809d
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 16 deletions.
14 changes: 10 additions & 4 deletions docs/foundation/grounding_dino.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ results = model.infer(
"type": "url",
"value": "https://media.roboflow.com/fruit.png",
},
"text": ["apple"]
"text": ["apple"],

# Optional params
"box_threshold": 0.5
"text_threshold": 0.5
}
)

Expand All @@ -40,8 +44,10 @@ In this code, we load Grounding DINO, run Grounding DINO on an image, and annota

Above, replace:

1. `coffee cup` with the object you want to detect.
2. `image.jpg` with the path to the image in which you want to detect objects.
1. `apple` with the object you want to detect.
2. `fruit.png` with the path to the image in which you want to detect objects.

Additionally, you can tweak the optional `box_threshold` and `class_threshold` params for your specific use case. Both values default to 0.5 if not set. See the <a href="https://github.com/IDEA-Research/GroundingDINO?tab=readme-ov-file#star-explanationstips-for-grounding-dino-inputs-and-outputs">Grounding DINO README</a> for an explanation of the model's thresholds.

To use Grounding DINO with Inference, you will need a Roboflow API key. If you don't already have a Roboflow account, <a href="https://app.roboflow.com" target="_blank">sign up for a free Roboflow account</a>. Then, retrieve your API key from the Roboflow dashboard. Run the following command to set your API key in your coding environment:

Expand All @@ -55,4 +61,4 @@ Then, run the Python script you have created:
python app.py
```

The predictions from your model will be printed to the console.
The predictions from your model will be printed to the console.
2 changes: 2 additions & 0 deletions inference/core/entities/requests/groundingdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ class GroundingDINOInferenceRequest(DynamicClassBaseInferenceRequest):
text (List[str]): A list of strings.
"""

box_threshold: Optional[float] = 0.5
grounding_dino_version_id: Optional[str] = "default"
text_threshold: Optional[float] = 0.5
9 changes: 9 additions & 0 deletions inference/core/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def load_image_rgb(value: Any, disable_preproc_auto_orient: bool = False) -> np.
return np_image


def load_image_bgr(value: Any, disable_preproc_auto_orient: bool = False) -> np.ndarray:
np_image, is_bgr = load_image(
value=value, disable_preproc_auto_orient=disable_preproc_auto_orient
)
if not is_bgr:
np_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
return np_image


def load_image(
value: Any,
disable_preproc_auto_orient: bool = False,
Expand Down
37 changes: 25 additions & 12 deletions inference/models/grounding_dino/grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
ObjectDetectionInferenceResponse,
ObjectDetectionPrediction,
)
from inference.core.env import MODEL_CACHE_DIR
from inference.core.env import CLASS_AGNOSTIC_NMS, MODEL_CACHE_DIR
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.image_utils import load_image_rgb, xyxy_to_xywh
from inference.core.utils.image_utils import load_image_bgr, xyxy_to_xywh


class GroundingDINO(RoboflowCoreModel):
Expand All @@ -37,17 +37,17 @@ def __init__(

super().__init__(*args, model_id=model_id, **kwargs)

GROUDNING_DINO_CACHE_DIR = os.path.join(MODEL_CACHE_DIR, model_id)
GROUNDING_DINO_CACHE_DIR = os.path.join(MODEL_CACHE_DIR, model_id)

GROUNDING_DINO_CONFIG_PATH = os.path.join(
GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py"
)
# GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(
# GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
# )

if not os.path.exists(GROUDNING_DINO_CACHE_DIR):
os.makedirs(GROUDNING_DINO_CACHE_DIR)
if not os.path.exists(GROUNDING_DINO_CACHE_DIR):
os.makedirs(GROUNDING_DINO_CACHE_DIR)

if not os.path.exists(GROUNDING_DINO_CONFIG_PATH):
url = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
Expand All @@ -60,7 +60,7 @@ def __init__(
self.model = Model(
model_config_path=GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=os.path.join(
GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
GROUNDING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
Expand All @@ -75,7 +75,7 @@ def preproc_image(self, image: Any):
Returns:
np.array: The preprocessed image.
"""
np_image = load_image_rgb(image)
np_image = load_image_bgr(image)
return np_image

def infer_from_request(
Expand All @@ -89,7 +89,13 @@ def infer_from_request(
return result

def infer(
self, image: Any = None, text: list = None, class_filter: list = None, **kwargs
self,
image: InferenceRequestImage,
text: list[str] = None,
class_filter: list = None,
box_threshold=0.5,
text_threshold=0.5,
**kwargs
):
"""
Run inference on a provided image.
Expand All @@ -109,12 +115,17 @@ def infer(
detections = self.model.predict_with_classes(
image=image,
classes=text,
box_threshold=0.5,
text_threshold=0.5,
box_threshold=box_threshold,
text_threshold=text_threshold,
)

self.class_names = text

if CLASS_AGNOSTIC_NMS:
detections = detections.with_nms(class_agnostic=True)
else:
detections = detections.with_nms()

xywh_bboxes = [xyxy_to_xywh(detection) for detection in detections.xyxy]

t2 = perf_counter() - t1
Expand All @@ -133,7 +144,9 @@ def infer(
}
)
for i, pred in enumerate(detections.xyxy)
if not class_filter or self.class_names[int(pred[6])] in class_filter
if not class_filter
or self.class_names[int(pred[6])] in class_filter
and detections.class_id[i] is not None
],
image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]),
time=t2,
Expand Down

0 comments on commit 4c3809d

Please sign in to comment.