Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ascend NPU as a backend #1826

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

noemotiovon
Copy link

@noemotiovon noemotiovon commented Oct 14, 2024

What does this PR do?

Overview

🚀This PR enables the users of torhtune to leverage the Ascend NPU for better performance in inferencing when GPU device is not available.

For more details, see: [#1797].

Environment

  • OS: ubuntu 20.04
  • NPU: Atlas 300T A2
  • CANN: 8.0.RC2
  • torch-npu: 2.4.0 rc1
  • torch: 2.4.0

Note

To properly install CANN, see [here] for more details.

The version of torch-npu should match that of torch, see [here] for more details.

In addition, torch_npu has a pre-release version, 2.4.0 RC1, which is also the basis for this test. For more information, please visit [here].

Examples

To start with, the library torch_npu should be correctly installed and imported. Part of the codes are showed below:

torchtune/utils/_device_support.py:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch

def is_torch_npu_available():
    try:
        import torch_npu # noqa: F401
    except ImportError:
        return False
    return torch.npu.is_available()

Plus, there are some other places of the codes might be adjusted, which won't be too much.

Feel free to leave comments to guide me in further improvements 😊.

Tests

This PR has passed the tests showed below:

Basic Usage Test

A single-device fine-tuning process was performed on the Llama 3.1 8B model using the LoRA (Low-Rank Adaptation) technique.

  • Recipe: lora_finetune_single_device

  • Model: Meta-Llama-3.1-8B-Instruct

  • Config:

    # Config for single device LoRA finetuning in lora_finetune_single_device.py
    # using a Llama3.1 8B Instruct model
    #
    # This config assumes that you've run the following command before launching
    # this run:
    #   tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
    #
    # To launch on a single device, run the following command from root:
    #   tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
    #
    # You can add specific overrides through the command line. For example
    # to override the checkpointer directory while launching training
    # you can run:
    #   tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
    #
    # This config works only for training on single device.
    
    
    # Model Arguments
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      lora_attn_modules: ['q_proj', 'v_proj']
      apply_lora_to_mlp: False
      apply_lora_to_output: False
      lora_rank: 8
      lora_alpha: 16
      lora_dropout: 0.0
    
    # Tokenizer
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
      max_seq_len: null
    
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files: [
        model-00001-of-00004.safetensors,
        model-00002-of-00004.safetensors,
        model-00003-of-00004.safetensors,
        model-00004-of-00004.safetensors
      ]
      recipe_checkpoint: null
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      model_type: LLAMA3
    resume_from_checkpoint: False
    save_adapter_weights_only: False
    
    # Dataset and Sampler
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    seed: null
    shuffle: True
    batch_size: 30
    
    # Optimizer and Scheduler
    optimizer:
      _component_: torch.optim.AdamW
      fused: False
      weight_decay: 0.01
      lr: 3e-4
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    
    # Training
    epochs: 1
    max_steps_per_epoch: null
    gradient_accumulation_steps: 64
    compile: False
    
    # Logging
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: ${output_dir}
    log_every_n_steps: 1
    log_peak_memory_stats: False
    
    # Environment
    device: npu
    dtype: bf16
    
    # Activations Memory
    enable_activation_checkpointing: True
    enable_activation_offloading: False
    
    # Profiler (disabled)
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      enabled: False
    
      #Output directory of trace artifacts
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    
      #`torch.profiler.ProfilerActivity` types to trace
      cpu: True
      cuda: True
    
      #trace options passed to `torch.profiler.profile`
      profile_memory: False
      with_stack: False
      record_shapes: True
      with_flops: False
    
      # `torch.profiler.schedule` options:
      # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
      wait_steps: 5
      warmup_steps: 5
      active_steps: 2
      num_cycles: 1
  • Logs:

    INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
    
    batch_size: 30
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files:
      - model-00001-of-00004.safetensors
      - model-00002-of-00004.safetensors
      - model-00003-of-00004.safetensors
      - model-00004-of-00004.safetensors
      model_type: LLAMA3
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      recipe_checkpoint: null
    compile: false
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    device: npu
    dtype: bf16
    enable_activation_checkpointing: true
    enable_activation_offloading: false
    epochs: 1
    gradient_accumulation_steps: 64
    log_every_n_steps: 1
    log_peak_memory_stats: false
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    max_steps_per_epoch: null
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      apply_lora_to_mlp: false
      apply_lora_to_output: false
      lora_alpha: 16
      lora_attn_modules:
      - q_proj
      - v_proj
      lora_dropout: 0.0
      lora_rank: 8
    optimizer:
      _component_: torch.optim.AdamW
      fused: false
      lr: 0.0003
      weight_decay: 0.01
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      active_steps: 2
      cpu: true
      cuda: true
      enabled: false
      num_cycles: 1
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
      profile_memory: false
      record_shapes: true
      wait_steps: 5
      warmup_steps: 5
      with_flops: false
      with_stack: false
    resume_from_checkpoint: false
    save_adapter_weights_only: false
    seed: null
    shuffle: true
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      max_seq_len: null
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
    
    DEBUG:torchtune.utils._logging:Setting manual seed to local seed 4173222699. Local seed is seed + rank = 4173222699 + 0
    Writing logs to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test/log_1728874769.txt
    INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
    INFO:torchtune.utils._logging:Memory stats after model init:
            NPU peak memory allocation: 16.98 GiB
            NPU peak memory reserved: 17.00 GiB
            NPU peak memory active: 16.98 GiB
    INFO:torchtune.utils._logging:Tokenizer is initialized from file.
    INFO:torchtune.utils._logging:Optimizer and loss are initialized.
    INFO:torchtune.utils._logging:Loss is initialized.
    Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    WARNING:datasets.load:Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    WARNING:datasets.packaged_modules.cache.cache:Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
    INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
    WARNING:torchtune.utils._logging: Profiling disabled.
    INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
      0%|                                                                                                                                                            | 0/26 [00:00<?, ?it/s]/home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
      with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
      4%|█████▌                                                                                                                                           | 1/26 [07:09<2:59:1|1|Loss: 1.7533539533615112:   4%|████▍                                                                                                              | 1/26 [07:09<2:59:1|1|Loss: 1.7533539533615112:   8%|████████▊                                                                                                          | 2/26 [14:09<2:49:1|2|Loss: 1.7825285196304321:   8%|████████▊                                                                                                          | 2/26 [14:09<2:49:1|2|Loss: 1.7825285196304321:  12%|█████████████▎                                                                                                     | 3/26 [21:17<2:43:1|3|Loss: 1.7610299587249756:  12%|█████████████▎                                                                                                     | 3/26 [21:17<2:43:1|3|Loss: 1.7610299587249756:  15%|█████████████████▋                                                                                                 | 4/26 [28:28<2:36:1|4|Loss: 1.7874119281768799:  15%|█████████████████▋                                                                                                 | 4/26 [28:28<2:36:1|4|Loss: 1.7874119281768799:  19%|██████████████████████                                                                                             | 5/26 [35:36<2:29:1|5|Loss: 1.7903798818588257:  19%|██████████████████████                                                                                             | 5/26 [35:36<2:29:1|5|Loss: 1.7903798818588257:  23%|██████████████████████████▌                                                                                        | 6/26 [42:52<2:23:1|6|Loss: 1.776786208152771:  23%|██████████████████████████▊                                                                                         | 6/26 [42:52<2:23:1|6|Loss: 1.776786208152771:  27%|███████████████████████████████▏                                                                                    | 7/26 [49:45<2:14:1|7|Loss: 1.7698196172714233:  27%|██████████████████████████████▉                                                                                    | 7/26 [49:45<2:14:29, 424.69s/it]
     *  History restored 
    
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$  
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$ 
    (torchtune) (base) lcg@lcg-docker:~/github/torchtune$ tune run lora_finetune_single_device --config my_custom_config.yaml
    INFO:torchtune.utils._logging:Running LoRAFinetuneRecipeSingleDevice with resolved config:
    
    batch_size: 30
    checkpointer:
      _component_: torchtune.training.FullModelHFCheckpointer
      checkpoint_dir: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct
      checkpoint_files:
      - model-00001-of-00004.safetensors
      - model-00002-of-00004.safetensors
      - model-00003-of-00004.safetensors
      - model-00004-of-00004.safetensors
      model_type: LLAMA3
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/
      recipe_checkpoint: null
    compile: false
    dataset:
      _component_: torchtune.datasets.alpaca_cleaned_dataset
    device: npu
    dtype: bf16
    enable_activation_checkpointing: true
    enable_activation_offloading: false
    epochs: 1
    gradient_accumulation_steps: 64
    log_every_n_steps: 1
    log_peak_memory_stats: false
    loss:
      _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
    lr_scheduler:
      _component_: torchtune.modules.get_cosine_schedule_with_warmup
      num_warmup_steps: 100
    max_steps_per_epoch: null
    metric_logger:
      _component_: torchtune.training.metric_logging.DiskLogger
      log_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    model:
      _component_: torchtune.models.llama3_1.lora_llama3_1_8b
      apply_lora_to_mlp: false
      apply_lora_to_output: false
      lora_alpha: 16
      lora_attn_modules:
      - q_proj
      - v_proj
      lora_dropout: 0.0
      lora_rank: 8
    optimizer:
      _component_: torch.optim.AdamW
      fused: false
      lr: 0.0003
      weight_decay: 0.01
    output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
    profiler:
      _component_: torchtune.training.setup_torch_profiler
      active_steps: 2
      cpu: true
      cuda: true
      enabled: false
      num_cycles: 1
      output_dir: /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test
      profile_memory: false
      record_shapes: true
      wait_steps: 5
      warmup_steps: 5
      with_flops: false
      with_stack: false
    resume_from_checkpoint: false
    save_adapter_weights_only: false
    seed: null
    shuffle: true
    tokenizer:
      _component_: torchtune.models.llama3.llama3_tokenizer
      max_seq_len: null
      path: /home/lcg/.cache/modelscope/hub/LLM-Research/Meta-Llama-3___1-8B-Instruct/original/tokenizer.model
    
    DEBUG:torchtune.utils._logging:Setting manual seed to local seed 1031355438. Local seed is seed + rank = 1031355438 + 0
    Writing logs to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/lcg-test/log_1728878132.txt
    INFO:torchtune.utils._logging:Model is initialized with precision torch.bfloat16.
    INFO:torchtune.utils._logging:Memory stats after model init:
            NPU peak memory allocation: 16.98 GiB
            NPU peak memory reserved: 17.00 GiB
            NPU peak memory active: 16.98 GiB
    INFO:torchtune.utils._logging:Tokenizer is initialized from file.
    INFO:torchtune.utils._logging:Optimizer and loss are initialized.
    INFO:torchtune.utils._logging:Loss is initialized.
    Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    WARNING:datasets.load:Using the latest cached version of the dataset since yahma/alpaca-cleaned couldn't be found on the Hugging Face Hub
    Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    WARNING:datasets.packaged_modules.cache.cache:Found the latest cached dataset configuration 'default' at /home/lcg/.cache/huggingface/datasets/yahma___alpaca-cleaned/default/0.0.0/12567cabf869d7c92e573c7c783905fc160e9639 (last modified on Fri Oct 11 01:21:44 2024).
    INFO:torchtune.utils._logging:Dataset and Sampler are initialized.
    INFO:torchtune.utils._logging:Learning rate scheduler is initialized.
    WARNING:torchtune.utils._logging: Profiling disabled.
    INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}
      0%|                                                                                                                                             | 0/26 [00:00<?, ?it/s]/home/lcg/miniconda3/envs/torchtune/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
      with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
    1|26|Loss: 1.427944302558899: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [3:04:04<00:00, 418.75s/it]INFO:torchtune.utils._logging:Starting checkpoint save...
    INFO:torchtune.utils._logging:Model checkpoint of size 4.98 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0001_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 5.00 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0002_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 4.92 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0003_0.pt
    INFO:torchtune.utils._logging:Model checkpoint of size 1.17 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/hf_model_0004_0.pt
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.01 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_0.pt
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.01 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_model.bin
    INFO:torchtune.utils._logging:Adapter checkpoint of size 0.00 GB saved to /home/lcg/tmp/torchtune/Meta-Llama-3.1-8B-Instruct/adapter_config.json
    INFO:torchtune.utils._logging:Saving final epoch checkpoint.
    INFO:torchtune.utils._logging:The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
    INFO:torchtune.utils._logging:Checkpoint saved in 65.93 seconds.
    1|26|Loss: 1.427944302558899: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [3:11:40<00:00, 442.34s/it]
    
  • Result: The test results demonstrate the successful completion of a single-device LoRA fine-tuning process on the Llama 3.1 8B model. The configuration included a batch size of 30, gradient accumulation over 64 steps, and one epoch of training on an NPU device using the bf16 data type. Activation checkpointing was enabled, and LoRA fine-tuning was applied to attention modules. The process utilized AdamW as the optimizer with a learning rate of 0.0003 and a cosine learning rate scheduler.

Copy link

pytorch-bot bot commented Oct 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1826

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures, 4 Cancelled Jobs

As of commit b6332dd with merge base 73aa126 (image):

NEW FAILURES - The following jobs have failed:

  • Build Docs / build_docs (3.11) (gh)
    sphinx.ext.autosummary.ImportExceptionGroup: no module named torchtune.training
  • GPU tests / gpu_test (3.10, stable) (gh)
    E ImportError: cannot import name 'TensorCoreTiledLayout' from 'torchao.dtypes' (/home/ec2-user/actions-runner/_work/torchtune/torchtune/3/envs/test/lib/python3.10/site-packages/torchao/dtypes/__init__.py)
  • Recipe Tests / recipe_test (3.10) (gh)
    E ImportError: cannot import name 'TensorCoreTiledLayout' from 'torchao.dtypes' (/usr/share/miniconda3/envs/test/lib/python3.10/site-packages/torchao/dtypes/__init__.py)
  • Recipe Tests / recipe_test (3.9) (gh)
    E ImportError: cannot import name 'TensorCoreTiledLayout' from 'torchao.dtypes' (/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/torchao/dtypes/__init__.py)
  • Unit Test / unit_tests (3.11) (gh)
    E ImportError: cannot import name 'TensorCoreTiledLayout' from 'torchao.dtypes' (/usr/share/miniconda3/envs/test/lib/python3.11/site-packages/torchao/dtypes/__init__.py)
  • Unit Test / unit_tests (3.9) (gh)
    E ImportError: cannot import name 'TensorCoreTiledLayout' from 'torchao.dtypes' (/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/torchao/dtypes/__init__.py)

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2024
@noemotiovon noemotiovon marked this pull request as draft October 14, 2024 10:21
torchtune/utils/_device_support.py Outdated Show resolved Hide resolved
recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
recipes/quantize.py Outdated Show resolved Hide resolved
@noemotiovon noemotiovon marked this pull request as ready for review October 21, 2024 10:49
@noemotiovon
Copy link
Author

Hi @ebsmothers, @RdoubleA:

I hope you’re doing well! Could you please help me review my code? I would really appreciate it if you could take a look and share any feedback or suggestions. Thank you so much in advance for your time and support! 😊

Best regards

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @noemotiovon thanks for the PR! And apologies for the delay in getting to the review here. A couple other questions I have that don't really fit neatly anywhere inline:

  1. Do we expect compile to work? If so, we should test that. If not, we could raise an error
  2. Do we expect quant-related APIs (e.g. QLoRA or QAT) from torchao to work? Same as point 1: if so we should test or possibly raise an error
  3. PyTorch has now released 2.5 as stable. In general we do not claim to support anything but the latest stable release of PyTorch -- do you know the contract on torch_npu releases here?

@@ -45,7 +46,7 @@ def _set_float32_precision(precision: str = "high") -> None:
def verify_bf16_support() -> bool:
"""
Check that bf16 is available on this hardware. Requirements:
- CUDA is available and supports bf16
- CUDA or NPU is available and supports bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand this, requirements for bf16 support on NPU are identical to bf16 support requirements on CUDA?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These requirements between NPU and CUDA are similar but not the same, and I will adjust the code comments. Thank you for your valuable feedback!

@@ -617,14 +618,14 @@ def train(self) -> None:
):
break

# Start tracking CUDA memory for active steps for just the first epoch
# Start tracking CUDA or NPU memory for active steps for just the first epoch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of a nit, but I wonder if we should just generalize these comments to "Start tracking device memory" (otherwise if we add other devices this will start to get pretty verbose)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s a great suggestion! I will adjust the comments to "Start tracking CUDA-like device memory". Thank you very much!

if (
curr_epoch == 0
and self.profiler_profile_memory
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
):
torch.cuda.memory._record_memory_history()
get_torch_device().memory._record_memory_history()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you also test this? I am not familiar with NPU memory snapshot APIs but would be good to make sure this works as expected too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NPU has these APIs, but I haven’t tested whether they function as expected, so I’ll roll back these changes for now and address them in a separate PR.


@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
@patch("torch.cuda.is_available", return_value=True)
def test_get_torch_device_for_cuda(self, mock_cuda):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should add a similar test for NPU (with corresponding patch)? Ofc if we don't have the device in our CI runners, maybe it's too trivial to actually be meaningful?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, NPU testing has not been considered. I will look into proposing a CI-related PR and the necessary hardware later.

@@ -87,15 +88,15 @@ def __init__(
60 # we should not exceed this percentage of memory
)

self.s0 = torch.cuda.default_stream() # comp stream
self.s0 = get_torch_device().default_stream() # comp stream
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here: do we know that activation offloading will work on NPU?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NPU has these APIs, but I haven’t tested whether they function as expected, so I’ll roll back these changes for now and address them in a separate PR.

Comment on lines +186 to +188
CPU = ("cpu", "CPU", "gloo")
CUDA = ("cuda", "GPU", "nccl")
NPU = ("npu", "NPU", "hccl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need an item for MPS here? cc @SalmanMohammadi

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also more of a nit, but I wonder if we should use a NamedTuple abstraction here. (Alternatively can just add a comment explaining what each of the fields correspond to)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for your suggestions! I’ll make sure to add the relevant comments here.

@elfisworking
Copy link

elfisworking commented Oct 22, 2024

distributed training seems to have problems e.g qat_distributed @noemotiovon
function torchtune/training/_distributed.py/load_from_full_model_state_dict
sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error.
Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

@noemotiovon
Copy link
Author

distributed training seems to have problems e.g qat_distributed @noemotiovon function torchtune/training/_distributed.py/load_from_full_model_state_dict sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error. Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

I would be very happy to! I will contact you via email.

@elfisworking
Copy link

distributed training seems to have problems e.g qat_distributed @noemotiovon function torchtune/training/_distributed.py/load_from_full_model_state_dict sharded_meta_param.device_mesh.device_type is cpu when loading model which would raise error. Are you willing to connect me through my github profile email? maybe we can discuss how to support ascend npu

I would be very happy to! I will contact you via email.

@noemotiovon through 126 email thanks. Looking forward to your email.

Comment on lines +21 to +33
def is_torch_npu_available() -> bool:
"""Check the availability of NPU"""
try:
import torch_npu # noqa: F401

return torch.npu.is_available()
except ImportError:
return False


is_npu_available = is_torch_npu_available()


Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all redundant after the autoload mechanism landed in PyTorch 2.5.0

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion! We will make adjustments to this part once torch-npu is updated to version 2.5.0.

Comment on lines +215 to +228
def get_torch_device() -> any:
"""Return the corresponding torch attribute based on the device type string.

Returns:
module: The corresponding torch module, or torch.cuda if not found.
"""
device_type = get_device_support().device_type
try:
return getattr(torch, device_type)
except AttributeError:
print(
f"Device Module '{device_type}' not found in torch, try to load torch.cuda."
)
return torch.cuda
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use torch.get_device_module() I think

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion! We will make adjustments to this part once torch-npu is updated to version 2.5.0.

Comment on lines +110 to +115
if device.type in ["cuda", "npu"] and local_rank is not None:
# Ensure device index matches assigned index when distributed training
if device.index != local_rank:
raise RuntimeError(
f"You can't specify a device index when using distributed training. \
Device specified is {device} but was assigned cuda:{local_rank}"
Device specified is {device} but was assigned cuda/npu:{local_rank}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All npu can be replaced to torch._C._get_privateuse1_backend_name(), here are two reasons:

  1. All privateuse1 backends are CUDA-like devices
  2. This change will benifit all out-of-tree backends

cc: @FFFrog

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion! We will make adjustments to this part once torch-npu is updated to version 2.5.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants