Skip to content

Commit

Permalink
code adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
noemotiovon committed Oct 21, 2024
1 parent 051f16e commit b6332dd
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 165 deletions.
5 changes: 3 additions & 2 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training.activations import apply_selective_activation_checkpointing
from torchtune.utils import get_torch_device
from torchtune.utils import DeviceSupport, get_torch_device

from tqdm import tqdm

Expand Down Expand Up @@ -743,7 +743,8 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
device_support = DeviceSupport.from_type(cfg.device)
init_process_group(backend=device_support.communication_backend)
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
Expand Down
4 changes: 3 additions & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf.loss import SimPOLoss
from torchtune.utils import DeviceSupport
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -759,7 +760,8 @@ def recipe_main(cfg: DictConfig) -> None:
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
device_support = DeviceSupport.from_type(cfg.device)
init_process_group(backend=device_support.communication_backend)

config.log_config(recipe_name="LoRADPORecipeDistributed", cfg=cfg)

Expand Down
11 changes: 3 additions & 8 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
OffloadActivations,
PROFILER_KEY,
)
from torchtune.utils import get_torch_device
from torchtune.utils import DeviceSupport, get_torch_device

from tqdm import tqdm

Expand Down Expand Up @@ -897,13 +897,8 @@ def recipe_main(cfg: DictConfig) -> None:
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()

if cfg.device == "cpu":
backend = "gloo"
elif cfg.device == "npu":
backend = "hccl"
else:
backend = "nccl"
init_process_group(backend=backend)
device_support = DeviceSupport.from_type(cfg.device)
init_process_group(backend=device_support.communication_backend)

config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg)

Expand Down
4 changes: 3 additions & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,9 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed QAT recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")

device_support = DeviceSupport.from_type(cfg.device)
init_process_group(backend=device_support.communication_backend)
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
Expand Down
20 changes: 20 additions & 0 deletions tests/torchtune/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
_get_device_type_from_env,
_setup_device,
batch_to_device,
DeviceSupport,
get_device,
get_device_support,
get_torch_device,
)


Expand Down Expand Up @@ -87,3 +90,20 @@ def test_get_gpu_device(self) -> None:
assert device.type == "cuda"
assert device.index == 0
assert device.index == torch.cuda.current_device()

@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
@patch("torch.cuda.is_available", return_value=True)
def test_cuda_available(self, mock_cuda):
# Test if CUDA is available, get_device_support should return DeviceSupport.CUDA
device_support = get_device_support()
assert device_support == DeviceSupport.CUDA
assert device_support.device_type == "cuda"
assert device_support.device_name == "GPU"
assert device_support.communication_backend == "nccl"

@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
@patch("torch.cuda.is_available", return_value=True)
def test_get_torch_device_for_cuda(self, mock_cuda):
# Test if get_torch_device returns the correct torch.cuda module
torch_device = get_torch_device("cuda")
assert torch_device == torch.cuda
39 changes: 0 additions & 39 deletions tests/torchtune/utils/test_device_support.py

This file was deleted.

5 changes: 2 additions & 3 deletions torchtune/training/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,15 @@ def verify_bf16_support() -> bool:
bool: True if bf16 is available, False otherwise.
"""
if is_npu_available:
return torch.npu.is_bf16_supported()
cuda_support = (
torch.cuda.is_available()
and torch.cuda.is_bf16_supported()
and torch.distributed.is_nccl_available()
and torch.cuda.nccl.version() >= (2, 10)
)
mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built()
return cuda_support or mps_support
npu_support = is_npu_available and torch.npu.is_bf16_supported()
return cuda_support or mps_support or npu_support


def get_dtype(
Expand Down
11 changes: 9 additions & 2 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._device import batch_to_device, get_device
from ._device_support import get_device_support, get_torch_device, is_npu_available
from ._device import (
batch_to_device,
DeviceSupport,
get_device,
get_device_support,
get_torch_device,
is_npu_available,
)
from ._logging import get_logger

from ._version import torch_version_ge
Expand All @@ -18,4 +24,5 @@
"is_npu_available",
"get_device_support",
"get_torch_device",
"DeviceSupport",
]
69 changes: 64 additions & 5 deletions torchtune/utils/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from enum import Enum
from typing import Optional

import torch
Expand All @@ -16,11 +17,18 @@
else:
BlockMask = torch.Tensor

from torchtune.utils._device_support import (
get_device_support,
get_torch_device,
is_npu_available,
)

def is_torch_npu_available() -> bool:
"""Check the availability of NPU"""
try:
import torch_npu # noqa: F401

return torch.npu.is_available()
except ImportError:
return False


is_npu_available = is_torch_npu_available()


def _get_local_rank() -> Optional[int]:
Expand Down Expand Up @@ -167,3 +175,54 @@ def batch_to_device(batch: dict, device: torch.device) -> None:
f"""To use batch_to_device, all elements in the batch must be a dict or Tensor.
Got key "{k}" with value of type {type(v)}"""
)


class DeviceSupport(Enum):
"""
This is a simple enum for compute devices,
This currently only supports CPU, CUDA, NPU.
"""

CPU = ("cpu", "CPU", "gloo")
CUDA = ("cuda", "GPU", "nccl")
NPU = ("npu", "NPU", "hccl")

def __init__(self, device_type: str, device_name: str, communication_backend: str):
self.device_type = device_type
self.device_name = device_name
self.communication_backend = communication_backend

@staticmethod
def from_type(device_type: str):
for member in DeviceSupport:
if member.device_type == device_type:
return member
raise ValueError(f"Unknown device type: {device_type}.")


def get_device_support() -> DeviceSupport:
"""function that gets the DeviceSupport with compute devices based on the current machine.
This currently only supports CPU, CUDA, NPU.
Returns:
device_support: DeviceSupport
"""
device_type = _get_device_type_from_env()
return DeviceSupport.from_type(device_type)


def get_torch_device() -> any:
"""Return the corresponding torch attribute based on the device type string.
Returns:
module: The corresponding torch module, or torch.cuda if not found.
"""
device_type = get_device_support().device_type
try:
return getattr(torch, device_type)
except AttributeError:
print(
f"Device Module '{device_type}' not found in torch, try to load torch.cuda."
)
return torch.cuda
104 changes: 0 additions & 104 deletions torchtune/utils/_device_support.py

This file was deleted.

0 comments on commit b6332dd

Please sign in to comment.