Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataFusionBlock enhancements to extract and fuse optical flow features into the model #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ MODEL:
DATASETS:
TRAIN: ("coco_2017_train_fake", "ytvis_2019_train",)
TEST: ("ytvis_2019_val",)
DATASET_RATIO: (1.0,)
SOLVER:
IMS_PER_BATCH: 16
BASE_LR: 0.0001
Expand Down
8 changes: 6 additions & 2 deletions demo_video/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from mask2former_video import add_maskformer2_video_config
from predictor import VisualizationDemo
import imageio
import random
from maskfreevis.config import get_cfg
from maskfreevis.data_fusion_modeling import add_data_fusion_block_config

# constants
WINDOW_NAME = "mask2former video demo"
Expand All @@ -31,6 +34,7 @@ def setup_cfg(args):
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
add_maskformer2_video_config(cfg)
add_data_fusion_block_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
Expand Down Expand Up @@ -107,14 +111,14 @@ def test_opencv_video_format(codec, file_ext):
# assert args.input, "The input path(s) was not found"
print('args input:', args.input)
args.input = args.input[0]
for file_name in os.listdir(args.input):
for file_name in random.sample(os.listdir(args.input), 20):
input_path_list = sorted([args.input + file_name + '/' + f for f in os.listdir(args.input + file_name)])
print('input path list:', input_path_list)
if len(input_path_list) == 0:
continue
vid_frames = []
for path in input_path_list:
img = read_image(path, format="BGR")
img = read_image(path, format=cfg.INPUT.FORMAT)
vid_frames.append(img)
start_time = time.time()
with autocast():
Expand Down
42 changes: 36 additions & 6 deletions demo_video/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@
from collections import deque
import cv2
import torch
from visualizer import TrackVisualizer
import copy
from detectron2.data import MetadataCatalog
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures import Instances
from detectron2.utils.video_visualizer import VideoVisualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.utils.visualizer import ColorMode, Visualizer
from maskfreevis.data_fusion_modeling import extract_optical_flow_dense_matrix

try:
from .visualizer import TrackVisualizer
except:
from visualizer import TrackVisualizer


class VisualizationDemo(object):
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
def __init__(self, cfg, metadata=None, instance_mode=ColorMode.IMAGE, parallel=False):
"""
Args:
cfg (CfgNode):
Expand All @@ -24,6 +31,8 @@ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
self.metadata = MetadataCatalog.get(
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
)
if metadata is not None:
self.metadata = metadata
self.cpu_device = torch.device("cpu")
self.instance_mode = instance_mode
self.parallel = parallel
Expand Down Expand Up @@ -87,6 +96,10 @@ class VideoPredictor(DefaultPredictor):
inputs = cv2.imread("input.jpg")
outputs = pred(inputs)
"""
def __init__(self, cfg):
super().__init__(cfg)
self.data_fusion_status = cfg.MODEL.DATAFUSION.STATUS

def __call__(self, frames):
"""
Args:
Expand All @@ -96,18 +109,35 @@ def __call__(self, frames):
the output of the model for one image only.
See :doc:`/tutorials/models` for details about the format.
"""
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
with torch.inference_mode(): # https://github.com/sphinx-doc/sphinx/issues/4258
input_frames = []
for original_image in frames:
optical_flow_matrixes = []
for image_idx, original_image in enumerate(frames):
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
prev_frame = copy.deepcopy(frames[image_idx - 1])
current_frame = copy.deepcopy(frames[image_idx])
original_image = original_image[:, :, ::-1]
else:
prev_frame = copy.deepcopy(frames[image_idx - 1][:, :, ::-1])
current_frame = copy.deepcopy(frames[image_idx][:, :, ::-1])

if self.data_fusion_status:
optical_flow_matrix = extract_optical_flow_dense_matrix(prev_frame, current_frame)
optical_flow_matrix = self.aug.get_transform(optical_flow_matrix).apply_image(optical_flow_matrix)
optical_flow_matrix = torch.as_tensor(optical_flow_matrix.astype("float32").transpose(2, 0, 1))
optical_flow_matrixes.append(optical_flow_matrix)
else:
del prev_frame
del current_frame

height, width = original_image.shape[:2]
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
input_frames.append(image)
inputs = {"image": input_frames, "height": height, "width": width}

inputs = {"image": input_frames, "height": height, "width": width, "optical_flow": optical_flow_matrixes}
predictions = self.model([inputs])
return predictions

Expand Down
Loading