-
Notifications
You must be signed in to change notification settings - Fork 0
/
oai_embeddings.py
45 lines (36 loc) · 1.32 KB
/
oai_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import logging
from typing import List, Iterator
import openai
from tenacity import retry, retry_if_exception, wait_exponential, stop_after_attempt
from models import Segment
from utils import batch_segments
@retry(
retry=retry_if_exception(openai.OpenAIError),
wait=wait_exponential(multiplier=1, min=1, max=10),
stop=stop_after_attempt(4),
)
def get_multi_embeddings(
texts: List[str], model="text-embedding-ada-002"
) -> List[List[float]]:
texts = [text.replace("\n", " ") for text in texts]
return [
data["embedding"]
for data in openai.Embedding.create(input=texts, model=model)["data"]
]
def generate_embeddings_batch(
segments: List[Segment], batch_size: int = 50
) -> Iterator[List[Segment]]:
for batch in batch_segments(segments, batch_size):
try:
batch_text = [row.text for row in batch]
embs = get_multi_embeddings(batch_text)
# assigning emb to the segment takes up a lof of memory.
# consider commenting this out, and simply yielding embs
for i, segment in enumerate(batch):
segment.emb = embs[i]
yield batch
except Exception as e:
import traceback
print(traceback.format_exc())
print(f"Problematic batch: {batch}")
raise e