Skip to content

Commit

Permalink
Update segmentation.py
Browse files Browse the repository at this point in the history
issue #12
  • Loading branch information
ajhamdi authored Nov 29, 2023
1 parent 7131fbd commit 8a4ca2f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
mlp_classifier = MLPClassifier(num_classes=num_classes, num_parts=num_parts)

# Create the multi-view lifting module
mvlifting = MVLiftingModule(image_size=224, lifting_method='mode', mlp_classifier=mlp_classifier, balanced_object_loss=True, balanced_3d_loss_alpha=0, lifting_net=None, use_early_voint_feats=False).cuda()
mvlifting = MVLiftingModule(image_size=224, lifting_method='mode', mlp_classifier=mlp_classifier, balanced_object_loss=True, balanced_3d_loss_alpha=0, lifting_net=torch.nn.Sequential(), use_early_voint_feats=False).cuda()

# Create loss function for training
criterion = torch.nn.CrossEntropyLoss()
Expand Down Expand Up @@ -120,4 +120,4 @@ def calc_loss_and_3d_pred(rendered_images, cls, labels_2d, azim, elev, points, i
running_loss += loss.item()
if (i + 1) % int(len(test_loader) * 0.25) == 0:
print(f"\tBatch {i + 1}/{len(test_loader)}: Current Average Test Loss = {(running_loss / (i + 1)):.5f}")
print(f"Total Average Test Loss = {(running_loss / len(test_loader)):.5f}")
print(f"Total Average Test Loss = {(running_loss / len(test_loader)):.5f}")

0 comments on commit 8a4ca2f

Please sign in to comment.