From 1b6d8b6102922c613f2ebf80e225ddb39a8eecdd Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 14 Jun 2022 17:01:20 -0700 Subject: [PATCH 01/19] replaced build_loss() cls.TargetImage invocation with conventional init --- src/pytti/LossAug/LossOrchestratorClass.py | 34 ++++++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index 931537e..b7173b3 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -2,7 +2,7 @@ from loguru import logger from PIL import Image -from pytti.image_models import PixelImage +from pytti.image_models import PixelImage, RGBImage # from pytti.LossAug import build_loss from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss @@ -15,20 +15,42 @@ ################################# +import torch LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss} -def build_loss(weight_name, weight, name, img, pil_target): +def build_loss( + weight_name: str, + weight: str, + name: str, + img: RGBImage, + pil_target: Image, + device=None, +): # from pytti.LossAug import LOSS_DICT + if device is None: + device = img.device weight_name, suffix = weight_name.split("_", 1) if weight_name == "direct": - Loss = type(img).get_preferred_loss() + loss = type(img).get_preferred_loss() + else: + loss = LOSS_DICT[weight_name] + # out = Loss.TargetImage( + # f"{weight_name} {name}:{weight}", img.image_shape, pil_target + # ) + if pil_target is not None: + resized = pil_target.resize(img.image_shape, Image.LANCZOS) + comp = loss.make_comp(resized) else: - Loss = LOSS_DICT[weight_name] - out = Loss.TargetImage( - f"{weight_name} {name}:{weight}", img.image_shape, pil_target + # comp = loss.get_default_comp() + comp = torch.zeros(1, 1, 1, 1, device=device) + out = loss( + comp=comp, + weight=weight, + name=f"{weight_name} {name} (direct)", + image_shape=img.image_shape, ) out.set_enabled(pil_target is not None) return out From 06722e04858dc01ca8ff40f3584cb2d7113c1847 Mon Sep 17 00:00:00 2001 From: David Marx Date: Tue, 14 Jun 2022 17:44:29 -0700 Subject: [PATCH 02/19] isolated parse_subprompt(), eliminated another TargetImage invocation --- src/pytti/LossAug/MSELossClass.py | 24 +++++++------------ src/pytti/eval_tools.py | 17 +++++++++++++ src/pytti/workhorse.py | 40 +++++++++++++++++++++++++------ 3 files changed, 59 insertions(+), 22 deletions(-) diff --git a/src/pytti/LossAug/MSELossClass.py b/src/pytti/LossAug/MSELossClass.py index ff1e5dc..71a4d46 100644 --- a/src/pytti/LossAug/MSELossClass.py +++ b/src/pytti/LossAug/MSELossClass.py @@ -6,7 +6,8 @@ # from pytti.Notebook import Rotoscoper from pytti.rotoscoper import Rotoscoper -from pytti import fetch, parse, vram_usage_mode +from pytti import fetch, vram_usage_mode +from pytti.eval_tools import parse, parse_subprompt import torch @@ -36,25 +37,18 @@ def __init__( def TargetImage( cls, prompt_string, image_shape, pil_image=None, is_path=False, device=None ): - # Why is this prompt parsing stuff here? Deprecate in favor of centralized - # parsing functions (if feasible) - text, weight, stop = parse( - prompt_string, r"(? Date: Wed, 15 Jun 2022 00:25:20 -0700 Subject: [PATCH 03/19] added refactoring test suite --- src/pytti/workhorse.py | 8 +++++-- tests/test_loss_refactoring.py | 41 ++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 tests/test_loss_refactoring.py diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index 171ea50..8752205 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -399,10 +399,14 @@ def do_run(): if prompt_string: loss_factory = type(img).get_preferred_loss() text, weight, stop, mask, pil_image = parse_subprompt( - prompt_string, is_path=is_path, pil_image=pil_image + # prompt_string, is_path=True, pil_image=pil_image + prompt_string, + is_path=True, + pil_image=init_image_pil, ) - + image_shape = img.image_shape if pil_image: + # im = pil_image.resize(image_shape, Image.LANCZOS) im = pil_image.resize(image_shape, Image.LANCZOS) comp = loss_factory.make_comp(im) else: diff --git a/tests/test_loss_refactoring.py b/tests/test_loss_refactoring.py new file mode 100644 index 0000000..b23a5af --- /dev/null +++ b/tests/test_loss_refactoring.py @@ -0,0 +1,41 @@ +import pytest + +from hydra import initialize, compose +from loguru import logger +from pytti.workhorse import _main as render_frames +from omegaconf import OmegaConf, open_dict +import torch +from pathlib import Path + +CONFIG_BASE_PATH = "config" +CONFIG_DEFAULTS = "default.yaml" + +TEST_DEVICE = "cuda:0" # "cuda:1" + + +# video_fpath = str(next(Path(".").glob("**/assets/*.mp4"))) +img_fpath = str(next(Path(".").glob("**/src/pytti/assets/*.jpg"))) + + +def run_cfg(cfg_str): + with initialize(config_path=CONFIG_BASE_PATH): + cfg_base = compose( + config_name=CONFIG_DEFAULTS, + overrides=[f"conf=_empty"], + ) + cfg_this = OmegaConf.create(cfg_str) + + with open_dict(cfg_base) as cfg: + cfg = OmegaConf.merge(cfg_base, cfg_this) + render_frames(cfg) + + +# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="requires multiple-GPUs") +def test_direct_init_weight(): + cfg_str = f"""# @package _global_ +scenes: a photograph of an apple +direct_image_prompts: '{img_fpath}:-1:-.5' +direct_init_weight: 1 +device: '{TEST_DEVICE}' +""" + run_cfg(cfg_str) From 561425a79b2f36c295ae50a0fde10ecbcaf3ea6d Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 00:38:33 -0700 Subject: [PATCH 04/19] added test --- tests/test_loss_refactoring.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_loss_refactoring.py b/tests/test_loss_refactoring.py index b23a5af..e102078 100644 --- a/tests/test_loss_refactoring.py +++ b/tests/test_loss_refactoring.py @@ -36,6 +36,23 @@ def test_direct_init_weight(): scenes: a photograph of an apple direct_image_prompts: '{img_fpath}:-1:-.5' direct_init_weight: 1 +semantic_iniit_weight: 1 +device: '{TEST_DEVICE}' +""" + run_cfg(cfg_str) + + +def test_stabilization_weights(): + cfg_str = f"""# @package _global_ +scenes: a photograph of an apple +depth_stabilization_weight: 1 +edge_stabilization_weight: 1 +direct_stabilization_weight: 1 +semantic_stabilization_weight: 1 +flow_stabilization_weight: 1 +steps_per_frame: 10 +steps_per_scene: 150 +#flow_long_term_samples: 1 device: '{TEST_DEVICE}' """ run_cfg(cfg_str) From 2cb75683982679ed9f6d1bcedcd13363d5cb8001 Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 00:58:15 -0700 Subject: [PATCH 05/19] added test, eliminated TargetImage in pixel.py --- src/pytti/image_models/pixel.py | 11 +++++++++-- tests/test_loss_refactoring.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/pytti/image_models/pixel.py b/src/pytti/image_models/pixel.py index 2816bd4..59471cd 100644 --- a/src/pytti/image_models/pixel.py +++ b/src/pytti/image_models/pixel.py @@ -1,4 +1,3 @@ - from pytti import DEVICE, named_rearrange, replace_grad, vram_usage_mode from pytti.image_models.differentiable_image import DifferentiableImage from pytti.LossAug.HSVLossClass import HSVLoss @@ -441,7 +440,15 @@ def encode_image(self, pil_image, smart_encode=True, device=None): # no embedder needed without any prompts if smart_encode: - mse = HSVLoss.TargetImage("HSV loss", self.image_shape, pil_image) + # mse = HSVLoss.TargetImage("HSV loss", self.image_shape, pil_image) + # im = pil_image.resize(image_shape, Image.LANCZOS) + comp = HSVLoss.make_comp(pil_image) + mse = HSVLoss( + comp=comp, + name=text + "HSV loss", + image_shape=pil_image.shape, + device=device, + ) if self.hdr_loss is not None: before_weight = self.hdr_loss.weight.detach() diff --git a/tests/test_loss_refactoring.py b/tests/test_loss_refactoring.py index e102078..213e45a 100644 --- a/tests/test_loss_refactoring.py +++ b/tests/test_loss_refactoring.py @@ -56,3 +56,15 @@ def test_stabilization_weights(): device: '{TEST_DEVICE}' """ run_cfg(cfg_str) + + +def test_limited_palette_image_encode(): + cfg_str = f"""# @package _global_ +scenes: a photograph of an apple +direct_image_prompts: '{img_fpath}:-1:-.5' +direct_init_weight: 1 +semantic_iniit_weight: 1 +image_model: Limited Palette +device: '{TEST_DEVICE}' +""" + run_cfg(cfg_str) From 087f2267c1f9d997437aacd5702a1d8a42076fd3 Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 01:01:38 -0700 Subject: [PATCH 06/19] why did that test pass? hmm... --- src/pytti/image_models/pixel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytti/image_models/pixel.py b/src/pytti/image_models/pixel.py index 59471cd..251668d 100644 --- a/src/pytti/image_models/pixel.py +++ b/src/pytti/image_models/pixel.py @@ -445,7 +445,7 @@ def encode_image(self, pil_image, smart_encode=True, device=None): comp = HSVLoss.make_comp(pil_image) mse = HSVLoss( comp=comp, - name=text + "HSV loss", + name="HSV loss", image_shape=pil_image.shape, device=device, ) From bfca0eef6be5cda4f363fa33a410104b6bb578b2 Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 09:23:55 -0700 Subject: [PATCH 07/19] eliminated TargetImage for video flow --- src/pytti/LossAug/LossOrchestratorClass.py | 38 ++++++++++++++++------ tests/test_loss_refactoring.py | 16 ++++++++- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index b7173b3..250df99 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -7,6 +7,8 @@ # from pytti.LossAug import build_loss from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss from pytti.Perceptor.Prompt import parse_prompt +from pytti.eval_tools import parse_subprompt + from pytti.LossAug.BaseLossClass import Loss from pytti.LossAug.DepthLossClass import DepthLoss @@ -125,20 +127,36 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs): def configure_optical_flows(img, params, loss_augs): - if params.animation_mode == "Video Source": if params.flow_stabilization_weight == "": params.flow_stabilization_weight = "0" - optical_flows = [ - OpticalFlowLoss.TargetImage( - f"optical flow stabilization (frame {-2**i}):{params.flow_stabilization_weight}", - img.image_shape, - ) - for i in range(params.flow_long_term_samples + 1) - ] - for optical_flow in optical_flows: + # optical_flows = [ + # OpticalFlowLoss.TargetImage( + # f"optical flow stabilization (frame {-2**i}):{params.flow_stabilization_weight}", + # img.image_shape, + # ) + # for i in range(params.flow_long_term_samples + 1) + # ] + optical_flows = [] + for i in range(params.flow_long_term_samples + 1): + # prompt_str = f"optical flow stabilization (frame {-2**i}):{params.flow_stabilization_weight}" + # text, weight, stop, mask, pil_image = parse_subprompt(prompt_str) + name = f"optical flow stabilization (frame {-2**i})" + weight = params.flow_stabilization_weight + comp = torch.zeros(1, 1, 1, 1) # ,device=device) + optical_flow = OpticalFlowLoss( + comp=comp, + weight=weight, + name=f"{name} (direct)", + image_shape=img.image_shape, + ) # , device=device) optical_flow.set_enabled(False) - loss_augs.extend(optical_flows) + loss_augs.append(optical_flow) + + ################################## + # for optical_flow in optical_flows: + # optical_flow.set_enabled(False) + # loss_augs.extend(optical_flows) elif params.animation_mode == "3D" and params.flow_stabilization_weight not in [ "0", "", diff --git a/tests/test_loss_refactoring.py b/tests/test_loss_refactoring.py index 213e45a..32553b4 100644 --- a/tests/test_loss_refactoring.py +++ b/tests/test_loss_refactoring.py @@ -13,7 +13,7 @@ TEST_DEVICE = "cuda:0" # "cuda:1" -# video_fpath = str(next(Path(".").glob("**/assets/*.mp4"))) +video_fpath = str(next(Path(".").glob("**/assets/*.mp4"))) img_fpath = str(next(Path(".").glob("**/src/pytti/assets/*.jpg"))) @@ -68,3 +68,17 @@ def test_limited_palette_image_encode(): device: '{TEST_DEVICE}' """ run_cfg(cfg_str) + + +def test_video_optical_flow(): + cfg_str = f"""# @package _global_ +scenes: a photograph of an apple +animation_mode: Video Source +video_path: {video_fpath} +flow_stabilization_weight: 1 +steps_per_frame: 10 +steps_per_scene: 150 +flow_long_term_samples: 3 +device: '{TEST_DEVICE}' +""" + run_cfg(cfg_str) From 7b36c3010acba58c05a7e3aa9d57a1f5e65cdedd Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 09:26:36 -0700 Subject: [PATCH 08/19] cleaned up deprecated code --- src/pytti/LossAug/LossOrchestratorClass.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index 250df99..c65b7e0 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -130,17 +130,9 @@ def configure_optical_flows(img, params, loss_augs): if params.animation_mode == "Video Source": if params.flow_stabilization_weight == "": params.flow_stabilization_weight = "0" - # optical_flows = [ - # OpticalFlowLoss.TargetImage( - # f"optical flow stabilization (frame {-2**i}):{params.flow_stabilization_weight}", - # img.image_shape, - # ) - # for i in range(params.flow_long_term_samples + 1) - # ] - optical_flows = [] + # if flow stabilization weight is 0, shouldn't this next block just get skipped? + for i in range(params.flow_long_term_samples + 1): - # prompt_str = f"optical flow stabilization (frame {-2**i}):{params.flow_stabilization_weight}" - # text, weight, stop, mask, pil_image = parse_subprompt(prompt_str) name = f"optical flow stabilization (frame {-2**i})" weight = params.flow_stabilization_weight comp = torch.zeros(1, 1, 1, 1) # ,device=device) @@ -153,10 +145,6 @@ def configure_optical_flows(img, params, loss_augs): optical_flow.set_enabled(False) loss_augs.append(optical_flow) - ################################## - # for optical_flow in optical_flows: - # optical_flow.set_enabled(False) - # loss_augs.extend(optical_flows) elif params.animation_mode == "3D" and params.flow_stabilization_weight not in [ "0", "", From 9ea18d6333d5d6e0b016e2b9e539a5c60adc782a Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 10:53:09 -0700 Subject: [PATCH 09/19] refactor seems sound but emitting a cuda error now --- src/pytti/LossAug/LossOrchestratorClass.py | 46 ++++++++++++---------- tests/test_loss_refactoring.py | 13 ++++++ 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index c65b7e0..d44e6f9 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -127,43 +127,47 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs): def configure_optical_flows(img, params, loss_augs): + logger.debug(params.device) + _device = params.device + optical_flows = [] if params.animation_mode == "Video Source": if params.flow_stabilization_weight == "": params.flow_stabilization_weight = "0" - # if flow stabilization weight is 0, shouldn't this next block just get skipped? + # TODO: if flow stabilization weight is 0, shouldn't this next block just get skipped? for i in range(params.flow_long_term_samples + 1): - name = f"optical flow stabilization (frame {-2**i})" - weight = params.flow_stabilization_weight - comp = torch.zeros(1, 1, 1, 1) # ,device=device) optical_flow = OpticalFlowLoss( - comp=comp, - weight=weight, - name=f"{name} (direct)", + comp=torch.zeros(1, 1, 1, 1, device=_device), # ,device=DEVICE) + weight=params.flow_stabilization_weight, + name=f"optical flow stabilization (frame {-2**i}) (direct)", image_shape=img.image_shape, + device=_device, ) # , device=device) optical_flow.set_enabled(False) - loss_augs.append(optical_flow) + optical_flows.append(optical_flow) elif params.animation_mode == "3D" and params.flow_stabilization_weight not in [ "0", "", ]: - optical_flows = [ - TargetFlowLoss.TargetImage( - f"optical flow stabilization:{params.flow_stabilization_weight}", - img.image_shape, - device="cuda", - ) - ] - for optical_flow in optical_flows: - optical_flow.set_enabled(False) - loss_augs.extend(optical_flows) - else: - optical_flows = [] + optical_flow = TargetFlowLoss( + comp=torch.zeros(1, 1, 1, 1, device=_device), + weight=params.flow_stabilization_weight, + name="optical flow stabilization (direct)", + image_shape=img.image_shape, + device=_device, + ) + optical_flow.set_enabled(False) + optical_flows.append(optical_flow) + + loss_augs.extend(optical_flows) + + # this shouldn't be in this function based on the name. # other loss augs if params.smoothing_weight != 0: - loss_augs.append(TVLoss(weight=params.smoothing_weight)) + loss_augs.append( + TVLoss(weight=params.smoothing_weight) + ) # , device=params.device)) return img, loss_augs, optical_flows diff --git a/tests/test_loss_refactoring.py b/tests/test_loss_refactoring.py index 32553b4..dfdd3e1 100644 --- a/tests/test_loss_refactoring.py +++ b/tests/test_loss_refactoring.py @@ -82,3 +82,16 @@ def test_video_optical_flow(): device: '{TEST_DEVICE}' """ run_cfg(cfg_str) + + +def test_3D_optical_flow(): + cfg_str = f"""# @package _global_ +scenes: a photograph of an apple +animation_mode: 3D +video_path: {video_fpath} +flow_stabilization_weight: 1 +steps_per_frame: 10 +steps_per_scene: 150 +device: '{TEST_DEVICE}' +""" + run_cfg(cfg_str) From 00046d34f3e97e362afac867e685f41fc3e24ba0 Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 12:26:01 -0700 Subject: [PATCH 10/19] tinkering with device selection, cuda error still there even after restarting computer and reinstalling torch --- src/pytti/LossAug/OpticalFlowLossClass.py | 35 ++++++++++++++++++++--- src/pytti/workhorse.py | 2 +- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/pytti/LossAug/OpticalFlowLossClass.py b/src/pytti/LossAug/OpticalFlowLossClass.py index 86e8998..1619349 100644 --- a/src/pytti/LossAug/OpticalFlowLossClass.py +++ b/src/pytti/LossAug/OpticalFlowLossClass.py @@ -65,7 +65,8 @@ def init_GMA(checkpoint_path=None, device=None): logger.debug(checkpoint_path) global GMA if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) if GMA is None: with vram_usage_mode("GMA"): # migrate this to a hydra initialize/compose operation @@ -97,10 +98,31 @@ def init_GMA(checkpoint_path=None, device=None): "--mixed_precision", action="store_true", help="use mixed precision" ) args = parser.parse_args([]) - GMA = torch.nn.DataParallel(RAFTGMA(args), device_ids=[device]) + + # create new OrderedDict that does not contain `module.` prefix + # state_dict = torch.load(checkpoint_path) + state_dict = torch.load(checkpoint_path, map_location=device) + from collections import OrderedDict + + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith("module."): + k = k[7:] # remove `module.` + new_state_dict[k] = v + + # GMA = torch.nn.DataParallel(RAFTGMA(args), device_ids=[device]) + GMA = RAFTGMA(args) + # GMA = torch.nn.parallel.DistributedDataParallel(RAFTGMA(args).to(device), device_ids=[device]) # GMA = RAFTGMA(args) - GMA.load_state_dict(torch.load(checkpoint_path)) - GMA.to(device) + # GMA.load_state_dict(torch.load(checkpoint_path, map_location=device)) + # GMA.load_state_dict(torch.load(checkpoint_path)) + GMA.load_state_dict(new_state_dict) + logger.debug("gma state_dict loaded") + ########################### + # 1. Fix state dict (remove module prefixes) + # 2. load state dict into model without DataParallel + ########################### + GMA.to(device) # redundant? GMA.eval() @@ -195,6 +217,11 @@ def get_loss(self, input, img, device=None): padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) _, flow = GMA(image1, image2, iters=3, test_mode=True) + logger.debug(device) + logger.debug((flow.shape, flow.device)) + logger.debug((self.comp.shape, self.comp.device)) + # logger.debug(GMA.device) # ugh... I bet this is another dataparallel thing. + # logger.debug(GMA.module.device) flow = flow.to(device, memory_format=torch.channels_last) return super().get_loss(TF.resize(flow, self.comp.shape[-2:]), img) / self.mag diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index 8752205..9a8eda6 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -198,7 +198,7 @@ def _main(cfg: DictConfig): with open_dict(params) as p: p.device = _device logger.debug(f"Using device {_device}") - torch.cuda.set_device(_device) + # torch.cuda.set_device(_device) # literal "off" in yaml interpreted as False if params.animation_mode == False: From 0f0b62a9a23403ab1132d4691a961f10b2ed5ecb Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 12:58:45 -0700 Subject: [PATCH 11/19] fixed attribute error --- src/pytti/image_models/pixel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytti/image_models/pixel.py b/src/pytti/image_models/pixel.py index 251668d..5b0babf 100644 --- a/src/pytti/image_models/pixel.py +++ b/src/pytti/image_models/pixel.py @@ -446,7 +446,8 @@ def encode_image(self, pil_image, smart_encode=True, device=None): mse = HSVLoss( comp=comp, name="HSV loss", - image_shape=pil_image.shape, + #image_shape=pil_image.shape, + image_shape=self.image_shape, device=device, ) From 9ad0bb7fa6c15ab6fedd0b7d3a22b2a39a28dfab Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 14:04:16 -0700 Subject: [PATCH 12/19] specified device in make_comp call of build_loss --- src/pytti/LossAug/LossOrchestratorClass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index d44e6f9..b678c3d 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -44,7 +44,7 @@ def build_loss( # ) if pil_target is not None: resized = pil_target.resize(img.image_shape, Image.LANCZOS) - comp = loss.make_comp(resized) + comp = loss.make_comp(resized, device=device) else: # comp = loss.get_default_comp() comp = torch.zeros(1, 1, 1, 1, device=device) From a7239a899ecd3478158cf7dd3d3c16dc8c136c02 Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 17:23:31 -0700 Subject: [PATCH 13/19] CUDA assertion error was caused by 3D mode with null transforms. added z translation to test, all passing. TO DO: document this edge case with a failing test --- tests/test_loss_refactoring.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_loss_refactoring.py b/tests/test_loss_refactoring.py index dfdd3e1..4dafa79 100644 --- a/tests/test_loss_refactoring.py +++ b/tests/test_loss_refactoring.py @@ -80,6 +80,9 @@ def test_video_optical_flow(): steps_per_scene: 150 flow_long_term_samples: 3 device: '{TEST_DEVICE}' +height: 512 +width: 512 +pixel_size: 1 """ run_cfg(cfg_str) @@ -93,5 +96,9 @@ def test_3D_optical_flow(): steps_per_frame: 10 steps_per_scene: 150 device: '{TEST_DEVICE}' +height: 512 +width: 512 +pixel_size: 1 +translate_z_3d: 10 """ run_cfg(cfg_str) From 0adc838f68279b04910ca1496530a3f7040d25d7 Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 15 Jun 2022 17:48:29 -0700 Subject: [PATCH 14/19] add xfail test for null 3d transform cuda error --- tests/test_loss_refactoring.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_loss_refactoring.py b/tests/test_loss_refactoring.py index 4dafa79..1709180 100644 --- a/tests/test_loss_refactoring.py +++ b/tests/test_loss_refactoring.py @@ -102,3 +102,23 @@ def test_3D_optical_flow(): translate_z_3d: 10 """ run_cfg(cfg_str) + + +# RuntimeError: CUDA error: device-side assert triggered +@pytest.mark.xfail +def test_3d_null_transform_bug(): + cfg_str = f"""# @package _global_ +scenes: a photograph of an apple +animation_mode: 3D +video_path: {video_fpath} +flow_stabilization_weight: 1 +steps_per_frame: 10 +steps_per_scene: 150 +device: '{TEST_DEVICE}' +height: 512 +width: 512 +pixel_size: 1 +translate_z_3d: 0 +rotate_3d: [1,0,0,0] +""" + run_cfg(cfg_str) From 42014424bed14d484150ad0637eae5f1bafb7e2f Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 16 Jun 2022 11:00:50 -0700 Subject: [PATCH 15/19] suppressed old LossOrchestration classes --- src/pytti/LossAug/LossOrchestratorClass.py | 460 ++++++++++----------- src/pytti/workhorse.py | 3 +- 2 files changed, 232 insertions(+), 231 deletions(-) diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index b678c3d..d9e1657 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -175,52 +175,52 @@ def configure_optical_flows(img, params, loss_augs): ####################################### -class LossBuilder: - - LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss} - - def __init__(self, weight_name, weight, name, img, pil_target): - self.weight_name = weight_name - self.weight = weight - self.name = name - self.img = img - self.pil_target = pil_target - - # uh.... should the places this is beind used maybe just use Loss.__init__? - # TO DO: let's make this a class attribute on something - - @property - def weight_category(self): - return self.weight_name.split("_")[0] - - @property - def loss_factory(self): - weight_name = self.weight_category - if weight_name == "direct": - Loss = type(self.img).get_preferred_loss() - else: - Loss = self.LOSS_DICT[weight_name] - return Loss - - def build_loss(self) -> Loss: - """ - Given a weight name, weight, name, image, and target image, returns a loss object - - :param weight_name: The name of the loss function - :param weight: The weight of the loss - :param name: The name of the loss function - :param img: The image to be optimized - :param pil_target: The target image - :return: The loss function. - """ - Loss = self.loss_factory - out = Loss.TargetImage( - f"{self.weight_category} {self.name}:{self.weight}", - self.img.image_shape, - self.pil_target, - ) - out.set_enabled(self.pil_target is not None) - return out +# class LossBuilder: + +# LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss} + +# def __init__(self, weight_name, weight, name, img, pil_target): +# self.weight_name = weight_name +# self.weight = weight +# self.name = name +# self.img = img +# self.pil_target = pil_target + +# # uh.... should the places this is beind used maybe just use Loss.__init__? +# # TO DO: let's make this a class attribute on something + +# @property +# def weight_category(self): +# return self.weight_name.split("_")[0] + +# @property +# def loss_factory(self): +# weight_name = self.weight_category +# if weight_name == "direct": +# Loss = type(self.img).get_preferred_loss() +# else: +# Loss = self.LOSS_DICT[weight_name] +# return Loss + +# def build_loss(self) -> Loss: +# """ +# Given a weight name, weight, name, image, and target image, returns a loss object + +# :param weight_name: The name of the loss function +# :param weight: The weight of the loss +# :param name: The name of the loss function +# :param img: The image to be optimized +# :param pil_target: The target image +# :return: The loss function. +# """ +# Loss = self.loss_factory +# out = Loss.TargetImage( +# f"{self.weight_category} {self.name}:{self.weight}", +# self.img.image_shape, +# self.pil_target, +# ) +# out.set_enabled(self.pil_target is not None) +# return out def _standardize_null(weight): @@ -232,187 +232,187 @@ def _standardize_null(weight): return weight -class LossConfigurator: - """ - Groups together procedures for initializing losses - """ - - def __init__( - self, - init_image_pil: Image.Image, - restore: bool, - img: PixelImage, - embedder, - prompts, - # params, - ######## - direct_image_prompts, - semantic_stabilization_weight, - init_image, - semantic_init_weight, - animation_mode, - flow_stabilization_weight, - flow_long_term_samples, - smoothing_weight, - ########### - direct_init_weight, - direct_stabilization_weight, - depth_stabilization_weight, - edge_stabilization_weight, - ): - self.init_image_pil = init_image_pil - self.img = img - self.embedder = embedder - self.prompts = prompts - - self.init_augs = [] - self.loss_augs = [] - self.optical_flows = [] - self.last_frame_semantic = None - self.semantic_init_prompt = None - - # self.params = params - self.restore = restore - - ### params - self.direct_image_prompts = direct_image_prompts - self.semantic_stabilization_weight = _standardize_null( - semantic_stabilization_weight - ) - self.init_image = init_image - self.semantic_init_weight = _standardize_null(semantic_init_weight) - self.animation_mode = animation_mode - self.flow_stabilization_weight = _standardize_null(flow_stabilization_weight) - self.flow_long_term_samples = flow_long_term_samples - self.smoothing_weight = _standardize_null(smoothing_weight) - - ###### - self.direct_init_weight = _standardize_null(direct_init_weight) - self.direct_stabilization_weight = _standardize_null( - direct_stabilization_weight - ) - self.depth_stabilization_weight = _standardize_null(depth_stabilization_weight) - self.edge_stabilization_weight = _standardize_null(edge_stabilization_weight) - - def process_direct_image_prompts(self): - # prompt parsing shouldn't go here. - self.loss_augs.extend( - type(self.img) - .get_preferred_loss() - .TargetImage(p.strip(), self.img.image_shape, is_path=True) - for p in self.direct_image_prompts.split("|") - if p.strip() - ) - - def process_semantic_stabilization(self): - last_frame_pil = self.init_image_pil - if not last_frame_pil: - last_frame_pil = self.img.decode_image() - self.last_frame_semantic = parse_prompt( - self.embedder, - f"stabilization:{self.semantic_stabilization_weight}", - last_frame_pil, - ) - self.last_frame_semantic.set_enabled(self.init_image_pil is not None) - for scene in self.prompts: - scene.append(self.last_frame_semantic) - - def configure_losses(self): - if self.init_image_pil is not None: - self.configure_init_image() - self.process_direct_image_prompts() - if self.semantic_stabilization_weight: - self.process_semantic_stabilization() - self.configure_stabilization_augs() - self.configure_optical_flows() - self.configure_aesthetic_losses() - - return ( - self.loss_augs, - self.init_augs, - self.stabilization_augs, - self.optical_flows, - self.semantic_init_prompt, - self.last_frame_semantic, - self.img, - ) - - def configure_init_image(self): - - if not self.restore: - # move these logging statements into .encode_image() - logger.info("Encoding image...") - self.img.encode_image(self.init_image_pil) - logger.info("Encoded Image:") - # pretty sure this assumes we're in a notebook - display.display(self.img.decode_image()) - - ## wrap this for the flexibility that the loop is pretending to provide... - # set up init image prompt - if self.direct_init_weight: - init_aug = LossBuilder( - "direct_init_weight", - self.direct_init_weight, - f"init image ({self.init_image})", - self.img, - self.init_image_pil, - ).build_loss() - self.loss_augs.append(init_aug) - self.init_augs.append(init_aug) - - ######## - if self.semantic_init_weight: - self.semantic_init_prompt = parse_prompt( - self.embedder, - f"init image [{self.init_image}]:{self.semantic_init_weight}", - self.init_image_pil, - ) - self.prompts[0].append(self.semantic_init_prompt) - - # stabilization - def configure_stabilization_augs(self): - d_augs = { - "direct_stabilization_weight": self.direct_stabilization_weight, - "depth_stabilization_weight": self.depth_stabilization_weight, - "edge_stabilization_weight": self.edge_stabilization_weight, - } - stabilization_augs = [ - LossBuilder( - k, v, "stabilization", self.img, self.init_image_pil - ).build_loss() - for k, v in d_augs.items() - if v - ] - self.stabilization_augs = stabilization_augs - self.loss_augs.extend(stabilization_augs) - - def configure_optical_flows(self): - optical_flows = None - - if self.animation_mode == "Video Source": - if self.flow_stabilization_weight == "": - self.flow_stabilization_weight = "0" - optical_flows = [ - OpticalFlowLoss.TargetImage( - f"optical flow stabilization (frame {-2**i}):{self.flow_stabilization_weight}", - self.img.image_shape, - ) - for i in range(self.flow_long_term_samples + 1) - ] - - elif self.animation_mode == "3D" and self.flow_stabilization_weight: - optical_flows = [ - TargetFlowLoss.TargetImage( - f"optical flow stabilization:{self.flow_stabilization_weight}", - self.img.image_shape, - ) - ] - - if optical_flows is not None: - for optical_flow in optical_flows: - optical_flow.set_enabled(False) - self.loss_augs.extend(optical_flows) - - def configure_aesthetic_losses(self): - if self.smoothing_weight != 0: - self.loss_augs.append(TVLoss(weight=self.smoothing_weight)) +# class LossConfigurator: +# """ +# Groups together procedures for initializing losses +# """ + +# def __init__( +# self, +# init_image_pil: Image.Image, +# restore: bool, +# img: PixelImage, +# embedder, +# prompts, +# # params, +# ######## +# direct_image_prompts, +# semantic_stabilization_weight, +# init_image, +# semantic_init_weight, +# animation_mode, +# flow_stabilization_weight, +# flow_long_term_samples, +# smoothing_weight, +# ########### +# direct_init_weight, +# direct_stabilization_weight, +# depth_stabilization_weight, +# edge_stabilization_weight, +# ): +# self.init_image_pil = init_image_pil +# self.img = img +# self.embedder = embedder +# self.prompts = prompts + +# self.init_augs = [] +# self.loss_augs = [] +# self.optical_flows = [] +# self.last_frame_semantic = None +# self.semantic_init_prompt = None + +# # self.params = params +# self.restore = restore + +# ### params +# self.direct_image_prompts = direct_image_prompts +# self.semantic_stabilization_weight = _standardize_null( +# semantic_stabilization_weight +# ) +# self.init_image = init_image +# self.semantic_init_weight = _standardize_null(semantic_init_weight) +# self.animation_mode = animation_mode +# self.flow_stabilization_weight = _standardize_null(flow_stabilization_weight) +# self.flow_long_term_samples = flow_long_term_samples +# self.smoothing_weight = _standardize_null(smoothing_weight) + +# ###### +# self.direct_init_weight = _standardize_null(direct_init_weight) +# self.direct_stabilization_weight = _standardize_null( +# direct_stabilization_weight +# ) +# self.depth_stabilization_weight = _standardize_null(depth_stabilization_weight) +# self.edge_stabilization_weight = _standardize_null(edge_stabilization_weight) + +# def process_direct_image_prompts(self): +# # prompt parsing shouldn't go here. +# self.loss_augs.extend( +# type(self.img) +# .get_preferred_loss() +# .TargetImage(p.strip(), self.img.image_shape, is_path=True) +# for p in self.direct_image_prompts.split("|") +# if p.strip() +# ) + +# def process_semantic_stabilization(self): +# last_frame_pil = self.init_image_pil +# if not last_frame_pil: +# last_frame_pil = self.img.decode_image() +# self.last_frame_semantic = parse_prompt( +# self.embedder, +# f"stabilization:{self.semantic_stabilization_weight}", +# last_frame_pil, +# ) +# self.last_frame_semantic.set_enabled(self.init_image_pil is not None) +# for scene in self.prompts: +# scene.append(self.last_frame_semantic) + +# def configure_losses(self): +# if self.init_image_pil is not None: +# self.configure_init_image() +# self.process_direct_image_prompts() +# if self.semantic_stabilization_weight: +# self.process_semantic_stabilization() +# self.configure_stabilization_augs() +# self.configure_optical_flows() +# self.configure_aesthetic_losses() + +# return ( +# self.loss_augs, +# self.init_augs, +# self.stabilization_augs, +# self.optical_flows, +# self.semantic_init_prompt, +# self.last_frame_semantic, +# self.img, +# ) + +# def configure_init_image(self): + +# if not self.restore: +# # move these logging statements into .encode_image() +# logger.info("Encoding image...") +# self.img.encode_image(self.init_image_pil) +# logger.info("Encoded Image:") +# # pretty sure this assumes we're in a notebook +# display.display(self.img.decode_image()) + +# ## wrap this for the flexibility that the loop is pretending to provide... +# # set up init image prompt +# if self.direct_init_weight: +# init_aug = LossBuilder( +# "direct_init_weight", +# self.direct_init_weight, +# f"init image ({self.init_image})", +# self.img, +# self.init_image_pil, +# ).build_loss() +# self.loss_augs.append(init_aug) +# self.init_augs.append(init_aug) + +# ######## +# if self.semantic_init_weight: +# self.semantic_init_prompt = parse_prompt( +# self.embedder, +# f"init image [{self.init_image}]:{self.semantic_init_weight}", +# self.init_image_pil, +# ) +# self.prompts[0].append(self.semantic_init_prompt) + +# # stabilization +# def configure_stabilization_augs(self): +# d_augs = { +# "direct_stabilization_weight": self.direct_stabilization_weight, +# "depth_stabilization_weight": self.depth_stabilization_weight, +# "edge_stabilization_weight": self.edge_stabilization_weight, +# } +# stabilization_augs = [ +# LossBuilder( +# k, v, "stabilization", self.img, self.init_image_pil +# ).build_loss() +# for k, v in d_augs.items() +# if v +# ] +# self.stabilization_augs = stabilization_augs +# self.loss_augs.extend(stabilization_augs) + +# def configure_optical_flows(self): +# optical_flows = None + +# if self.animation_mode == "Video Source": +# if self.flow_stabilization_weight == "": +# self.flow_stabilization_weight = "0" +# optical_flows = [ +# OpticalFlowLoss.TargetImage( +# f"optical flow stabilization (frame {-2**i}):{self.flow_stabilization_weight}", +# self.img.image_shape, +# ) +# for i in range(self.flow_long_term_samples + 1) +# ] + +# elif self.animation_mode == "3D" and self.flow_stabilization_weight: +# optical_flows = [ +# TargetFlowLoss.TargetImage( +# f"optical flow stabilization:{self.flow_stabilization_weight}", +# self.img.image_shape, +# ) +# ] + +# if optical_flows is not None: +# for optical_flow in optical_flows: +# optical_flow.set_enabled(False) +# self.loss_augs.extend(optical_flows) + +# def configure_aesthetic_losses(self): +# if self.smoothing_weight != 0: +# self.loss_augs.append(TVLoss(weight=self.smoothing_weight)) diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index 9a8eda6..9bb8803 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -52,7 +52,8 @@ vram_profiling, ) from pytti.LossAug.DepthLossClass import init_AdaBins -from pytti.LossAug.LossOrchestratorClass import LossConfigurator + +# from pytti.LossAug.LossOrchestratorClass import LossConfigurator logger.info("pytti loaded.") From 09fc46b597811d95f8c4c54873d7ad11b6f4aeba Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 17 Jun 2022 16:32:23 -0700 Subject: [PATCH 16/19] cleaned up a deprecated code --- src/pytti/LossAug/LossOrchestratorClass.py | 243 +-------------------- src/pytti/LossAug/MSELossClass.py | 3 +- src/pytti/LossAug/OpticalFlowLossClass.py | 30 +-- src/pytti/image_models/pixel.py | 6 - src/pytti/workhorse.py | 85 +------ 5 files changed, 12 insertions(+), 355 deletions(-) diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index d9e1657..6082511 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -4,7 +4,6 @@ from pytti.image_models import PixelImage, RGBImage -# from pytti.LossAug import build_loss from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss from pytti.Perceptor.Prompt import parse_prompt from pytti.eval_tools import parse_subprompt @@ -30,7 +29,6 @@ def build_loss( pil_target: Image, device=None, ): - # from pytti.LossAug import LOSS_DICT if device is None: device = img.device @@ -39,9 +37,7 @@ def build_loss( loss = type(img).get_preferred_loss() else: loss = LOSS_DICT[weight_name] - # out = Loss.TargetImage( - # f"{weight_name} {name}:{weight}", img.image_shape, pil_target - # ) + if pil_target is not None: resized = pil_target.resize(img.image_shape, Image.LANCZOS) comp = loss.make_comp(resized, device=device) @@ -172,57 +168,6 @@ def configure_optical_flows(img, params, loss_augs): return img, loss_augs, optical_flows -####################################### - - -# class LossBuilder: - -# LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss} - -# def __init__(self, weight_name, weight, name, img, pil_target): -# self.weight_name = weight_name -# self.weight = weight -# self.name = name -# self.img = img -# self.pil_target = pil_target - -# # uh.... should the places this is beind used maybe just use Loss.__init__? -# # TO DO: let's make this a class attribute on something - -# @property -# def weight_category(self): -# return self.weight_name.split("_")[0] - -# @property -# def loss_factory(self): -# weight_name = self.weight_category -# if weight_name == "direct": -# Loss = type(self.img).get_preferred_loss() -# else: -# Loss = self.LOSS_DICT[weight_name] -# return Loss - -# def build_loss(self) -> Loss: -# """ -# Given a weight name, weight, name, image, and target image, returns a loss object - -# :param weight_name: The name of the loss function -# :param weight: The weight of the loss -# :param name: The name of the loss function -# :param img: The image to be optimized -# :param pil_target: The target image -# :return: The loss function. -# """ -# Loss = self.loss_factory -# out = Loss.TargetImage( -# f"{self.weight_category} {self.name}:{self.weight}", -# self.img.image_shape, -# self.pil_target, -# ) -# out.set_enabled(self.pil_target is not None) -# return out - - def _standardize_null(weight): weight = str(weight).strip() if weight in ("", "None"): @@ -230,189 +175,3 @@ def _standardize_null(weight): if float(weight) == 0: weight = "" return weight - - -# class LossConfigurator: -# """ -# Groups together procedures for initializing losses -# """ - -# def __init__( -# self, -# init_image_pil: Image.Image, -# restore: bool, -# img: PixelImage, -# embedder, -# prompts, -# # params, -# ######## -# direct_image_prompts, -# semantic_stabilization_weight, -# init_image, -# semantic_init_weight, -# animation_mode, -# flow_stabilization_weight, -# flow_long_term_samples, -# smoothing_weight, -# ########### -# direct_init_weight, -# direct_stabilization_weight, -# depth_stabilization_weight, -# edge_stabilization_weight, -# ): -# self.init_image_pil = init_image_pil -# self.img = img -# self.embedder = embedder -# self.prompts = prompts - -# self.init_augs = [] -# self.loss_augs = [] -# self.optical_flows = [] -# self.last_frame_semantic = None -# self.semantic_init_prompt = None - -# # self.params = params -# self.restore = restore - -# ### params -# self.direct_image_prompts = direct_image_prompts -# self.semantic_stabilization_weight = _standardize_null( -# semantic_stabilization_weight -# ) -# self.init_image = init_image -# self.semantic_init_weight = _standardize_null(semantic_init_weight) -# self.animation_mode = animation_mode -# self.flow_stabilization_weight = _standardize_null(flow_stabilization_weight) -# self.flow_long_term_samples = flow_long_term_samples -# self.smoothing_weight = _standardize_null(smoothing_weight) - -# ###### -# self.direct_init_weight = _standardize_null(direct_init_weight) -# self.direct_stabilization_weight = _standardize_null( -# direct_stabilization_weight -# ) -# self.depth_stabilization_weight = _standardize_null(depth_stabilization_weight) -# self.edge_stabilization_weight = _standardize_null(edge_stabilization_weight) - -# def process_direct_image_prompts(self): -# # prompt parsing shouldn't go here. -# self.loss_augs.extend( -# type(self.img) -# .get_preferred_loss() -# .TargetImage(p.strip(), self.img.image_shape, is_path=True) -# for p in self.direct_image_prompts.split("|") -# if p.strip() -# ) - -# def process_semantic_stabilization(self): -# last_frame_pil = self.init_image_pil -# if not last_frame_pil: -# last_frame_pil = self.img.decode_image() -# self.last_frame_semantic = parse_prompt( -# self.embedder, -# f"stabilization:{self.semantic_stabilization_weight}", -# last_frame_pil, -# ) -# self.last_frame_semantic.set_enabled(self.init_image_pil is not None) -# for scene in self.prompts: -# scene.append(self.last_frame_semantic) - -# def configure_losses(self): -# if self.init_image_pil is not None: -# self.configure_init_image() -# self.process_direct_image_prompts() -# if self.semantic_stabilization_weight: -# self.process_semantic_stabilization() -# self.configure_stabilization_augs() -# self.configure_optical_flows() -# self.configure_aesthetic_losses() - -# return ( -# self.loss_augs, -# self.init_augs, -# self.stabilization_augs, -# self.optical_flows, -# self.semantic_init_prompt, -# self.last_frame_semantic, -# self.img, -# ) - -# def configure_init_image(self): - -# if not self.restore: -# # move these logging statements into .encode_image() -# logger.info("Encoding image...") -# self.img.encode_image(self.init_image_pil) -# logger.info("Encoded Image:") -# # pretty sure this assumes we're in a notebook -# display.display(self.img.decode_image()) - -# ## wrap this for the flexibility that the loop is pretending to provide... -# # set up init image prompt -# if self.direct_init_weight: -# init_aug = LossBuilder( -# "direct_init_weight", -# self.direct_init_weight, -# f"init image ({self.init_image})", -# self.img, -# self.init_image_pil, -# ).build_loss() -# self.loss_augs.append(init_aug) -# self.init_augs.append(init_aug) - -# ######## -# if self.semantic_init_weight: -# self.semantic_init_prompt = parse_prompt( -# self.embedder, -# f"init image [{self.init_image}]:{self.semantic_init_weight}", -# self.init_image_pil, -# ) -# self.prompts[0].append(self.semantic_init_prompt) - -# # stabilization -# def configure_stabilization_augs(self): -# d_augs = { -# "direct_stabilization_weight": self.direct_stabilization_weight, -# "depth_stabilization_weight": self.depth_stabilization_weight, -# "edge_stabilization_weight": self.edge_stabilization_weight, -# } -# stabilization_augs = [ -# LossBuilder( -# k, v, "stabilization", self.img, self.init_image_pil -# ).build_loss() -# for k, v in d_augs.items() -# if v -# ] -# self.stabilization_augs = stabilization_augs -# self.loss_augs.extend(stabilization_augs) - -# def configure_optical_flows(self): -# optical_flows = None - -# if self.animation_mode == "Video Source": -# if self.flow_stabilization_weight == "": -# self.flow_stabilization_weight = "0" -# optical_flows = [ -# OpticalFlowLoss.TargetImage( -# f"optical flow stabilization (frame {-2**i}):{self.flow_stabilization_weight}", -# self.img.image_shape, -# ) -# for i in range(self.flow_long_term_samples + 1) -# ] - -# elif self.animation_mode == "3D" and self.flow_stabilization_weight: -# optical_flows = [ -# TargetFlowLoss.TargetImage( -# f"optical flow stabilization:{self.flow_stabilization_weight}", -# self.img.image_shape, -# ) -# ] - -# if optical_flows is not None: -# for optical_flow in optical_flows: -# optical_flow.set_enabled(False) -# self.loss_augs.extend(optical_flows) - -# def configure_aesthetic_losses(self): -# if self.smoothing_weight != 0: -# self.loss_augs.append(TVLoss(weight=self.smoothing_weight)) diff --git a/src/pytti/LossAug/MSELossClass.py b/src/pytti/LossAug/MSELossClass.py index 71a4d46..ce308d9 100644 --- a/src/pytti/LossAug/MSELossClass.py +++ b/src/pytti/LossAug/MSELossClass.py @@ -4,10 +4,9 @@ from torch.nn import functional as F from pytti.LossAug.BaseLossClass import Loss -# from pytti.Notebook import Rotoscoper from pytti.rotoscoper import Rotoscoper from pytti import fetch, vram_usage_mode -from pytti.eval_tools import parse, parse_subprompt +from pytti.eval_tools import parse_subprompt import torch diff --git a/src/pytti/LossAug/OpticalFlowLossClass.py b/src/pytti/LossAug/OpticalFlowLossClass.py index 1619349..0edc961 100644 --- a/src/pytti/LossAug/OpticalFlowLossClass.py +++ b/src/pytti/LossAug/OpticalFlowLossClass.py @@ -16,10 +16,8 @@ import gma from gma.core.network import RAFTGMA -# from gma.core.utils import flow_viz from gma.core.utils.utils import InputPadder -# from pytti import fetch, to_pil, DEVICE, vram_usage_mode from pytti import fetch, vram_usage_mode from pytti.LossAug.MSELossClass import MSELoss from pytti.rotoscoper import Rotoscoper @@ -100,7 +98,6 @@ def init_GMA(checkpoint_path=None, device=None): args = parser.parse_args([]) # create new OrderedDict that does not contain `module.` prefix - # state_dict = torch.load(checkpoint_path) state_dict = torch.load(checkpoint_path, map_location=device) from collections import OrderedDict @@ -110,18 +107,9 @@ def init_GMA(checkpoint_path=None, device=None): k = k[7:] # remove `module.` new_state_dict[k] = v - # GMA = torch.nn.DataParallel(RAFTGMA(args), device_ids=[device]) GMA = RAFTGMA(args) - # GMA = torch.nn.parallel.DistributedDataParallel(RAFTGMA(args).to(device), device_ids=[device]) - # GMA = RAFTGMA(args) - # GMA.load_state_dict(torch.load(checkpoint_path, map_location=device)) - # GMA.load_state_dict(torch.load(checkpoint_path)) GMA.load_state_dict(new_state_dict) logger.debug("gma state_dict loaded") - ########################### - # 1. Fix state dict (remove module prefixes) - # 2. load state dict into model without DataParallel - ########################### GMA.to(device) # redundant? GMA.eval() @@ -209,7 +197,6 @@ def get_loss(self, input, img, device=None): if device is None: device = getattr(self, "device", self.device) init_GMA( - # "GMA/checkpoints/gma-sintel.pth" device=device, ) # update this to use model dir from config image1 = self.last_step @@ -220,8 +207,6 @@ def get_loss(self, input, img, device=None): logger.debug(device) logger.debug((flow.shape, flow.device)) logger.debug((self.comp.shape, self.comp.device)) - # logger.debug(GMA.device) # ugh... I bet this is another dataparallel thing. - # logger.debug(GMA.module.device) flow = flow.to(device, memory_format=torch.channels_last) return super().get_loss(TF.resize(flow, self.comp.shape[-2:]), img) / self.mag @@ -232,9 +217,9 @@ class OpticalFlowLoss(MSELoss): def motion_edge_map( flow_forward, flow_backward, - img, # is this even being used anywhere here? - border_mode="smear", - sampling_mode="bilinear", + img, # unused + border_mode="smear", # unused + sampling_mode="bilinear", # unused device=None, ): """ @@ -325,7 +310,6 @@ def get_flow(image1, image2, device=None): """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" - # init_GMA("GMA/checkpoints/gma-sintel.pth") init_GMA( device=device, ) @@ -386,11 +370,9 @@ def set_flow( ) logger.debug(device) if path is not None: - # img = img.clone() img = img.clone().to(device) if not isinstance(device, torch.device): device = torch.device(device) - # logger.debug(device) state_dict = torch.load(path, map_location=device) img.load_state_dict(state_dict) @@ -412,13 +394,9 @@ def set_flow( image1.add_(noise) image2.add_(noise) - # flow_forward = OpticalFlowLoss.get_flow(image1, image2) - # flow_backward = OpticalFlowLoss.get_flow(image2, image1) - # flow_forward = self.get_flow(image1, image2, device=device) - # flow_backward = self.get_flow(image2, image1, device=device) flow_forward = OpticalFlowLoss.get_flow(image1, image2, device=device) flow_backward = OpticalFlowLoss.get_flow(image2, image1, device=device) - unwarped_target_direct = img.decode_tensor() + unwarped_target_direct = img.decode_tensor() # unused flow_target_direct = apply_flow( img, -flow_backward, border_mode=border_mode, sampling_mode=sampling_mode ) diff --git a/src/pytti/image_models/pixel.py b/src/pytti/image_models/pixel.py index 5b0babf..0beb892 100644 --- a/src/pytti/image_models/pixel.py +++ b/src/pytti/image_models/pixel.py @@ -2,7 +2,6 @@ from pytti.image_models.differentiable_image import DifferentiableImage from pytti.LossAug.HSVLossClass import HSVLoss -# from pytti.ImageGuide import DirectImageGuide import numpy as np import torch, math from torch import nn, optim @@ -191,7 +190,6 @@ def __init__( .view(pallet_size, 1, 1) .repeat(1, n_pallets, 3) ) - # pallet.set_(torch.rand_like(pallet)*self.pallet_inertia) self.pallet = nn.Parameter(pallet.to(self.device)) self.pallet_size = pallet_size @@ -429,7 +427,6 @@ def encode_image(self, pil_image, smart_encode=True, device=None): scale = self.scale color_ref = pil_image.resize((width // scale, height // scale), Image.LANCZOS) color_ref = TF.to_tensor(color_ref).to(device) - # value_ref = ImageOps.grayscale(color_ref) with torch.no_grad(): # https://alienryderflex.com/hsp.html magic_color = self.pallet.new_tensor([[[0.299]], [[0.587]], [[0.114]]]) @@ -440,13 +437,10 @@ def encode_image(self, pil_image, smart_encode=True, device=None): # no embedder needed without any prompts if smart_encode: - # mse = HSVLoss.TargetImage("HSV loss", self.image_shape, pil_image) - # im = pil_image.resize(image_shape, Image.LANCZOS) comp = HSVLoss.make_comp(pil_image) mse = HSVLoss( comp=comp, name="HSV loss", - #image_shape=pil_image.shape, image_shape=self.image_shape, device=device, ) diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index 9bb8803..05c0763 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -53,7 +53,6 @@ ) from pytti.LossAug.DepthLossClass import init_AdaBins -# from pytti.LossAug.LossOrchestratorClass import LossConfigurator logger.info("pytti loaded.") @@ -69,7 +68,6 @@ TB_LOGDIR = "logs" # to do: make this more easily configurable -# writer = SummaryWriter(TB_LOGDIR) OUTPATH = f"{os.getcwd()}/images_out/" ####################################################### @@ -80,24 +78,7 @@ configure_optical_flows, ) -####################################################### - -# To do: ove remaining gunk into this... -# class Renderer: -# """ -# High-level orchestrator for pytti rendering procedure. -# """ -# -# def __init__(self, params): -# pass - - -# this is the only place `parse_prompt` is invoked. -# combine load_scenes, parse_prompt, and parse into a unified, generic parser. -# generic here means the output of the parsing process shouldn't be bound to -# modules yet, just a collection of settings. -# -# ...actually, parse_prompt is invoked in loss orchestration +# move this with the other "prompt" functions... def parse_scenes( embedder, scenes, @@ -175,8 +156,6 @@ def load_video_source( pre_animation_steps = max(steps_per_frame, pre_animation_steps) if init_image_pil is None: init_image_pil = Image.fromarray(video_frames.get_data(0)).convert("RGB") - # enhancer = ImageEnhance.Contrast(init_image_pil) - # init_image_pil = enhancer.enhance(2) init_size = init_image_pil.size if width == -1: width = int(height * init_size[0] / init_size[1]) @@ -187,7 +166,6 @@ def load_video_source( @hydra.main(config_path="config", config_name="default") def _main(cfg: DictConfig): - # params = OmegaConf.to_container(cfg, resolve=True) params = cfg if torch.cuda.is_available(): @@ -213,15 +191,12 @@ def _main(cfg: DictConfig): if params.use_tensorboard: writer = SummaryWriter(TB_LOGDIR) - batch_mode = False # @param{type:"boolean"} + batch_mode = False ### Move these into default.yaml - # @markdown check `restore` to restore from a previous run - restore = params.get("restore") or False # @param{type:"boolean"} - # @markdown check `reencode` if you are restoring with a modified image or modified image settings - reencode = False # @param{type:"boolean"} - # @markdown which run to restore - restore_run = latest # @param{type:"raw"} + restore = params.get("restore") or False + reencode = False + restore_run = latest # which run to restore # NB: `backup/` dir probably not working at present if restore and restore_run == latest: @@ -235,13 +210,10 @@ def do_run(): # Phase 1 - reset state ######################## - # clear_rotoscopers() # what a silly name ROTOSCOPERS.clear_rotoscopers() vram_profiling(params.approximate_vram_usage) reset_vram_usage() - # global CLIP_MODEL_NAMES # we don't do anything with this... - # @markdown which frame to restore from - restore_frame = latest # @param{type:"raw"} + restore_frame = latest # which frame to restore from # set up seed for deterministic RNG if params.seed is not None: @@ -386,28 +358,18 @@ def do_run(): # other image prompts - # loss_augs.extend( - # type(img) - # .get_preferred_loss() - # .TargetImage(p.strip(), img.image_shape, is_path=True) - # for p in params.direct_image_prompts.split("|") - # if p.strip() - # ) - # uh... I'm not sure I actually test direct_image_prompts anywhere. ah well. fuck it. for p in params.direct_image_prompts.split("|"): prompt_string = p.strip() if prompt_string: loss_factory = type(img).get_preferred_loss() text, weight, stop, mask, pil_image = parse_subprompt( - # prompt_string, is_path=True, pil_image=pil_image prompt_string, is_path=True, pil_image=init_image_pil, ) image_shape = img.image_shape if pil_image: - # im = pil_image.resize(image_shape, Image.LANCZOS) im = pil_image.resize(image_shape, Image.LANCZOS) comp = loss_factory.make_comp(im) else: @@ -452,41 +414,6 @@ def do_run(): # optical flow img, loss_augs, optical_flows = configure_optical_flows(img, params, loss_augs) - # # set up losses - # loss_orch = LossConfigurator( - # init_image_pil=init_image_pil, - # restore=restore, - # img=img, - # embedder=embedder, - # prompts=prompts, - # # params=params, - # ######## - # # To do: group arguments into param groups - # animation_mode=params.animation_mode, - # init_image=params.init_image, - # direct_image_prompts=params.direct_image_prompts, - # semantic_init_weight=params.semantic_init_weight, - # semantic_stabilization_weight=params.semantic_stabilization_weight, - # flow_stabilization_weight=params.flow_stabilization_weight, - # flow_long_term_samples=params.flow_long_term_samples, - # smoothing_weight=params.smoothing_weight, - # ########### - # direct_init_weight=params.direct_init_weight, - # direct_stabilization_weight=params.direct_stabilization_weight, - # depth_stabilization_weight=params.depth_stabilization_weight, - # edge_stabilization_weight=params.edge_stabilization_weight, - # ) - - # ( - # loss_augs, - # init_augs, - # stabilization_augs, - # optical_flows, - # semantic_init_prompt, - # last_frame_semantic, - # img, - # ) = loss_orch.configure_losses() - # Phase 4 - setup outputs ########################## From f73e60410d897852976cb23bb40be71543a72fe9 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 17 Jun 2022 16:36:30 -0700 Subject: [PATCH 17/19] removed now unused TargetImage method defs --- src/pytti/LossAug/LatentLossClass.py | 25 ------------------------- src/pytti/LossAug/MSELossClass.py | 24 ------------------------ src/pytti/LossAug/__init__.py | 7 ------- 3 files changed, 56 deletions(-) diff --git a/src/pytti/LossAug/LatentLossClass.py b/src/pytti/LossAug/LatentLossClass.py index 8f267b8..3e8b324 100644 --- a/src/pytti/LossAug/LatentLossClass.py +++ b/src/pytti/LossAug/LatentLossClass.py @@ -31,31 +31,6 @@ def set_comp(self, pil_image, device=DEVICE): self.has_latent = False self.direct_loss.set_comp(pil_image.resize(self.image_shape, Image.LANCZOS)) - @classmethod - @vram_usage_mode("Latent Image Loss") - @torch.no_grad() - def TargetImage( - cls, prompt_string, image_shape, pil_image=None, is_path=False, device=DEVICE - ): - text, weight, stop = parse( - prompt_string, r"(? Date: Fri, 17 Jun 2022 16:42:21 -0700 Subject: [PATCH 18/19] added some notes and removed more dead code --- src/pytti/LossAug/DepthLossClass.py | 1 - src/pytti/LossAug/LatentLossClass.py | 1 + src/pytti/LossAug/MSELossClass.py | 2 ++ 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pytti/LossAug/DepthLossClass.py b/src/pytti/LossAug/DepthLossClass.py index 0428806..501a31d 100644 --- a/src/pytti/LossAug/DepthLossClass.py +++ b/src/pytti/LossAug/DepthLossClass.py @@ -29,7 +29,6 @@ def init_AdaBins(device=None): class DepthLoss(MSELoss): @torch.no_grad() def set_comp(self, pil_image): - # pil_image = pil_image.resize(self.image_shape, Image.LANCZOS) self.comp.set_(DepthLoss.make_comp(pil_image)) if self.use_mask and self.mask.shape[-2:] != self.comp.shape[-2:]: self.mask.set_(TF.resize(self.mask, self.comp.shape[-2:])) diff --git a/src/pytti/LossAug/LatentLossClass.py b/src/pytti/LossAug/LatentLossClass.py index 3e8b324..a2612b3 100644 --- a/src/pytti/LossAug/LatentLossClass.py +++ b/src/pytti/LossAug/LatentLossClass.py @@ -25,6 +25,7 @@ def __init__( TF.resize(comp.clone(), (h, w)), weight, stop, name, image_shape ) + # Comp and mask should live on the image representation, not the loss class. @torch.no_grad() def set_comp(self, pil_image, device=DEVICE): self.pil_image = pil_image diff --git a/src/pytti/LossAug/MSELossClass.py b/src/pytti/LossAug/MSELossClass.py index 80ab9be..8a1f33e 100644 --- a/src/pytti/LossAug/MSELossClass.py +++ b/src/pytti/LossAug/MSELossClass.py @@ -58,6 +58,8 @@ def set_mask(self, mask, inverted=False, device=None): def convert_input(cls, input, img): return input + # Comp and mask should live on the image representation, not the loss class. + # comp for sure @classmethod def make_comp(cls, pil_image, device=None): if device is None: From d0bf63780dd6003183f4fab0ae6a122c5fb42a7c Mon Sep 17 00:00:00 2001 From: David Marx Date: Mon, 20 Jun 2022 10:14:13 -0700 Subject: [PATCH 19/19] reorganized configure_optical_flows to not initialize for null flow weight --- src/pytti/LossAug/LossOrchestratorClass.py | 38 ++++++++++++---------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index 6082511..0151500 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -125,22 +125,31 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs): def configure_optical_flows(img, params, loss_augs): logger.debug(params.device) _device = params.device + + # this shouldn't be in this function based on the name. + # other loss augs + if params.smoothing_weight != 0: + loss_augs.append( + TVLoss(weight=params.smoothing_weight) + ) # , device=params.device)) + optical_flows = [] if params.animation_mode == "Video Source": if params.flow_stabilization_weight == "": params.flow_stabilization_weight = "0" # TODO: if flow stabilization weight is 0, shouldn't this next block just get skipped? - - for i in range(params.flow_long_term_samples + 1): - optical_flow = OpticalFlowLoss( - comp=torch.zeros(1, 1, 1, 1, device=_device), # ,device=DEVICE) - weight=params.flow_stabilization_weight, - name=f"optical flow stabilization (frame {-2**i}) (direct)", - image_shape=img.image_shape, - device=_device, - ) # , device=device) - optical_flow.set_enabled(False) - optical_flows.append(optical_flow) + if params.flow_stabilization_weight != "0": + # TO DO: if weight is parameterized, need to do a parameteric evaluation here. + for i in range(params.flow_long_term_samples + 1): + optical_flow = OpticalFlowLoss( + comp=torch.zeros(1, 1, 1, 1, device=_device), # ,device=DEVICE) + weight=params.flow_stabilization_weight, + name=f"optical flow stabilization (frame {-2**i}) (direct)", + image_shape=img.image_shape, + device=_device, + ) # , device=device) + optical_flow.set_enabled(False) + optical_flows.append(optical_flow) elif params.animation_mode == "3D" and params.flow_stabilization_weight not in [ "0", @@ -158,13 +167,6 @@ def configure_optical_flows(img, params, loss_augs): loss_augs.extend(optical_flows) - # this shouldn't be in this function based on the name. - # other loss augs - if params.smoothing_weight != 0: - loss_augs.append( - TVLoss(weight=params.smoothing_weight) - ) # , device=params.device)) - return img, loss_augs, optical_flows