Skip to content

Commit

Permalink
throughput logger - log per epoch
Browse files Browse the repository at this point in the history
Differential Revision: D56498952
  • Loading branch information
galrotem authored and facebook-github-bot committed Apr 24, 2024
1 parent 6f3bdfa commit 5bb592f
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 16 deletions.
127 changes: 112 additions & 15 deletions tests/framework/callbacks/test_throughput_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import unittest
from unittest.mock import ANY, call, MagicMock
from unittest.mock import ANY, call, MagicMock, patch

import torch
from pyre_extensions import none_throws
Expand All @@ -17,13 +17,15 @@
from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyPredictUnit,
DummyTrainUnit,
generate_random_dataloader,
)
from torchtnt.framework.callbacks.throughput_logger import ThroughputLogger
from torchtnt.framework.predict import predict

from torchtnt.framework.state import EntryPoint, PhaseState, State
from torchtnt.framework.train import _train_impl
from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.framework.train import _train_impl, train
from torchtnt.framework.unit import TrainUnit
from torchtnt.utils.loggers.logger import MetricLogger


Expand Down Expand Up @@ -121,21 +123,18 @@ def test_with_comparing_time(self) -> None:
evaluate_every_n_epochs=2,
),
)
throughput_logger = ThroughputLogger(
logger=logger,
throughput_per_batch={"Batches": 1, "Queries": 8},
log_every_n_steps=1,
)

# we want to be able to compare the logging value to the state, so we need to create state manually and
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
_train_impl(
state,
DummyAutoUnit(module=torch.nn.Linear(2, 2)),
CallbackHandler(
[
ThroughputLogger(
logger=logger,
throughput_per_batch={"Batches": 1, "Queries": 8},
log_every_n_steps=1,
)
],
),
CallbackHandler([throughput_logger]),
)

train_iteration_times = none_throws(
Expand Down Expand Up @@ -163,8 +162,8 @@ def test_with_comparing_time(self) -> None:
eval_iteration_times[i] + eval_twfb_times[i] for i in range(2)
]
self.assertEqual(
logger.log.call_count, 12
) # 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2)
logger.log.call_count, 18
) # steps: 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2). epochs: 4 train (2epoch x 2items). 2 eval (1x2)
train_batches_step_logs = [
call(
"Train: Batches per second (step granularity)",
Expand Down Expand Up @@ -197,11 +196,36 @@ def test_with_comparing_time(self) -> None:
)
for i in range(2)
]
# for epoch, we test the logged value separately
train_batches_epoch_logs = [
call("Train: Batches per second (epoch granularity)", ANY, i)
for i in range(1, 3)
]
train_queries_epoch_logs = [
call("Train: Queries per second (epoch granularity)", ANY, i)
for i in range(1, 3)
]
eval_epoch_logs = [
call(
"Eval: Queries per second (epoch granularity)",
ANY,
1,
),
call(
"Eval: Batches per second (epoch granularity)",
ANY,
1,
),
]

logger.log.assert_has_calls(
train_batches_step_logs
+ train_queries_step_logs
+ eval_batches_step_logs
+ eval_queries_step_logs,
+ eval_queries_step_logs
+ train_batches_epoch_logs
+ train_queries_epoch_logs
+ eval_epoch_logs,
any_order=True,
)

Expand All @@ -227,6 +251,79 @@ def test_with_predict(self) -> None:
1,
)
],
[
call(
"Predict: Batches per second (epoch granularity)",
ANY,
1,
)
],
)

def test_log_for_epoch(self) -> None:
logger = MagicMock(spec=MetricLogger)
unit = DummyTrainUnit(input_dim=2)
throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Queries": 8})
state = State(entry_point=EntryPoint.TRAIN)

self.assertIsNone(throughput_logger._epoch_start_times.get(ActivePhase.TRAIN))
self.assertEqual(throughput_logger._steps_in_epoch[ActivePhase.TRAIN], 0)
with patch.object(throughput_logger, "_maybe_log_for_step"):
throughput_logger.on_train_step_end(state, unit)
self.assertEqual(throughput_logger._steps_in_epoch[ActivePhase.TRAIN], 1)

with patch("time.perf_counter", return_value=0.5):
throughput_logger.on_train_epoch_start(state, MagicMock(spec=TrainUnit))
self.assertEqual(throughput_logger._epoch_start_times[ActivePhase.TRAIN], 0.5)

throughput_logger._steps_in_epoch[ActivePhase.TRAIN] = (
2 # to assume there were two steps in the epoch
)
logger.log.reset_mock()
with patch("time.perf_counter", return_value=0.6):
throughput_logger._log_for_epoch(state, epoch_logging_for=15)

logger.log.assert_has_calls(
[
call(
"Train: Batches per second (epoch granularity)",
(1 * 2) / (0.6 - 0.5),
15,
),
call(
"Train: Queries per second (epoch granularity)",
(8 * 2) / (0.6 - 0.5),
15,
),
]
)

def test_epoch_logging_time(self) -> None:
logger = MagicMock(spec=MetricLogger)
throughput_logger = ThroughputLogger(logger, {"Queries": 4})
with patch("time.perf_counter", side_effect=[0.1, 0.5, 0.8, 1.5]):
train(
DummyTrainUnit(input_dim=2),
generate_random_dataloader(num_samples=16, input_dim=2, batch_size=4),
max_epochs=2,
max_steps_per_epoch=2,
callbacks=[throughput_logger],
)

logger.log.assert_has_calls(
[
call(
"Train: Queries per second (epoch granularity)",
(4 * 2) / (0.5 - 0.1),
1,
),
call(
"Train: Queries per second (epoch granularity)",
(4 * 2) / (1.5 - 0.8),
2,
),
],
any_order=True,
)

def test_input_validation(self) -> None:
Expand Down
54 changes: 53 additions & 1 deletion torchtnt/framework/callbacks/throughput_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# pyre-strict


import time
from collections import defaultdict
from typing import Mapping

from pyre_extensions import none_throws
Expand All @@ -32,7 +34,7 @@
class ThroughputLogger(Callback):
"""
A callback which logs throughput. For instance, it can be used to log QPS and number of batches processed per second.
The callback logs the throughput on a step basis.
The callback logs the throughput on a step basis and on an epoch basis.
Args:
logger: A a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger`.
Expand Down Expand Up @@ -66,12 +68,15 @@ def __init__(
)

self._log_every_n_steps = log_every_n_steps
self._epoch_start_times: dict[ActivePhase, float] = {}
self._steps_in_epoch: dict[ActivePhase, int] = defaultdict(int)

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
self._maybe_log_for_step(
state,
unit.train_progress.num_steps_completed - 1,
)
self._steps_in_epoch[ActivePhase.TRAIN] += 1

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self._maybe_log_for_step(
Expand All @@ -85,13 +90,18 @@ def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
state,
unit.eval_progress.num_steps_completed - 1,
)
self._steps_in_epoch[ActivePhase.EVALUATE] += 1

def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
self._maybe_log_for_step(
state,
unit.eval_progress.num_steps_completed,
is_step_end_hook=False,
)
self._log_for_epoch(
state,
epoch_logging_for=unit.eval_progress.num_epochs_completed,
)

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
self._maybe_log_for_step(
Expand All @@ -105,6 +115,25 @@ def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
unit.predict_progress.num_steps_completed,
is_step_end_hook=False,
)
self._log_for_epoch(
state,
epoch_logging_for=unit.predict_progress.num_epochs_completed,
)

def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
self._epoch_start_times[ActivePhase.TRAIN] = time.perf_counter()

def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
self._log_for_epoch(
state,
epoch_logging_for=unit.train_progress.num_epochs_completed,
)

def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
self._epoch_start_times[ActivePhase.EVALUATE] = time.perf_counter()

def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
self._epoch_start_times[ActivePhase.PREDICT] = time.perf_counter()

def _maybe_log_for_step(
self,
Expand Down Expand Up @@ -147,3 +176,26 @@ def _maybe_log_for_step(
num_items / total_time,
step_logging_for,
)

def _log_for_epoch(
self,
state: State,
*,
epoch_logging_for: int,
) -> None:
time_since_epoch_start = (
time.perf_counter() - self._epoch_start_times[state.active_phase]
)

steps_in_epoch = self._steps_in_epoch[state.active_phase]
if steps_in_epoch <= 0:
return

for item, num_items in self._throughput_per_batch.items():
self._logger.log(
f"{ACTIVE_PHASE_TO_LABEL_PREFIX[state.active_phase]}: {item} per second (epoch granularity)",
(num_items * steps_in_epoch) / time_since_epoch_start,
epoch_logging_for,
)

self._steps_in_epoch[state.active_phase] = 0

0 comments on commit 5bb592f

Please sign in to comment.