Skip to content

PyTorch/XLA 2.5 Release

Latest
Compare
Choose a tag to compare
@ManfeiBai ManfeiBai released this 18 Oct 23:19
· 138 commits to master since this release
396608c

Cloud TPUs now support the Pytorch 2.5 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.5 release, this release introduces several features, and PyTorch/XLA specific bug fixes.

Highlights

We are excited to announce the release of PyTorch XLA 2.5! PyTorch 2.5 supports torch_xla.compile function which improves the debugging experience for developers during the development process, and aligns distributed APIs with upstream PyTorch with the traceable collective support for both Dynamo and non-Dynamo cases. Start from PyTorch/XLA 2.5, proposed a clarified vision for deprecation of the older torch_xla API in favor of moving towards the existing PyTorch API, providing for a simplified developer experience.

If you’ve used vLLM for serving models on GPUs, you’ll now be able to seamlessly switch to its TPU backend. vLLM is a widely adopted inference framework that also serves as an excellent way to drive accelerator interoperability. With vLLM on TPU, users will retain the same vLLM interface we’ve grown to love, with direct integration with Hugging Face Models to make model experimentation easy.

STABLE FEATURES

Eager

  • Increase max in flight operation to accommodate eager mode [#7263]
  • Unify the logics to check eager mode [#7709]
  • Update eager.md [#7710]
  • Optimize execution for ops that have multiple output in eager mode [#7680]

Quantization / Low Precision

  • Asymmetric quantized matmul support [#7626]
  • Add blockwise quantized dot support [#7605]
  • Support int4 weight in quantized matmul / linear [#7235]
  • Support fp8e5m2 dtype [#7740]
  • Add fp8e4m3fn support [#7842]
  • Support dynamic activation quant for per-channel quantized matmul [#7867]
  • Enable cross entropy loss for xla autocast with FP32 precision [#8094]

Pallas Kernels

  • Support ab for flash_attention [#7840], actual kernel is implemented in JAX
  • Support logits_soft_cap parameter in paged_attention [#7704], actual kernel is implemented in JAX
  • Support gmm and tgmm trace_pallas caching [#7921]
  • Cache flash attention tracing [#8026]
  • Improve the user guide [#7625]
  • Update pallas doc with paged_attention [#7591]

StableHLO

  • Add user guide for stablehlo composite op [#7826]

gSPMD

  • Handle the parameter wrapping for SPMD [#7604]
  • Add helper function to get 1d mesh [#7577]
  • Support manual all-reduce [#7576]
  • Expose apply_backward_optimization_barrier [#7477]
  • Support reduce-scatter in manual sharding [#7231]
  • Allow MpDeviceLoader to shard dictionaries of tensor [#8202]

Dynamo

  • Optimize dynamo dynamic shape caching [#7726]
  • Add support for dynamic shape in dynamo [#7676]
  • In dynamo optim_mode avoid unnecessary set_attr [#7915]
  • Fix the crash with copy op in dynamo [#7902]
  • Optimize _split_xla_args_tensor_sym_constant [#7900]
  • DYNAMO RNG seed update optimization [#7884]
  • Support mark_dynamic [#7812]
  • Support gmm as a custom op for dynamo [#7672]
  • Fix dynamo inplace copy [#7933]
  • CPU time optimization for GraphInputMatcher [#7895]

PJRT

  • Improve device auto-detection [#7787]
  • Move _xla_register_custom_call_target implementation into PjRtComputationClient [#7801]
  • Handle SPMD case inside of ComputationClient::WaitDeviceOps [#7796]

GKE

  • Add tpu example for torchrun on GKE [#7620]
  • Add an example of using GKE with torchrun [#7589]

Functionalization

  • Add 1-layer gradient accumulation test to check aliasing [#7692]

AMP

  • Fix norm data-type when using AMP [#7878]

BETA FEATURES

Op Lowering

  • Lower aten::_linalg_eigh [#7674]
  • Fallback _embedding_bag_backward and force sparse=false [#7584]
  • Support trilinear by using upstream decomp [#7586]

Higher order ops

  • [Fori_loop] Update randint max range to Support bool dtype [#7632]

TorchBench Integration

  • [benchmarks] API alignment with PyTorch profiler events [#7930]
  • [benchmarks] Add IR dump option when run torchbench [#7927]
  • [benchmarks] Use same matmul precision between PyTorch and PyTorch/XLA[#7748]
  • [benchmarks] Introduce verifier to verify the model output correctness against native pytorch [#7724, #7777]
  • [benchmarks] Fix moco model issue on XLA [#7257, #7598]
  • Type annotation for benchmarks/ [#7289]
  • Default with CUDAGraphs on for inductor [#7749]

GPU

  • Deprecate XRT for XLA:CUDA [#8006]

EXPERIMENTAL FEATURES

Backward Compatibility & APIs that will be removed in 2.7 release:

  • Deprecate APIs (deprecated → new):
    Deprecated New PRs
    xla_model.xrt_world_size() runtime.world_size() [#7679][#7743]
    xla_model.get_ordinal() runtime.global_ordinal() [#7679]
    xla_model.get_local_ordinal() runtime.global_ordinal() [#7679]
  • Internalize APIs
    • xla_model.parse_xla_device() [#7675]
  • Improvement
    • Automatic PJRT device detection when importing torch_xla [#7787]
  • Add deprecated decorator [#7703]

Distributed

  • Enable bucketized all-reduce for gradients [#7216]
  • Use reduce-scatter coalescing for FSDP [#6024]

Distributed API

We have aligned our distributed APIs with upstream PyTorch. Previously, we implemented custom distributed APIs, such as torch_xla.xla_model.all_reduce. With the traceable collective support, we now enable torch.distributed.all_reduce and similar functions for both Dynamo and non-Dynamo cases in torch_xla.

  • Support of upstream distributed APIs (torch.distributed.*) like all_reduce, all_gather, reduce_scatter_tensor, all_to_all. Previously we used xla specific distributed APIs in xla_model [#7860, #7950, #8064].
  • Introduce torch_xla.launch() to launch the multiprocess in order to unify torchrun and torch_xla.distributed.xla_multiprocessing.spawn() [#7764, #7648, #7695].
  • torch.distributed.reduce_scatter_tensor(): [#7950]
  • Register sdp lower precision autocast [#7299]
  • Add Python binding for xla::DotGeneral [#7863]
  • Fix input output alias for custom inplace ops [#7822]

torch_xla.compile

  • Support full_graph which will error out if there will be more than one graph being executed in the compiled region. [#7776][#7789]
  • Support the dynamic shape detection which will print a useful error message when the number of different graphs being executed across different executions exceeds the predefined limits. [#7918]
  • Support naming each compiled program which will make debug messages more informative. [#7802]

Usability & Debuggability

  • Wheel name change to support pip>=24.1: [issue#7697]
  • Add tpu-info as a dependency of torch_xla[tpu] and test: [#7938][#7337]
  • Support torch_xla.manual_seed: [#7340]
  • Support callback on tensor when async execution is finished [#7984]
  • Implement torch.ops._c10d_functional.broadcast: [#7770]
  • Flags XLA_USE_BF16, XLA_DOWNCAST_BF16 will be removed in 2.6 release [#7582][#7945]

AWS Neuron:

  • Update Neuron initializations [#7952]
  • Pass local_world_size into neuron.initialize_env [#7852]
  • Update and short circuit the Neuron initialization [#8041]
  • Introduce multi-node SPMD support for Neuron [#8224]