diff --git a/tests/framework/callbacks/test_throughput_logger.py b/tests/framework/callbacks/test_throughput_logger.py index 12f578265b..a179c6522a 100644 --- a/tests/framework/callbacks/test_throughput_logger.py +++ b/tests/framework/callbacks/test_throughput_logger.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from unittest.mock import ANY, call, MagicMock +from unittest.mock import ANY, call, MagicMock, patch import torch from pyre_extensions import none_throws @@ -17,13 +17,15 @@ from torchtnt.framework._test_utils import ( DummyAutoUnit, DummyPredictUnit, + DummyTrainUnit, generate_random_dataloader, ) from torchtnt.framework.callbacks.throughput_logger import ThroughputLogger from torchtnt.framework.predict import predict -from torchtnt.framework.state import EntryPoint, PhaseState, State -from torchtnt.framework.train import _train_impl +from torchtnt.framework.state import ActivePhase, EntryPoint, PhaseState, State +from torchtnt.framework.train import _train_impl, train +from torchtnt.framework.unit import TrainUnit from torchtnt.utils.loggers.logger import MetricLogger @@ -121,21 +123,18 @@ def test_with_comparing_time(self) -> None: evaluate_every_n_epochs=2, ), ) + throughput_logger = ThroughputLogger( + logger=logger, + throughput_per_batch={"Batches": 1, "Queries": 8}, + log_every_n_steps=1, + ) # we want to be able to compare the logging value to the state, so we need to create state manually and # call _train_impl. This would have been similar to calling fit() and getting the state as a ret value _train_impl( state, DummyAutoUnit(module=torch.nn.Linear(2, 2)), - CallbackHandler( - [ - ThroughputLogger( - logger=logger, - throughput_per_batch={"Batches": 1, "Queries": 8}, - log_every_n_steps=1, - ) - ], - ), + CallbackHandler([throughput_logger]), ) train_iteration_times = none_throws( @@ -163,8 +162,8 @@ def test_with_comparing_time(self) -> None: eval_iteration_times[i] + eval_twfb_times[i] for i in range(2) ] self.assertEqual( - logger.log.call_count, 12 - ) # 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2) + logger.log.call_count, 18 + ) # steps: 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2). epochs: 4 train (2epoch x 2items). 2 eval (1x2) train_batches_step_logs = [ call( "Train: Batches per second (step granularity)", @@ -197,11 +196,36 @@ def test_with_comparing_time(self) -> None: ) for i in range(2) ] + # for epoch, we test the logged value separately + train_batches_epoch_logs = [ + call("Train: Batches per second (epoch granularity)", ANY, i) + for i in range(1, 3) + ] + train_queries_epoch_logs = [ + call("Train: Queries per second (epoch granularity)", ANY, i) + for i in range(1, 3) + ] + eval_epoch_logs = [ + call( + "Eval: Queries per second (epoch granularity)", + ANY, + 1, + ), + call( + "Eval: Batches per second (epoch granularity)", + ANY, + 1, + ), + ] + logger.log.assert_has_calls( train_batches_step_logs + train_queries_step_logs + eval_batches_step_logs - + eval_queries_step_logs, + + eval_queries_step_logs + + train_batches_epoch_logs + + train_queries_epoch_logs + + eval_epoch_logs, any_order=True, ) @@ -227,6 +251,79 @@ def test_with_predict(self) -> None: 1, ) ], + [ + call( + "Predict: Batches per second (epoch granularity)", + ANY, + 1, + ) + ], + ) + + def test_log_for_epoch(self) -> None: + logger = MagicMock(spec=MetricLogger) + unit = DummyTrainUnit(input_dim=2) + throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Queries": 8}) + state = State(entry_point=EntryPoint.TRAIN) + + self.assertIsNone(throughput_logger._epoch_start_times.get(ActivePhase.TRAIN)) + self.assertEqual(throughput_logger._steps_in_epoch[ActivePhase.TRAIN], 0) + with patch.object(throughput_logger, "_maybe_log_for_step"): + throughput_logger.on_train_step_end(state, unit) + self.assertEqual(throughput_logger._steps_in_epoch[ActivePhase.TRAIN], 1) + + with patch("time.perf_counter", return_value=0.5): + throughput_logger.on_train_epoch_start(state, MagicMock(spec=TrainUnit)) + self.assertEqual(throughput_logger._epoch_start_times[ActivePhase.TRAIN], 0.5) + + throughput_logger._steps_in_epoch[ActivePhase.TRAIN] = ( + 2 # to assume there were two steps in the epoch + ) + logger.log.reset_mock() + with patch("time.perf_counter", return_value=0.6): + throughput_logger._log_for_epoch(state, epoch_logging_for=15) + + logger.log.assert_has_calls( + [ + call( + "Train: Batches per second (epoch granularity)", + (1 * 2) / (0.6 - 0.5), + 15, + ), + call( + "Train: Queries per second (epoch granularity)", + (8 * 2) / (0.6 - 0.5), + 15, + ), + ] + ) + + def test_epoch_logging_time(self) -> None: + logger = MagicMock(spec=MetricLogger) + throughput_logger = ThroughputLogger(logger, {"Queries": 4}) + with patch("time.perf_counter", side_effect=[0.1, 0.5, 0.8, 1.5]): + train( + DummyTrainUnit(input_dim=2), + generate_random_dataloader(num_samples=16, input_dim=2, batch_size=4), + max_epochs=2, + max_steps_per_epoch=2, + callbacks=[throughput_logger], + ) + + logger.log.assert_has_calls( + [ + call( + "Train: Queries per second (epoch granularity)", + (4 * 2) / (0.5 - 0.1), + 1, + ), + call( + "Train: Queries per second (epoch granularity)", + (4 * 2) / (1.5 - 0.8), + 2, + ), + ], + any_order=True, ) def test_input_validation(self) -> None: diff --git a/torchtnt/framework/callbacks/throughput_logger.py b/torchtnt/framework/callbacks/throughput_logger.py index 230a0d90c9..5fa09b7d1c 100644 --- a/torchtnt/framework/callbacks/throughput_logger.py +++ b/torchtnt/framework/callbacks/throughput_logger.py @@ -7,6 +7,8 @@ # pyre-strict +import time +from collections import defaultdict from typing import Mapping from pyre_extensions import none_throws @@ -32,7 +34,7 @@ class ThroughputLogger(Callback): """ A callback which logs throughput. For instance, it can be used to log QPS and number of batches processed per second. - The callback logs the throughput on a step basis. + The callback logs the throughput on a step basis and on an epoch basis. Args: logger: A a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger`. @@ -66,12 +68,15 @@ def __init__( ) self._log_every_n_steps = log_every_n_steps + self._epoch_start_times: dict[ActivePhase, float] = {} + self._steps_in_epoch: dict[ActivePhase, int] = defaultdict(int) def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: self._maybe_log_for_step( state, unit.train_progress.num_steps_completed - 1, ) + self._steps_in_epoch[ActivePhase.TRAIN] += 1 def on_train_end(self, state: State, unit: TTrainUnit) -> None: self._maybe_log_for_step( @@ -85,6 +90,7 @@ def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: state, unit.eval_progress.num_steps_completed - 1, ) + self._steps_in_epoch[ActivePhase.EVALUATE] += 1 def on_eval_end(self, state: State, unit: TEvalUnit) -> None: self._maybe_log_for_step( @@ -92,6 +98,10 @@ def on_eval_end(self, state: State, unit: TEvalUnit) -> None: unit.eval_progress.num_steps_completed, is_step_end_hook=False, ) + self._log_for_epoch( + state, + epoch_logging_for=unit.eval_progress.num_epochs_completed, + ) def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: self._maybe_log_for_step( @@ -105,6 +115,25 @@ def on_predict_end(self, state: State, unit: TPredictUnit) -> None: unit.predict_progress.num_steps_completed, is_step_end_hook=False, ) + self._log_for_epoch( + state, + epoch_logging_for=unit.predict_progress.num_epochs_completed, + ) + + def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None: + self._epoch_start_times[ActivePhase.TRAIN] = time.perf_counter() + + def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: + self._log_for_epoch( + state, + epoch_logging_for=unit.train_progress.num_epochs_completed, + ) + + def on_eval_start(self, state: State, unit: TEvalUnit) -> None: + self._epoch_start_times[ActivePhase.EVALUATE] = time.perf_counter() + + def on_predict_start(self, state: State, unit: TPredictUnit) -> None: + self._epoch_start_times[ActivePhase.PREDICT] = time.perf_counter() def _maybe_log_for_step( self, @@ -147,3 +176,26 @@ def _maybe_log_for_step( num_items / total_time, step_logging_for, ) + + def _log_for_epoch( + self, + state: State, + *, + epoch_logging_for: int, + ) -> None: + time_since_epoch_start = ( + time.perf_counter() - self._epoch_start_times[state.active_phase] + ) + + steps_in_epoch = self._steps_in_epoch[state.active_phase] + if steps_in_epoch <= 0: + return + + for item, num_items in self._throughput_per_batch.items(): + self._logger.log( + f"{ACTIVE_PHASE_TO_LABEL_PREFIX[state.active_phase]}: {item} per second (epoch granularity)", + (num_items * steps_in_epoch) / time_since_epoch_start, + epoch_logging_for, + ) + + self._steps_in_epoch[state.active_phase] = 0