From 4f32ca99b9b334c43516bdbb7fd2d54623393f70 Mon Sep 17 00:00:00 2001 From: One Date: Tue, 1 Aug 2023 17:51:24 +0000 Subject: [PATCH] set flashattention as optional dependency --- README.md | 15 ++++++++------- ochat/models/unpadded_llama.py | 7 +++++-- pyproject.toml | 1 - 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 34cba6f..c77206e 100644 --- a/README.md +++ b/README.md @@ -114,16 +114,20 @@ We will release the evaluation results as soon as they become available, so stay ## Installation -To use OpenChat, you need to install CUDA and PyTorch, then install FlashAttention 1. After that you can install OpenChat via pip: +To use OpenChat, you need to install CUDA and PyTorch, then you can install OpenChat via pip: + +```bash +pip3 install ochat +``` + +If you want to train models, please also install FlashAttention 1. ```bash pip3 install packaging ninja pip3 install --no-build-isolation "flash-attn<2" - -pip3 install ochat ``` -FlashAttention may have compatibility issues. If you encounter these problems, you can try to create a new `conda` environment following the instructions below. +FlashAttention and vLLM may have compatibility issues. If you encounter these problems, you can try to create a new `conda` environment following the instructions below. ```bash conda create -y --name openchat @@ -146,9 +150,6 @@ pip3 install ochat git clone https://github.com/imoneoi/openchat cd openchat -pip3 install packaging ninja -pip3 install --no-build-isolation "flash-attn<2" - pip3 install --upgrade pip # enable PEP 660 support pip3 install -e . ``` diff --git a/ochat/models/unpadded_llama.py b/ochat/models/unpadded_llama.py index 9056234..16439b5 100644 --- a/ochat/models/unpadded_llama.py +++ b/ochat/models/unpadded_llama.py @@ -32,8 +32,11 @@ from transformers.utils import logging from transformers.models.llama.configuration_llama import LlamaConfig -from flash_attn.flash_attn_interface import flash_attn_unpadded_func -from flash_attn.bert_padding import pad_input +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func + from flash_attn.bert_padding import pad_input +except ImportError: + print ("FlashAttention not found. Install it if you need to train models.") logger = logging.get_logger(__name__) diff --git a/pyproject.toml b/pyproject.toml index 72d8e0e..cc35170 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ dependencies = [ "sentencepiece", "transformers", "accelerate", - "flash-attn<2", "protobuf<3.21", "fastapi", "pydantic",