Skip to content

Commit

Permalink
Use checkpoint_id in the internal APIs (#860)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookexternal/vizard#2

Pull Request resolved: #860

Use checkpoint_id in the internal APIs instead of checkpoint paths. ID is a more generic parameter which will be used in the subsequent diffs to represent Meta internal abstractions like model entity id to identify a checkpoint.

Reviewed By: galrotem

Differential Revision: D59638742

fbshipit-source-id: 1a1958c0ef93f7eaf2ef14a1668b3dd6a4237641
  • Loading branch information
saumishr authored and facebook-github-bot committed Jul 15, 2024
1 parent 53c6f91 commit 860d155
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def __init__(
self._latest_checkpoint_path: str = ""

def _checkpoint_impl(
self, state: State, unit: AppStateMixin, checkpoint_path: str, hook: str
self, state: State, unit: AppStateMixin, checkpoint_id: str, hook: str
) -> bool:
self._latest_checkpoint_path = checkpoint_path
if not os.path.exists(checkpoint_path):
os.mkdir(checkpoint_path)
self._latest_checkpoint_path = checkpoint_id
if not os.path.exists(checkpoint_id):
os.mkdir(checkpoint_id)
return True

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _generate_checkpoint_and_upkeep(

# 3) try to save checkpoint
if not self._checkpoint_impl(
state, unit, checkpoint_path=checkpoint_path.path, hook=hook
state, unit, checkpoint_id=checkpoint_path.path, hook=hook
):
return False

Expand Down Expand Up @@ -299,7 +299,7 @@ def _checkpoint_impl(
state: State,
unit: AppStateMixin,
*,
checkpoint_path: str,
checkpoint_id: str,
hook: str,
) -> bool:
"""
Expand All @@ -308,7 +308,7 @@ def _checkpoint_impl(
Args:
state: current application state
unit: current unit
checkpoint_path: path to save checkpoint
checkpoint_id: Checkpoint id to save a checkpoint. It can be a path
hook: name of callback hook that triggered this function call
Returns:
Expand Down
6 changes: 3 additions & 3 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _checkpoint_impl(
state: State,
unit: AppStateMixin,
*,
checkpoint_path: str,
checkpoint_id: str,
hook: str,
planner: Optional[SavePlanner] = None,
storage_writer: Optional[StorageWriter] = None,
Expand All @@ -156,14 +156,14 @@ def _checkpoint_impl(
# future, add logic to set successful flag
# only when checkpoint is fully written
checkpoint_success = self._async_save(
checkpoint_path, app_state, planner, storage_writer
checkpoint_id, app_state, planner, storage_writer
)
if curr_snapshot_wait:
self._wait()
else:
with get_timing_context(state, f"{self.__class__.__name__}.save"):
checkpoint_success = self._save(
checkpoint_path, app_state, planner, storage_writer
checkpoint_id, app_state, planner, storage_writer
)

return checkpoint_success
Expand Down
6 changes: 3 additions & 3 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _checkpoint_impl(
state: State,
unit: AppStateMixin,
*,
checkpoint_path: str,
checkpoint_id: str,
hook: str,
) -> bool:
"""
Expand Down Expand Up @@ -185,12 +185,12 @@ def _checkpoint_impl(
# since this is async checkpointed, so in
# future, add logic to set successful flag
# only when checkpoint is fully written
checkpoint_success = self._async_snapshot(checkpoint_path, app_state)
checkpoint_success = self._async_snapshot(checkpoint_id, app_state)
if curr_snapshot_wait:
self._wait()
else:
with get_timing_context(state, f"{self.__class__.__name__}.take_snapshot"):
checkpoint_success = self._sync_snapshot(checkpoint_path, app_state)
checkpoint_success = self._sync_snapshot(checkpoint_id, app_state)
return checkpoint_success

def _wait(self) -> None:
Expand Down

0 comments on commit 860d155

Please sign in to comment.