From ca5d5802e867aaaf3def41cc694120a72eb87355 Mon Sep 17 00:00:00 2001 From: Victor Bourgin Date: Thu, 10 Oct 2024 23:08:44 -0700 Subject: [PATCH] Move SWA Model to AutoUnit.device Summary: Update torchtnt auto_unit to use self.device for the EMA / SWA model, which may be set from environment in the superclass init. This enables model evaluation in GPU. Differential Revision: D64206735 --- torchtnt/framework/auto_unit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 38a50157c6..ea04ad0046 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -512,7 +512,7 @@ def __init__( self.swa_model = AveragedModel( module_for_swa, - device=device, + device=self.device, use_buffers=swa_params.use_buffers, averaging_method=swa_params.averaging_method, ema_decay=swa_params.ema_decay,