Skip to content

Commit

Permalink
Collapse create base text units (#1178)
Browse files Browse the repository at this point in the history
* Collapse non-attribute verbs

* Include document_column_attributes in collapse

* Remove merge_override verb

* Semver

* Setup initial test and config

* Collapse create_base_text_units

* Semver

* Spelling

* Fix smoke tests

* Addres PR comments

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
  • Loading branch information
natoverse and AlonsoGuevara authored Sep 23, 2024
1 parent be7d3eb commit 1755afb
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 96 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240920215241796658.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Collapse create_base_text_units."
}
1 change: 1 addition & 0 deletions dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ aembed
dedupe
dropna
dtypes
notna

# LLM Terms
AOAI
Expand Down
36 changes: 25 additions & 11 deletions graphrag/index/verbs/genid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def genid(
input: VerbInput,
to: str,
method: str = "md5_hash",
hash: list[str] = [], # noqa A002
hash: list[str] | None = None, # noqa A002
**_kwargs: dict,
) -> TableContainer:
"""
Expand Down Expand Up @@ -52,15 +52,29 @@ def genid(
"""
data = cast(pd.DataFrame, input.source.table)

if method == "md5_hash":
if len(hash) == 0:
msg = 'Must specify the "hash" columns to use md5_hash method'
output = genid_df(data, to, method, hash)

return TableContainer(table=output)


def genid_df(
input: pd.DataFrame,
to: str,
method: str = "md5_hash",
hash: list[str] | None = None, # noqa A002
):
"""Generate a unique id for each row in the tabular data."""
data = input
match method:
case "md5_hash":
if not hash:
msg = 'Must specify the "hash" columns to use md5_hash method'
raise ValueError(msg)
data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
case "increment":
data[to] = data.index + 1
case _:
msg = f"Unknown method {method}"
raise ValueError(msg)

data[to] = data.apply(lambda row: gen_md5_hash(row, hash), axis=1)
elif method == "increment":
data[to] = data.index + 1
else:
msg = f"Unknown method {method}"
raise ValueError(msg)
return TableContainer(table=data)
return data
19 changes: 17 additions & 2 deletions graphrag/index/verbs/text/chunk/text_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,24 @@ def chunk(
type: sentence
```
"""
input_table = cast(pd.DataFrame, input.get_input())

output = chunk_df(input_table, column, to, callbacks, strategy)

return TableContainer(table=output)


def chunk_df(
input: pd.DataFrame,
column: str,
to: str,
callbacks: VerbCallbacks,
strategy: dict[str, Any] | None = None,
) -> pd.DataFrame:
"""Chunk a piece of text into smaller pieces."""
output = input
if strategy is None:
strategy = {}
output = cast(pd.DataFrame, input.get_input())
strategy_name = strategy.get("type", ChunkStrategyType.tokens)
strategy_config = {**strategy}
strategy_exec = load_strategy(strategy_name)
Expand All @@ -102,7 +117,7 @@ def chunk(
),
axis=1,
)
return TableContainer(table=output)
return output


def run_strategy(
Expand Down
87 changes: 6 additions & 81 deletions graphrag/index/workflows/v1/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,91 +22,16 @@ def build_steps(
chunk_column_name = config.get("chunk_column", "chunk")
chunk_by_columns = config.get("chunk_by", []) or []
n_tokens_column_name = config.get("n_tokens_column", "n_tokens")
text_chunk = config.get("text_chunk", {})
return [
{
"verb": "orderby",
"verb": "create_base_text_units",
"args": {
"orders": [
# sort for reproducibility
{"column": "id", "direction": "asc"},
]
"chunk_column_name": chunk_column_name,
"n_tokens_column_name": n_tokens_column_name,
"chunk_by_columns": chunk_by_columns,
**text_chunk,
},
"input": {"source": DEFAULT_INPUT_NAME},
},
{
"verb": "zip",
"args": {
# Pack the document ids with the text
# So when we unpack the chunks, we can restore the document id
"columns": ["id", "text"],
"to": "text_with_ids",
},
},
{
"verb": "aggregate_override",
"args": {
"groupby": [*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
"aggregations": [
{
"column": "text_with_ids",
"operation": "array_agg",
"to": "texts",
}
],
},
},
{
"verb": "chunk",
"args": {"column": "texts", "to": "chunks", **config.get("text_chunk", {})},
},
{
"verb": "select",
"args": {
"columns": [*chunk_by_columns, "chunks"],
},
},
{
"verb": "unroll",
"args": {
"column": "chunks",
},
},
{
"verb": "rename",
"args": {
"columns": {
"chunks": chunk_column_name,
}
},
},
{
"verb": "genid",
"args": {
# Generate a unique id for each chunk
"to": "chunk_id",
"method": "md5_hash",
"hash": [chunk_column_name],
},
},
{
"verb": "unzip",
"args": {
"column": chunk_column_name,
"to": ["document_ids", chunk_column_name, n_tokens_column_name],
},
},
{"verb": "copy", "args": {"column": "chunk_id", "to": "id"}},
{
# ELIMINATE EMPTY CHUNKS
"verb": "filter",
"args": {
"column": chunk_column_name,
"criteria": [
{
"type": "value",
"operator": "is not empty",
}
],
},
},
]
2 changes: 2 additions & 0 deletions graphrag/index/workflows/v1/subflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""The Indexing Engine workflows -> subflows package root."""

from .create_base_documents import create_base_documents
from .create_base_text_units import create_base_text_units
from .create_final_communities import create_final_communities
from .create_final_nodes import create_final_nodes
from .create_final_relationships_post_embedding import (
Expand All @@ -16,6 +17,7 @@

__all__ = [
"create_base_documents",
"create_base_text_units",
"create_final_communities",
"create_final_nodes",
"create_final_relationships_post_embedding",
Expand Down
86 changes: 86 additions & 0 deletions graphrag/index/workflows/v1/subflows/create_base_text_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to transform base text_units."""

from typing import Any, cast

import pandas as pd
from datashaper import (
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.verbs.genid import genid_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df
from graphrag.index.verbs.text.chunk.text_chunk import chunk_df


@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
def create_base_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
chunk_column_name: str,
n_tokens_column_name: str,
chunk_by_columns: list[str],
strategy: dict[str, Any] | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform base text_units."""
table = cast(pd.DataFrame, input.get_input())

sort = table.sort_values(by=["id"], ascending=[True])

sort["text_with_ids"] = list(
zip(*[sort[col] for col in ["id", "text"]], strict=True)
)

aggregated = aggregate_df(
sort,
groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None,
aggregations=[
{
"column": "text_with_ids",
"operation": "array_agg",
"to": "texts",
}
],
)

chunked = chunk_df(
aggregated,
column="texts",
to="chunks",
callbacks=callbacks,
strategy=strategy,
)

chunked = cast(pd.DataFrame, chunked[[*chunk_by_columns, "chunks"]])
chunked = chunked.explode("chunks")
chunked.rename(
columns={
"chunks": chunk_column_name,
},
inplace=True,
)

chunked = genid_df(
chunked, to="chunk_id", method="md5_hash", hash=[chunk_column_name]
)

chunked[["document_ids", chunk_column_name, n_tokens_column_name]] = pd.DataFrame(
chunked[chunk_column_name].tolist(), index=chunked.index
)
chunked["id"] = chunked["chunk_id"]

filtered = chunked[chunked[chunk_column_name].notna()].reset_index(drop=True)

return create_verb_result(
cast(
Table,
filtered,
)
)
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
1,
2000
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 10
},
"create_base_extracted_entities": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
1,
2000
],
"subworkflows": 11,
"subworkflows": 1,
"max_runtime": 10
},
"create_base_extracted_entities": {
Expand Down
35 changes: 35 additions & 0 deletions tests/verbs/test_create_base_text_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from graphrag.index.workflows.v1.create_base_text_units import (
build_steps,
workflow_name,
)

from .util import (
compare_outputs,
get_config_for_workflow,
get_workflow_output,
load_expected,
load_input_tables,
)


async def test_create_base_text_units():
input_tables = load_input_tables(inputs=[])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)
# test data was created with 4o, so we need to match the encoding for chunks to be identical
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"

steps = build_steps(config)

actual = await get_workflow_output(
input_tables,
{
"steps": steps,
},
)

compare_outputs(actual, expected)
2 changes: 2 additions & 0 deletions tests/verbs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
# remove the workflow: prefix if it exists, because that is not part of the actual table filename
name = input.replace("workflow:", "")
input_tables[input] = pd.read_parquet(f"tests/verbs/data/{name}.parquet")

return input_tables


Expand All @@ -42,6 +43,7 @@ def load_expected(output: str) -> pd.DataFrame:
def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
"""Instantiates the bare minimum config to get a default workflow config for testing."""
config = create_graphrag_config()
print(config)
pipeline_config = create_pipeline_config(config)
print(pipeline_config.workflows)
result = next(conf for conf in pipeline_config.workflows if conf.name == name)
Expand Down

0 comments on commit 1755afb

Please sign in to comment.