diff --git a/examples/segmentation.py b/examples/segmentation.py index e73a129..c195a65 100644 --- a/examples/segmentation.py +++ b/examples/segmentation.py @@ -44,7 +44,7 @@ mlp_classifier = MLPClassifier(num_classes=num_classes, num_parts=num_parts) # Create the multi-view lifting module -mvlifting = MVLiftingModule(image_size=224, lifting_method='mode', mlp_classifier=mlp_classifier, balanced_object_loss=True, balanced_3d_loss_alpha=0, lifting_net=None, use_early_voint_feats=False).cuda() +mvlifting = MVLiftingModule(image_size=224, lifting_method='mode', mlp_classifier=mlp_classifier, balanced_object_loss=True, balanced_3d_loss_alpha=0, lifting_net=torch.nn.Sequential(), use_early_voint_feats=False).cuda() # Create loss function for training criterion = torch.nn.CrossEntropyLoss() @@ -120,4 +120,4 @@ def calc_loss_and_3d_pred(rendered_images, cls, labels_2d, azim, elev, points, i running_loss += loss.item() if (i + 1) % int(len(test_loader) * 0.25) == 0: print(f"\tBatch {i + 1}/{len(test_loader)}: Current Average Test Loss = {(running_loss / (i + 1)):.5f}") - print(f"Total Average Test Loss = {(running_loss / len(test_loader)):.5f}") \ No newline at end of file + print(f"Total Average Test Loss = {(running_loss / len(test_loader)):.5f}")