Skip to content

Commit

Permalink
add len to RandomIterableDataset test util (#859)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #859

Reviewed By: diego-urgell

Differential Revision: D59363709

fbshipit-source-id: 0ed7a515c6a87ec0e42897f56ec8b583e254bd71
  • Loading branch information
galrotem authored and facebook-github-bot committed Jul 8, 2024
1 parent 58b6ea7 commit 60059ef
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/utils/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
# pyre-strict

import unittest
from typing import Iterable, Iterator
from unittest.mock import patch

from torchtnt.framework._test_utils import (
generate_random_dataloader,
generate_random_iterable_dataloader,
)
from torchtnt.framework._test_utils import generate_random_dataloader

from torchtnt.utils.progress import (
estimated_steps_in_epoch,
Expand Down Expand Up @@ -274,12 +272,14 @@ def test_estimated_steps_in_fit(self) -> None:
)

def test_estimate_epoch_without_len(self) -> None:
dataloader = generate_random_iterable_dataloader(
num_samples=10, input_dim=2, batch_size=2
)
class IterableWithoutLen(Iterable):
def __iter__(self) -> Iterator[int]:
for _ in range(5):
yield 1

self.assertEqual(
estimated_steps_in_epoch(
dataloader,
IterableWithoutLen(),
num_steps_completed=0,
max_steps=None,
max_steps_per_epoch=None,
Expand Down
3 changes: 3 additions & 0 deletions torchtnt/framework/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def __iter__(self) -> Iterator[Tensor]:
for _ in range(self.count):
yield torch.randn(self.size)

def __len__(self) -> int:
return self.count


def generate_random_iterable_dataloader(
num_samples: int, input_dim: int, batch_size: int
Expand Down

0 comments on commit 60059ef

Please sign in to comment.