Skip to content

Commit

Permalink
reorganized configure_optical_flows to not initialize for null flow w…
Browse files Browse the repository at this point in the history
…eight
  • Loading branch information
dmarx committed Jun 20, 2022
1 parent 0e77611 commit 5ec3e59
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions src/pytti/LossAug/LossOrchestratorClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,31 @@ def configure_stabilization_augs(img, init_image_pil, params, loss_augs):
def configure_optical_flows(img, params, loss_augs):
logger.debug(params.device)
_device = params.device

# this shouldn't be in this function based on the name.
# other loss augs
if params.smoothing_weight != 0:
loss_augs.append(
TVLoss(weight=params.smoothing_weight)
) # , device=params.device))

optical_flows = []
if params.animation_mode == "Video Source":
if params.flow_stabilization_weight == "":
params.flow_stabilization_weight = "0"
# TODO: if flow stabilization weight is 0, shouldn't this next block just get skipped?

for i in range(params.flow_long_term_samples + 1):
optical_flow = OpticalFlowLoss(
comp=torch.zeros(1, 1, 1, 1, device=_device), # ,device=DEVICE)
weight=params.flow_stabilization_weight,
name=f"optical flow stabilization (frame {-2**i}) (direct)",
image_shape=img.image_shape,
device=_device,
) # , device=device)
optical_flow.set_enabled(False)
optical_flows.append(optical_flow)
if params.flow_stabilization_weight != "0":
# TO DO: if weight is parameterized, need to do a parameteric evaluation here.
for i in range(params.flow_long_term_samples + 1):
optical_flow = OpticalFlowLoss(
comp=torch.zeros(1, 1, 1, 1, device=_device), # ,device=DEVICE)
weight=params.flow_stabilization_weight,
name=f"optical flow stabilization (frame {-2**i}) (direct)",
image_shape=img.image_shape,
device=_device,
) # , device=device)
optical_flow.set_enabled(False)
optical_flows.append(optical_flow)

elif params.animation_mode == "3D" and params.flow_stabilization_weight not in [
"0",
Expand All @@ -158,13 +167,6 @@ def configure_optical_flows(img, params, loss_augs):

loss_augs.extend(optical_flows)

# this shouldn't be in this function based on the name.
# other loss augs
if params.smoothing_weight != 0:
loss_augs.append(
TVLoss(weight=params.smoothing_weight)
) # , device=params.device))

return img, loss_augs, optical_flows


Expand Down

0 comments on commit 5ec3e59

Please sign in to comment.