Skip to content

Commit

Permalink
Move SWA Model to AutoUnit.device
Browse files Browse the repository at this point in the history
Summary: Update torchtnt auto_unit to use self.device for the EMA / SWA model, which may be set from environment in the superclass init. This enables model evaluation in GPU.

Differential Revision: D64206735
  • Loading branch information
Victor Bourgin authored and facebook-github-bot committed Oct 11, 2024
1 parent b9e7c1f commit ca5d580
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def __init__(

self.swa_model = AveragedModel(
module_for_swa,
device=device,
device=self.device,
use_buffers=swa_params.use_buffers,
averaging_method=swa_params.averaging_method,
ema_decay=swa_params.ema_decay,
Expand Down

0 comments on commit ca5d580

Please sign in to comment.