-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Miscellaneous CI, dependency, and version fixes #1151
Changes from 9 commits
f5f911b
4c4de97
e85fe96
f45bfbd
5746924
9d0e193
189c409
31c02ca
a0c9789
9b94fa1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -156,7 +156,16 @@ You can find a full list of all our Llama3 configs [here.](recipes/configs/llama | |||||
|
||||||
## Installation | ||||||
|
||||||
**Step 1:** [Install PyTorch](https://pytorch.org/get-started/locally/). torchtune is tested with the latest stable PyTorch release as well as the preview nightly version. | ||||||
**Step 1:** [Install PyTorch](https://pytorch.org/get-started/locally/). torchtune is tested with the latest stable PyTorch release as well as the preview nightly version. For multimodality | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
be sure to also install torchvision. | ||||||
|
||||||
``` | ||||||
# Install stable version of PyTorch using pip | ||||||
pip3 install torch torchvision | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to explicitly say |
||||||
|
||||||
# Nightly install for latest features | ||||||
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 | ||||||
``` | ||||||
|
||||||
**Step 2:** The latest stable version of torchtune is hosted on PyPI and can be downloaded with the following command: | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,6 @@ authors = [ | |
] | ||
keywords = ["pytorch", "finetuning", "llm"] | ||
dependencies = [ | ||
# multimodality | ||
"torchvision", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if there is a way to keep torchvision without causing issues with torch nightlies. Do you know if its worth researching, or there is no way to make it work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just chatted with @NicolasHug on this and he confirmed it's not possible since there's no way to point pyproject.toml to a specific conda channel or PyPI repo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it! i am just afraid that that in the long run it may cause issues, since we may want to pin torchvision version. For example, in ClipTransforms, older versions will break. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be installed in a single command, this is why I updated the readme to clarify this. So if the user runs I can't find the link to the original comment but I believe that was demonstrating what happens when we install using the existing pyproject.toml. For reference, here is my pip list after running the first command; here is my pip list after running the second command. You can see that the versions are as expected. Btw regarding pinning versions -- we do not test on anything older than the latest stable version of PyTorch and I don't think we want to worry about breaking folks on older versions than that. By the same logic, I don't think we should be pinning to older versions of torchvision. Then the best way to keep things in sync is just install the two together using these commands. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the links and explanation! |
||
|
||
# Hugging Face integrations | ||
"datasets", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -262,6 +262,10 @@ def world_size(self) -> int: | |
return 2 | ||
|
||
@gpu_test(gpu_count=2) | ||
@pytest.mark.skipif( | ||
version.parse(torch.__version__).base_version < "2.4.0", | ||
reason="torch >= 2.4 required", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any particular reason we don't want to test this with torch 2.3.x? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some of the DTensor APIs we use in load_from_full_model_state_dict were not stable prior to 2.4. We already addressed this for the QLoRA state dict test in #1087. In this case it's OK because we are testing FSDP2 functionality which is not available until 2.4 anyways. cc @weifengpy in case I'm missing any important points here. |
||
) | ||
def test_lora_state_dict(self): | ||
rank = self.rank | ||
is_rank_zero = rank == 0 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from importlib.metadata import PackageNotFoundError, version | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
import torchao | ||
|
||
|
||
def _is_fbcode(): | ||
return not hasattr(torch.version, "git_version") | ||
|
||
|
||
def _get_torchao_version() -> Tuple[Optional[str], Optional[bool]]: | ||
""" | ||
Get torchao version. Returns a tuple of two elements, the first element | ||
is the version string, the second element is whether it's a nightly version. | ||
For fbcode usage, return None, None. | ||
|
||
Checks: | ||
1) is_fbcode, then | ||
2) importlib's version(torchao-nightly) for nightlies, then | ||
3) torchao.__version__ (only defined for torchao >= 0.3.0), then | ||
4) importlib's version(torchao) for non-nightly | ||
|
||
|
||
If none of these work, raise an error. | ||
|
||
""" | ||
if _is_fbcode(): | ||
return None, None | ||
# Check for nightly install first | ||
try: | ||
ao_version = version("torchao-nightly") | ||
is_nightly = True | ||
except PackageNotFoundError: | ||
try: | ||
ao_version = torchao.__version__ | ||
is_nightly = False | ||
except AttributeError: | ||
ao_version = "unknown" | ||
if ao_version == "unknown": | ||
try: | ||
ao_version = version("torchao") | ||
is_nightly = False | ||
except Exception as e: | ||
raise PackageNotFoundError("Could not find torchao version") from e | ||
return ao_version, is_nightly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yay