Skip to content

Commit

Permalink
Collapse relationship embeddings (#1199)
Browse files Browse the repository at this point in the history
* Merge text_embed into a single relationships subflow

* Update smoke tests

* Semver

* Spelling
  • Loading branch information
natoverse authored Sep 24, 2024
1 parent 1755afb commit f518c8b
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 83 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240923202146450500.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Merge text_embed into create-final-relationships subflow."
}
37 changes: 26 additions & 11 deletions graphrag/index/verbs/text/embed/text_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,23 @@ async def text_embed(
<...>
```
"""
input_df = cast(pd.DataFrame, input.get_input())
result_df = await text_embed_df(
input_df, callbacks, cache, column, strategy, **kwargs
)
return TableContainer(table=result_df)


# TODO: this ultimately just creates a new column, so our embed function could just generate a series instead of updating the dataframe
async def text_embed_df(
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
strategy: dict,
**kwargs,
):
"""Embed a piece of text into a vector space."""
vector_store_config = strategy.get("vector_store")

if vector_store_config:
Expand Down Expand Up @@ -113,28 +130,28 @@ async def text_embed(


async def _text_embed_in_memory(
input: VerbInput,
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
strategy: dict,
to: str,
):
output_df = cast(pd.DataFrame, input.get_input())
output_df = input
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
input_table = input.get_input()
input_table = input

texts: list[str] = input_table[column].to_numpy().tolist()
result = await strategy_exec(texts, callbacks, cache, strategy_args)

output_df[to] = result.embeddings
return TableContainer(table=output_df)
return output_df


async def _text_embed_with_vector_store(
input: VerbInput,
input: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
column: str,
Expand All @@ -144,7 +161,7 @@ async def _text_embed_with_vector_store(
store_in_table: bool = False,
to: str = "",
):
output_df = cast(pd.DataFrame, input.get_input())
output_df = input
strategy_type = strategy["type"]
strategy_exec = load_strategy(strategy_type)
strategy_args = {**strategy}
Expand Down Expand Up @@ -179,10 +196,8 @@ async def _text_embed_with_vector_store(

all_results = []

while insert_batch_size * i < input.get_input().shape[0]:
batch = input.get_input().iloc[
insert_batch_size * i : insert_batch_size * (i + 1)
]
while insert_batch_size * i < input.shape[0]:
batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)]
texts: list[str] = batch[column].to_numpy().tolist()
titles: list[str] = batch[title_column].to_numpy().tolist()
ids: list[str] = batch[id_column].to_numpy().tolist()
Expand Down Expand Up @@ -218,7 +233,7 @@ async def _text_embed_with_vector_store(
if store_in_table:
output_df[to] = all_results

return TableContainer(table=output_df)
return output_df


def _create_vector_store(
Expand Down
23 changes: 4 additions & 19 deletions graphrag/index/workflows/v1/create_final_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,15 @@ def build_steps(
"relationship_description_embed", base_text_embed
)
skip_description_embedding = config.get("skip_description_embedding", False)

return [
{
"id": "pre_embedding",
"verb": "create_final_relationships_pre_embedding",
"input": {"source": "workflow:create_base_entity_graph"},
},
{
"id": "description_embedding",
"verb": "text_embed",
"enabled": not skip_description_embedding,
"verb": "create_final_relationships",
"args": {
"embedding_name": "relationship_description",
"column": "description",
"to": "description_embedding",
**relationship_description_embed_config,
"skip_embedding": skip_description_embedding,
"text_embed": relationship_description_embed_config,
},
},
{
"verb": "create_final_relationships_post_embedding",
"input": {
"source": "pre_embedding"
if skip_description_embedding
else "description_embedding",
"source": "workflow:create_base_entity_graph",
"nodes": "workflow:create_final_nodes",
},
},
Expand Down
10 changes: 3 additions & 7 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_nodes import create_final_nodes
from .create_final_relationships_post_embedding import (
create_final_relationships_post_embedding,
)
from .create_final_relationships_pre_embedding import (
create_final_relationships_pre_embedding,
from .create_final_relationships import (
create_final_relationships,
)
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding

Expand All @@ -20,7 +17,6 @@
"create_base_text_units",
"create_final_communities",
"create_final_nodes",
"create_final_relationships_post_embedding",
"create_final_relationships_pre_embedding",
"create_final_relationships",
"create_final_text_units_pre_embedding",
]
Original file line number Diff line number Diff line change
@@ -1,37 +1,64 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform final relationships after they are embedded."""
"""All the steps to transform final relationships before they are embedded."""

from typing import Any, cast

import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.cache import PipelineCache
from graphrag.index.utils.ds_util import get_required_input_table
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
compute_edge_combined_degree_df,
)
from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.text.embed.text_embed import text_embed_df


@verb(
name="create_final_relationships_post_embedding",
name="create_final_relationships",
treats_input_tables_as_immutable=True,
)
def create_final_relationships_post_embedding(
async def create_final_relationships(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
text_embed: dict,
skip_embedding: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final relationships after they are embedded."""
"""All the steps to transform final relationships before they are embedded."""
table = cast(pd.DataFrame, input.get_input())
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)

pruned_edges = table.drop(columns=["level"])
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")

graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)

filtered = cast(
pd.DataFrame, graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
)

if not skip_embedding:
filtered = await text_embed_df(
filtered,
callbacks,
cache,
column="description",
strategy=text_embed["strategy"],
to="description_embedding",
embedding_name="relationship_description",
)

pruned_edges = filtered.drop(columns=["level"])

filtered_nodes = cast(
pd.DataFrame,
Expand Down

This file was deleted.

2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_final_nodes": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
1,
2000
],
"subworkflows": 2,
"subworkflows": 1,
"max_runtime": 100
},
"create_final_nodes": {
Expand Down
29 changes: 29 additions & 0 deletions tests/verbs/test_create_final_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,32 @@ async def test_create_final_relationships():
)

compare_outputs(actual, expected)


async def test_create_final_relationships_with_embeddings():
input_tables = load_input_tables([
"workflow:create_base_entity_graph",
"workflow:create_final_nodes",
])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)

config["skip_description_embedding"] = False
# default config has a detailed standard embed config
# just override the strategy to mock so the rest of the required parameters are in place
config["relationship_description_embed"]["strategy"]["type"] = "mock"

steps = remove_disabled_steps(build_steps(config))

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

assert "description_embedding" in actual.columns
assert len(actual.columns) == len(expected.columns) + 1
# the mock impl returns an array of 3 floats for each embedding
assert len(actual["description_embedding"][0]) == 3
5 changes: 4 additions & 1 deletion tests/verbs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PipelineWorkflowStep,
create_pipeline_config,
)
from graphrag.index.run.utils import _create_run_context


def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
Expand Down Expand Up @@ -61,7 +62,9 @@ async def get_workflow_output(
input_tables=input_tables,
)

await workflow.run()
context = _create_run_context(None, None, None)

await workflow.run(context=context)

# if there's only one output, it is the default here, no name required
return cast(pd.DataFrame, workflow.output())
Expand Down

0 comments on commit f518c8b

Please sign in to comment.