Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DataStax Astra DB vector store driver #1022

Closed
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9f8a6b0
initial commit for astra db branch
hemidactylus Jul 25, 2024
3b3e4e7
wip on unit tests with mocker
hemidactylus Jul 25, 2024
472c523
slight unit test refactor
hemidactylus Jul 25, 2024
bd2bdb5
full unit test coverage
hemidactylus Jul 26, 2024
26c5bb8
docstrings; added a find-nothing test
hemidactylus Jul 26, 2024
3586465
docs about the Astra vector store + standalone example; changelog
hemidactylus Jul 26, 2024
21ed76b
minor formatting in code and docs
hemidactylus Jul 26, 2024
0d49f83
'dimension' parameter optional
hemidactylus Jul 26, 2024
6609eb0
Astra DB vector store driver, integration testing
hemidactylus Jul 26, 2024
384a145
default count for query is the base class default
hemidactylus Jul 26, 2024
4d093b7
optionally specify metric in Astra DB vector store
hemidactylus Jul 26, 2024
679769f
support for HCD and other non-Astra databases
hemidactylus Jul 26, 2024
d42c6a8
Add test coverage for the unkown-kwargs warnings
hemidactylus Jul 28, 2024
c12fb1e
Merge branch 'dev' into SL-astra-db-vector-store
hemidactylus Jul 28, 2024
dec82a9
Address most of the comments to the PR
hemidactylus Jul 30, 2024
89a60e6
Merge branch 'dev' into SL-astra-db-vector-store
hemidactylus Jul 30, 2024
cc0e93b
remove ad-hoc ParserEmbeddingDriver, rather extending MockEmbeddingDr…
hemidactylus Jul 30, 2024
fe4bc28
rename mock_output parameter to MockEmbeddingDriver
hemidactylus Jul 30, 2024
c00bf77
clear separation of vector store driver from collection provisioning
hemidactylus Jul 30, 2024
4301f75
Merge branch 'dev' into SL-astra-db-vector-store
hemidactylus Jul 30, 2024
6f126e9
adapt full demo script to latest api
hemidactylus Jul 30, 2024
2d8580d
final driver name AstraDbVectorStoreDriver; private method at end of …
hemidactylus Jul 30, 2024
d7bc1d9
Merge branch 'dev' into SL-astra-db-vector-store
hemidactylus Jul 30, 2024
ad15226
module rename astra_db_vector_store_driver.py => astradb_vector_store…
hemidactylus Jul 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `@observable` decorator for selecting which functions/methods to provide observability for.
- `GenericArtifact` for storing any data.
- `BaseTextArtifact` for text-based Artifacts to subclass.
- `AstraDBVectorStoreDriver` to support DataStax Astra DB as a vector store.

### Changed
- **BREAKING**: `BaseVectorStoreDriver.upsert_text_artifacts` optional arguments are now keyword-only arguments.
Expand Down
73 changes: 73 additions & 0 deletions docs/examples/query-webpage-astra-db.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
Required Griptape extras:

```
pip install griptape[drivers-vector-astra-db,drivers-web-scraper-trafilatura]
```
collindutter marked this conversation as resolved.
Show resolved Hide resolved

Python script:
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved

```python
import os

from griptape.drivers import (
AstraDBVectorStoreDriver,
OpenAiChatPromptDriver,
OpenAiEmbeddingDriver,
)
from griptape.engines.rag import RagEngine
from griptape.engines.rag.modules import (
PromptResponseRagModule,
VectorStoreRetrievalRagModule,
)
from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage
from griptape.loaders import WebLoader
from griptape.structures import Agent
from griptape.tools import RagClient, TaskMemoryClient

if __name__ == "__main__":
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
namespace = "datastax_blog"
input_blogpost = (
"www.datastax.com/blog/indexing-all-of-wikipedia-on-a-laptop"
)

vector_store_driver = AstraDBVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(),
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
collection_name="griptape_test_collection",
astra_db_namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)

engine = RagEngine(
retrieval_stage=RetrievalRagStage(
retrieval_modules=[
VectorStoreRetrievalRagModule(
vector_store_driver=vector_store_driver,
query_params={
"count": 2,
"namespace": namespace,
},
)
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")
)
)
)

vector_store_driver.upsert_text_artifacts(
{namespace: WebLoader(max_tokens=256).load(input_blogpost)}
)

vector_store_tool = RagClient(
description="A DataStax blog post",
rag_engine=engine,
)
agent = Agent(tools=[vector_store_tool, TaskMemoryClient(off_prompt=False)])
agent.run(
"What engine made possible to index such an amount of data, "
"and what kind of tuning was required?"
)
```
46 changes: 46 additions & 0 deletions docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,49 @@ values = [r.to_artifact().value for r in results]

print("\n\n".join(values))
```

### Astra DB

!!! info
This Driver requires the `drivers-vector-astra-db` [extra](../index.md#extras).

The AstraDBVectorStoreDriver supports [DataStax Astra DB](https://www.datastax.com/products/datastax-astra).

The following example shows how to store vector entries and query the information using the driver:

```python
import os
from griptape.drivers import AstraDBVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

# Astra DB secrets and connection parameters
api_endpoint = os.environ["ASTRA_DB_API_ENDPOINT"]
token = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
astra_db_namespace = os.environ.get("ASTRA_DB_KEYSPACE") # optional
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved

# Initialize an Embedding Driver.
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

vector_store_driver = AstraDBVectorStoreDriver(
embedding_driver=embedding_driver,
api_endpoint=api_endpoint,
token=token,
collection_name="astra_db_demo",
astra_db_namespace=astra_db_namespace,
)

# Load Artifacts from the web
artifacts = WebLoader().load("https://www.griptape.ai")

# Upsert Artifacts into the Vector Store Driver
[
vector_store_driver.upsert_text_artifact(a, namespace="griptape")
for a in artifacts
]

results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .vector.azure_mongodb_vector_store_driver import AzureMongoDbVectorStoreDriver
from .vector.dummy_vector_store_driver import DummyVectorStoreDriver
from .vector.qdrant_vector_store_driver import QdrantVectorStoreDriver
from .vector.astra_db_vector_store_driver import AstraDBVectorStoreDriver
from .vector.griptape_cloud_knowledge_base_vector_store_driver import GriptapeCloudKnowledgeBaseVectorStoreDriver

from .sql.base_sql_driver import BaseSqlDriver
Expand Down Expand Up @@ -155,6 +156,7 @@
"AmazonOpenSearchVectorStoreDriver",
"PgVectorVectorStoreDriver",
"QdrantVectorStoreDriver",
"AstraDBVectorStoreDriver",
"DummyVectorStoreDriver",
"GriptapeCloudKnowledgeBaseVectorStoreDriver",
"BaseSqlDriver",
Expand Down
224 changes: 224 additions & 0 deletions griptape/drivers/vector/astra_db_vector_store_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from __future__ import annotations

import logging
import warnings
from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

from griptape.drivers import BaseVectorStoreDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from astrapy import Collection
from astrapy.authentication import TokenProvider

GRIPTAPE_VERSION: Optional[str]
try:
from importlib import metadata

GRIPTAPE_VERSION = metadata.version("griptape")
except Exception:
GRIPTAPE_VERSION = None
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved

logging.basicConfig(level=logging.WARNING)
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved


COLLECTION_INDEXING = {"deny": ["meta.artifact"]}
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved


@define
class AstraDBVectorStoreDriver(BaseVectorStoreDriver):
collindutter marked this conversation as resolved.
Show resolved Hide resolved
"""A Vector Store Driver for Astra DB.

Attributes:
embedding_driver: a `griptape.drivers.BaseEmbeddingDriver` for embedding computations within the store
api_endpoint: the "API Endpoint" for the Astra DB instance.
token: a Database Token ("AstraCS:...") secret to access Astra DB. An instance of `astrapy.authentication.TokenProvider` is also accepted.
collection_name: the name of the collection on Astra DB.
environment: the environment ("prod", "hcd", ...) hosting the target Data API.
It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values.
dimension: the number of components for embedding vectors. If not provided, it will be guessed from the embedding driver.
metric: the similarity metric to use, one of "dot_product", "euclidean" or "cosine".
If omitted, the server default ("cosine") will be used. See also values of `astrapy.constants.VectorMetric`.
If the vectors are normalized to unit norm, choosing "dot_product" over cosine yields up to 2x speedup in searches.
astra_db_namespace: optional specification of the namespace (in the Astra database) for the data.
*Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store.
collindutter marked this conversation as resolved.
Show resolved Hide resolved
"""

api_endpoint: str = field(kw_only=True, metadata={"serializable": True})
token: Optional[str | TokenProvider] = field(kw_only=True, default=None, metadata={"serializable": False})
collection_name: str = field(kw_only=True, metadata={"serializable": True})
environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
dimension: Optional[int] = field(kw_only=True, default=None, metadata={"serializable": True})
metric: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

collection: Collection = field(init=False)

def __attrs_post_init__(self) -> None:
astrapy = import_optional_dependency("astrapy")
if not self.dimension:
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
# auto-compute dimension from the embedding
self.dimension = len(self.embedding_driver.embed_string("This is a sample text."))
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
self.collection = (
astrapy.DataAPIClient(
caller_name="griptape",
caller_version=GRIPTAPE_VERSION,
environment=self.environment,
)
.get_database(
self.api_endpoint,
token=self.token,
namespace=self.astra_db_namespace,
)
.create_collection(
name=self.collection_name,
dimension=self.dimension,
metric=self.metric,
indexing=COLLECTION_INDEXING,
check_exists=False,
)
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
)

def delete_vector(self, vector_id: str) -> None:
"""Delete a vector from Astra DB store.

The method succeeds regardless of whether a vector with the provided ID
was actually stored or not in the first place.

Args:
vector_id: ID of the vector to delete.
"""
self.collection.delete_one({"_id": vector_id})

def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs: Any,
) -> str:
"""Write a vector to the Astra DB store.

In case the provided ID exists already, an overwrite will take place.

Args:
vector: the vector to be upserted.
vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed.
namespace: a namespace (a grouping within the vector store) to assign the vector to.
meta: a metadata dictionary associated to the vector.
kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.

Returns:
the ID of the written vector (str).
"""
if kwargs:
warnings.warn(
"Unhandled keyword argument(s) provided to AstraDBVectorStore.upsert_vector: "
f"'{','.join(sorted(kwargs.keys()))}'. These will be ignored.",
stacklevel=2,
)
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
document = {
k: v
for k, v in {"$vector": vector, "_id": vector_id, "namespace": namespace, "meta": meta}.items()
if v is not None
}
if vector_id is not None:
self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True)
return vector_id
else:
insert_result = self.collection.insert_one(document)
return insert_result.inserted_id

def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
"""Load a single vector entry from the Astra DB store given its ID.

Args:
vector_id: the ID of the required vector.
namespace: a namespace, within the vector store, to constrain the search.

Returns:
The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None.
"""
find_filter = {k: v for k, v in {"_id": vector_id, "namespace": namespace}.items() if v is not None}
match = self.collection.find_one(filter=find_filter, projection={"*": 1})
if match:
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
return BaseVectorStoreDriver.Entry(
id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
)
else:
return None

def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
"""Load entries from the Astra DB store.

Args:
namespace: a namespace, within the vector store, to constrain the search.

Returns:
A list of vector (`BaseVectorStoreDriver.Entry`) entries.
"""
find_filter: dict[str, str] = {} if namespace is None else {"namespace": namespace}
return [
BaseVectorStoreDriver.Entry(
id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
)
for match in self.collection.find(filter=find_filter, projection={"*": 1})
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
"""Run a similarity search on the Astra DB store, based on a query string.

Args:
query: the query string.
count: the maximum number of results to return. If omitted, defaults will apply.
namespace: the namespace to filter results by.
include_vectors: whether to include vector data in the results.
kwargs: additional keyword arguments. Currently only the free-form dict `filter`
is recognized (and goes straight to the Data API query);
others will generate a warning and be ignored.

Returns:
A list of vector (`BaseVectorStoreDriver.Entry`) entries,
with their `score` attribute set to the vector similarity to the query.
"""
query_filter: Optional[dict[str, Any]] = kwargs.pop("filter", None)
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
if kwargs:
warnings.warn(
"Unhandled keyword argument(s) provided to AstraDBVectorStore.query: "
f"'{','.join(sorted(kwargs.keys()))}'. These will be ignored.",
stacklevel=2,
)
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved
find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace}
find_filter = {**(query_filter or {}), **find_filter_ns}
find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
vector = self.embedding_driver.embed_string(query)
ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
matches = self.collection.find(
filter=find_filter,
sort={"$vector": vector},
limit=ann_limit,
projection=find_projection,
include_similarity=True,
)
return [
BaseVectorStoreDriver.Entry(
id=match["_id"],
vector=match.get("$vector"),
score=match["$similarity"],
meta=match.get("meta"),
namespace=match.get("namespace"),
)
for match in matches
]
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,6 @@ nav:
- Load and Query Pinecone: "examples/load-and-query-pinecone.md"
- Load and Query Marqo: "examples/load-query-and-chat-marqo.md"
- Query a Webpage: "examples/query-webpage.md"
- RAG with Astra DB vector store: "examples/query-webpage-astra-db.md"
- Reference Guide: "reference/"
- Trade School: "https://learn.griptape.ai"
Loading
Loading