Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DataStax Astra DB vector store driver #1034

Merged
merged 25 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9f8a6b0
initial commit for astra db branch
hemidactylus Jul 25, 2024
3b3e4e7
wip on unit tests with mocker
hemidactylus Jul 25, 2024
472c523
slight unit test refactor
hemidactylus Jul 25, 2024
bd2bdb5
full unit test coverage
hemidactylus Jul 26, 2024
26c5bb8
docstrings; added a find-nothing test
hemidactylus Jul 26, 2024
3586465
docs about the Astra vector store + standalone example; changelog
hemidactylus Jul 26, 2024
21ed76b
minor formatting in code and docs
hemidactylus Jul 26, 2024
0d49f83
'dimension' parameter optional
hemidactylus Jul 26, 2024
6609eb0
Astra DB vector store driver, integration testing
hemidactylus Jul 26, 2024
384a145
default count for query is the base class default
hemidactylus Jul 26, 2024
4d093b7
optionally specify metric in Astra DB vector store
hemidactylus Jul 26, 2024
679769f
support for HCD and other non-Astra databases
hemidactylus Jul 26, 2024
d42c6a8
Add test coverage for the unkown-kwargs warnings
hemidactylus Jul 28, 2024
c12fb1e
Merge branch 'dev' into SL-astra-db-vector-store
hemidactylus Jul 28, 2024
dec82a9
Address most of the comments to the PR
hemidactylus Jul 30, 2024
89a60e6
Merge branch 'dev' into SL-astra-db-vector-store
hemidactylus Jul 30, 2024
cc0e93b
remove ad-hoc ParserEmbeddingDriver, rather extending MockEmbeddingDr…
hemidactylus Jul 30, 2024
fe4bc28
rename mock_output parameter to MockEmbeddingDriver
hemidactylus Jul 30, 2024
c00bf77
clear separation of vector store driver from collection provisioning
hemidactylus Jul 30, 2024
4301f75
Merge branch 'dev' into SL-astra-db-vector-store
hemidactylus Jul 30, 2024
6f126e9
adapt full demo script to latest api
hemidactylus Jul 30, 2024
2d8580d
final driver name AstraDbVectorStoreDriver; private method at end of …
hemidactylus Jul 30, 2024
d7bc1d9
Merge branch 'dev' into SL-astra-db-vector-store
hemidactylus Jul 30, 2024
ad15226
module rename astra_db_vector_store_driver.py => astradb_vector_store…
hemidactylus Jul 30, 2024
bdef2ae
Fix integration tests
collindutter Jul 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ jobs:
ZENROWS_API_KEY: ${{ secrets.INTEG_ZENROWS_API_KEY }}
QDRANT_CLUSTER_ENDPOINT: ${{ secrets.INTEG_QDRANT_CLUSTER_ENDPOINT }}
QDRANT_CLUSTER_API_KEY: ${{ secrets.INTEG_QDRANT_CLUSTER_API_KEY }}
ASTRA_DB_API_ENDPOINT: ${{ secrets.INTEG_ASTRA_DB_API_ENDPOINT }}
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.INTEG_ASTRA_DB_APPLICATION_TOKEN }}
services:
postgres:
image: ankane/pgvector:v0.5.0
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- `AstraDbVectorStoreDriver` to support DataStax Astra DB as a vector store.

## [0.29.0] - 2024-07-30

### Added
Expand Down
78 changes: 78 additions & 0 deletions docs/examples/query-webpage-astra-db.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
The following example script ingests a Web page (a blog post),
stores its chunked contents on Astra DB through the Astra DB vector store driver,
and finally runs a RAG process to answer a question specific to the topic of the
Web page.

This script requires that a vector collection has been created in the Astra database
(with name `"griptape_test_collection"` and vector dimension matching the embedding being used, i.e. 1536 in this case).

_Note:_ Besides the [Astra DB](../griptape-framework/drivers/vector-store-drivers.md#astra-db) extra,
this example requires the `drivers-web-scraper-trafilatura`
Griptape extra to be installed as well.


```python
import os

from griptape.drivers import (
AstraDbVectorStoreDriver,
OpenAiChatPromptDriver,
OpenAiEmbeddingDriver,
)
from griptape.engines.rag import RagEngine
from griptape.engines.rag.modules import (
PromptResponseRagModule,
VectorStoreRetrievalRagModule,
)
from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage
from griptape.loaders import WebLoader
from griptape.structures import Agent
from griptape.tools import RagClient, TaskMemoryClient


namespace = "datastax_blog"
input_blogpost = (
"www.datastax.com/blog/indexing-all-of-wikipedia-on-a-laptop"
)

vector_store_driver = AstraDbVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(),
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
collection_name="griptape_test_collection",
astra_db_namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)

engine = RagEngine(
retrieval_stage=RetrievalRagStage(
retrieval_modules=[
VectorStoreRetrievalRagModule(
vector_store_driver=vector_store_driver,
query_params={
"count": 2,
"namespace": namespace,
},
)
]
),
response_stage=ResponseRagStage(
response_module=PromptResponseRagModule(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")
)
)
)

vector_store_driver.upsert_text_artifacts(
{namespace: WebLoader(max_tokens=256).load(input_blogpost)}
)

vector_store_tool = RagClient(
description="A DataStax blog post",
rag_engine=engine,
)
agent = Agent(tools=[vector_store_tool, TaskMemoryClient(off_prompt=False)])
agent.run(
"What engine made possible to index such an amount of data, "
"and what kind of tuning was required?"
)
```
46 changes: 46 additions & 0 deletions docs/griptape-framework/drivers/vector-store-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,49 @@ values = [r.to_artifact().value for r in results]

print("\n\n".join(values))
```

### Astra DB

!!! info
This Driver requires the `drivers-vector-astra-db` [extra](../index.md#extras).

The AstraDbVectorStoreDriver supports [DataStax Astra DB](https://www.datastax.com/products/datastax-astra).

The following example shows how to store vector entries and query the information using the driver:

```python
import os
from griptape.drivers import AstraDbVectorStoreDriver, OpenAiEmbeddingDriver
from griptape.loaders import WebLoader

# Astra DB secrets and connection parameters
api_endpoint = os.environ["ASTRA_DB_API_ENDPOINT"]
token = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
astra_db_namespace = os.environ.get("ASTRA_DB_KEYSPACE") # optional

# Initialize an Embedding Driver.
embedding_driver = OpenAiEmbeddingDriver(api_key=os.environ["OPENAI_API_KEY"])

vector_store_driver = AstraDbVectorStoreDriver(
embedding_driver=embedding_driver,
api_endpoint=api_endpoint,
token=token,
collection_name="griptape_test_collection",
astra_db_namespace=astra_db_namespace, # optional
)

# Load Artifacts from the web
artifacts = WebLoader().load("https://www.griptape.ai")

# Upsert Artifacts into the Vector Store Driver
[
vector_store_driver.upsert_text_artifact(a, namespace="griptape")
for a in artifacts
]

results = vector_store_driver.query(query="What is griptape?")

values = [r.to_artifact().value for r in results]

print("\n\n".join(values))
```
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .vector.azure_mongodb_vector_store_driver import AzureMongoDbVectorStoreDriver
from .vector.dummy_vector_store_driver import DummyVectorStoreDriver
from .vector.qdrant_vector_store_driver import QdrantVectorStoreDriver
from .vector.astradb_vector_store_driver import AstraDbVectorStoreDriver
from .vector.griptape_cloud_knowledge_base_vector_store_driver import GriptapeCloudKnowledgeBaseVectorStoreDriver

from .sql.base_sql_driver import BaseSqlDriver
Expand Down Expand Up @@ -171,6 +172,7 @@
"AmazonOpenSearchVectorStoreDriver",
"PgVectorVectorStoreDriver",
"QdrantVectorStoreDriver",
"AstraDbVectorStoreDriver",
"DummyVectorStoreDriver",
"GriptapeCloudKnowledgeBaseVectorStoreDriver",
"BaseSqlDriver",
Expand Down
184 changes: 184 additions & 0 deletions griptape/drivers/vector/astradb_vector_store_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from attrs import define, field

from griptape.drivers import BaseVectorStoreDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from astrapy import Collection
from astrapy.authentication import TokenProvider


@define
class AstraDbVectorStoreDriver(BaseVectorStoreDriver):
"""A Vector Store Driver for Astra DB.

Attributes:
embedding_driver: a `griptape.drivers.BaseEmbeddingDriver` for embedding computations within the store
api_endpoint: the "API Endpoint" for the Astra DB instance.
token: a Database Token ("AstraCS:...") secret to access Astra DB. An instance of `astrapy.authentication.TokenProvider` is also accepted.
collection_name: the name of the collection on Astra DB. The collection must have been created beforehand,
and support vectors with a vector dimension matching the embeddings being used by this driver.
environment: the environment ("prod", "hcd", ...) hosting the target Data API.
It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values.
astra_db_namespace: optional specification of the namespace (in the Astra database) for the data.
*Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store.
"""

api_endpoint: str = field(kw_only=True, metadata={"serializable": True})
token: Optional[str | TokenProvider] = field(kw_only=True, default=None, metadata={"serializable": False})
collection_name: str = field(kw_only=True, metadata={"serializable": True})
environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

collection: Collection = field(init=False)

def __attrs_post_init__(self) -> None:
astrapy = import_optional_dependency("astrapy")
self.collection = (
astrapy.DataAPIClient(
caller_name="griptape",
environment=self.environment,
)
.get_database(
self.api_endpoint,
token=self.token,
namespace=self.astra_db_namespace,
)
.get_collection(
name=self.collection_name,
)
)

def delete_vector(self, vector_id: str) -> None:
"""Delete a vector from Astra DB store.

The method succeeds regardless of whether a vector with the provided ID
was actually stored or not in the first place.

Args:
vector_id: ID of the vector to delete.
"""
self.collection.delete_one({"_id": vector_id})

def upsert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs: Any,
) -> str:
"""Write a vector to the Astra DB store.

In case the provided ID exists already, an overwrite will take place.

Args:
vector: the vector to be upserted.
vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed.
namespace: a namespace (a grouping within the vector store) to assign the vector to.
meta: a metadata dictionary associated to the vector.
kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.

Returns:
the ID of the written vector (str).
"""
document = {
k: v
for k, v in {"$vector": vector, "_id": vector_id, "namespace": namespace, "meta": meta}.items()
if v is not None
}
if vector_id is not None:
self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True)
return vector_id
else:
insert_result = self.collection.insert_one(document)
return insert_result.inserted_id

def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
"""Load a single vector entry from the Astra DB store given its ID.

Args:
vector_id: the ID of the required vector.
namespace: a namespace, within the vector store, to constrain the search.

Returns:
The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None.
"""
find_filter = {k: v for k, v in {"_id": vector_id, "namespace": namespace}.items() if v is not None}
match = self.collection.find_one(filter=find_filter, projection={"*": 1})
if match is not None:
return BaseVectorStoreDriver.Entry(
id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
)
else:
return None

def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
"""Load entries from the Astra DB store.

Args:
namespace: a namespace, within the vector store, to constrain the search.

Returns:
A list of vector (`BaseVectorStoreDriver.Entry`) entries.
"""
find_filter: dict[str, str] = {} if namespace is None else {"namespace": namespace}
return [
BaseVectorStoreDriver.Entry(
id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("namespace")
)
for match in self.collection.find(filter=find_filter, projection={"*": 1})
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
"""Run a similarity search on the Astra DB store, based on a query string.

Args:
query: the query string.
count: the maximum number of results to return. If omitted, defaults will apply.
namespace: the namespace to filter results by.
include_vectors: whether to include vector data in the results.
kwargs: additional keyword arguments. Currently only the free-form dict `filter`
is recognized (and goes straight to the Data API query);
others will generate a warning and be ignored.

Returns:
A list of vector (`BaseVectorStoreDriver.Entry`) entries,
with their `score` attribute set to the vector similarity to the query.
"""
query_filter: Optional[dict[str, Any]] = kwargs.get("filter")
find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace}
find_filter = {**(query_filter or {}), **find_filter_ns}
find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
vector = self.embedding_driver.embed_string(query)
ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
matches = self.collection.find(
filter=find_filter,
sort={"$vector": vector},
limit=ann_limit,
projection=find_projection,
include_similarity=True,
)
return [
BaseVectorStoreDriver.Entry(
id=match["_id"],
vector=match.get("$vector"),
score=match["$similarity"],
meta=match.get("meta"),
namespace=match.get("namespace"),
)
for match in matches
]
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,6 @@ nav:
- Load and Query Pinecone: "examples/load-and-query-pinecone.md"
- Load and Query Marqo: "examples/load-query-and-chat-marqo.md"
- Query a Webpage: "examples/query-webpage.md"
- RAG with Astra DB vector store: "examples/query-webpage-astra-db.md"
- Reference Guide: "reference/"
- Trade School: "https://learn.griptape.ai"
Loading
Loading