diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 7053361f795..22046df75e2 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -131,7 +131,7 @@ def per_device_loader(self, device): return PerDeviceLoader(self, torch.device(device)) def per_device_samples(self): - return len(self._loader) // len(self._devices) + return len(self._cpu_loader) // len(self._devices) def next_item(self, device): dqueue = self._queues[device]