Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collapse create base text units #1178

Merged
merged 14 commits into from
Sep 23, 2024
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240920215241796658.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create_base_text_units."
}
1 change: 1 addition & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ aembed
dedupe
dropna
dtypes
notna

# LLM Terms
AOAI
Expand Down
36 changes: 25 additions & 11 deletions graphrag/index/verbs/genid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def genid(
input: VerbInput,
to: str,
method: str = "md5_hash",
hash: list[str] = [], # noqa A002
hash: list[str] | None = None, # noqa A002
**_kwargs: dict,
) -> TableContainer:
"""
Expand Down Expand Up @@ -52,15 +52,29 @@ def genid(
"""
data = cast(pd.DataFrame, input.source.table)

if method == "md5_hash":
if len(hash) == 0:
msg = 'Must specify the "hash" columns to use md5_hash method'
output = genid_df(data, to, method, hash)

return TableContainer(table=output)


def genid_df(
input: pd.DataFrame,
to: str,
method: str = "md5_hash",
hash: list[str] | None = None, # noqa A002
):
"""Generate a unique id for each row in the tabular data."""
data = input
match method:
case "md5_hash":
if not hash:
msg = 'Must specify the "hash" columns to use md5_hash method'
raise ValueError(msg)
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
case "increment":
data[to] = data.index + 1
case _:
msg = f"Unknown method {method}"
raise ValueError(msg)

data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
elif method == "increment":
data[to] = data.index + 1
else:
msg = f"Unknown method {method}"
raise ValueError(msg)
return TableContainer(table=data)
return data
19 changes: 17 additions & 2 deletions graphrag/index/verbs/text/chunk/text_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,24 @@ def chunk(
type: sentence
```
"""
input_table = cast(pd.DataFrame, input.get_input())

output = chunk_df(input_table, column, to, callbacks, strategy)

return TableContainer(table=output)


def chunk_df(
input: pd.DataFrame,
column: str,
to: str,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""Chunk a piece of text into smaller pieces."""
output = input
if strategy is None:
strategy = {}
output = cast(pd.DataFrame, input.get_input())
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
strategy_config = {**strategy}
strategy_exec = load_strategy(strategy_name)
Expand All @@ -102,7 +117,7 @@ def chunk(
),
axis=1,
)
return TableContainer(table=output)
return output


def run_strategy(
Expand Down
87 changes: 6 additions & 81 deletions graphrag/index/workflows/v1/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,91 +22,16 @@ def build_steps(
chunk_column_name = config.get("chunk_column", "chunk")
chunk_by_columns = config.get("chunk_by", []) or []
n_tokens_column_name = config.get("n_tokens_column", "n_tokens")
text_chunk = config.get("text_chunk", {})
return [
{
"verb": "orderby",
"verb": "create_base_text_units",
"args": {
"orders": [
# sort for reproducibility
{"column": "id", "direction": "asc"},
]
"chunk_column_name": chunk_column_name,
"n_tokens_column_name": n_tokens_column_name,
"chunk_by_columns": chunk_by_columns,
**text_chunk,
},
"input": {"source": DEFAULT_INPUT_NAME},
},
{
"verb": "zip",
"args": {
# Pack the document ids with the text
# So when we unpack the chunks, we can restore the document id
"columns": ["id", "text"],
"to": "text_with_ids",
},
},
{
"verb": "aggregate_override",
"args": {
"groupby": [*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
"aggregations": [
{
"column": "text_with_ids",
"operation": "array_agg",
"to": "texts",
}
],
},
},
{
"verb": "chunk",
"args": {"column": "texts", "to": "chunks", **config.get("text_chunk", {})},
},
{
"verb": "select",
"args": {
"columns": [*chunk_by_columns, "chunks"],
},
},
{
"verb": "unroll",
"args": {
"column": "chunks",
},
},
{
"verb": "rename",
"args": {
"columns": {
"chunks": chunk_column_name,
}
},
},
{
"verb": "genid",
"args": {
# Generate a unique id for each chunk
"to": "chunk_id",
"method": "md5_hash",
"hash": [chunk_column_name],
},
},
{
"verb": "unzip",
"args": {
"column": chunk_column_name,
"to": ["document_ids", chunk_column_name, n_tokens_column_name],
},
},
{"verb": "copy", "args": {"column": "chunk_id", "to": "id"}},
{
# ELIMINATE EMPTY CHUNKS
"verb": "filter",
"args": {
"column": chunk_column_name,
"criteria": [
{
"type": "value",
"operator": "is not empty",
}
],
},
},
]
2 changes: 2 additions & 0 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""The Indexing Engine workflows -> subflows package root."""

from .create_base_documents import create_base_documents
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_nodes import create_final_nodes
from .create_final_relationships_post_embedding import (
Expand All @@ -16,6 +17,7 @@

__all__ = [
"create_base_documents",
"create_base_text_units",
"create_final_communities",
"create_final_nodes",
"create_final_relationships_post_embedding",
Expand Down
86 changes: 86 additions & 0 deletions graphrag/index/workflows/v1/subflows/create_base_text_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform base text_units."""

from typing import Any, cast

import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.verbs.genid import genid_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df


@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
def create_base_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
chunk_column_name: str,
n_tokens_column_name: str,
chunk_by_columns: list[str],
strategy: dict[str, Any] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform base text_units."""
table = cast(pd.DataFrame, input.get_input())

sort = table.sort_values(by=["id"], ascending=[True])

sort["text_with_ids"] = list(
zip(*[sort[col] for col in ["id", "text"]], strict=True)
)

aggregated = aggregate_df(
sort,
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
aggregations=[
{
"column": "text_with_ids",
"operation": "array_agg",
"to": "texts",
}
],
)

chunked = chunk_df(
aggregated,
column="texts",
to="chunks",
callbacks=callbacks,
strategy=strategy,
)

chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]])
chunked = chunked.explode("chunks")
chunked.rename(
columns={
"chunks": chunk_column_name,
},
inplace=True,
)

chunked = genid_df(
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name]
)

chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame(
chunked[chunk_column_name].tolist(), index=chunked.index
)
chunked["id"] = chunked["chunk_id"]

filtered = chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)
natoverse marked this conversation as resolved.
Show resolved Hide resolved

return create_verb_result(
cast(
Table,
filtered,
)
)
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
1,
2000
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 10
},
"create_base_extracted_entities": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
1,
2000
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 10
},
"create_base_extracted_entities": {
Expand Down
35 changes: 35 additions & 0 deletions tests/verbs/test_create_base_text_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from graphrag.index.workflows.v1.create_base_text_units import (
build_steps,
workflow_name,
)

from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
)


async def test_create_base_text_units():
input_tables = load_input_tables(inputs=[])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)
# test data was created with 4o, so we need to match the encoding for chunks to be identical
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"

steps = build_steps(config)

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

compare_outputs(actual, expected)
2 changes: 2 additions & 0 deletions tests/verbs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
# remove the workflow: prefix if it exists, because that is not part of the actual table filename
name = input.replace("workflow:", "")
input_tables[input] = pd.read_parquet(f"tests/verbs/data/{name}.parquet")

return input_tables


Expand All @@ -42,6 +43,7 @@ def load_expected(output: str) -> pd.DataFrame:
def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
"""Instantiates the bare minimum config to get a default workflow config for testing."""
config = create_graphrag_config()
print(config)
pipeline_config = create_pipeline_config(config)
print(pipeline_config.workflows)
result = next(conf for conf in pipeline_config.workflows if conf.name == name)
Expand Down
Loading