Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove aggregate_df from final coomunities and final text units #1179

Merged
merged 10 commits into from
Sep 23, 2024
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240920221112632172.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Remove aggregate_df from final coomunities and final text units"
}
63 changes: 22 additions & 41 deletions graphrag/index/workflows/v1/subflows/create_final_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.verbs.graph.unpack import unpack_graph_df
from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(name="create_final_communities", treats_input_tables_as_immutable=True)
Expand All @@ -30,54 +29,35 @@ def create_final_communities(
graph_nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes")
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")

# Merge graph_nodes with graph_edges for both source and target matches
source_clusters = graph_nodes.merge(
graph_edges,
left_on="label",
right_on="source",
how="inner",
graph_edges, left_on="label", right_on="source", how="inner"
)

target_clusters = graph_nodes.merge(
graph_edges,
left_on="label",
right_on="target",
how="inner",
graph_edges, left_on="label", right_on="target", how="inner"
)

concatenated_clusters = pd.concat(
[source_clusters, target_clusters], ignore_index=True
)
# Concatenate the source and target clusters
clusters = pd.concat([source_clusters, target_clusters], ignore_index=True)

# level_x is the left side of the join
# level_y is the right side of the join
# we only want to keep the clusters that are the same on both sides
combined_clusters = concatenated_clusters[
concatenated_clusters["level_x"] == concatenated_clusters["level_y"]
# Keep only rows where level_x == level_y
combined_clusters = clusters[
clusters["level_x"] == clusters["level_y"]
].reset_index(drop=True)

cluster_relationships = aggregate_df(
cast(Table, combined_clusters),
aggregations=[
{
"column": "id_y", # this is the id of the edge from the join steps above
"to": "relationship_ids",
"operation": "array_agg_distinct",
},
{
"column": "source_id_x",
"to": "text_unit_ids",
"operation": "array_agg_distinct",
},
],
groupby=[
"cluster",
"level_x", # level_x is the left side of the join
],
cluster_relationships = (
combined_clusters.groupby(["cluster", "level_x"], sort=False)
.agg(
relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique")
)
.reset_index()
)

all_clusters = aggregate_df(
graph_nodes,
aggregations=[{"column": "cluster", "to": "id", "operation": "any"}],
groupby=["cluster", "level"],
all_clusters = (
graph_nodes.groupby(["cluster", "level"], sort=False)
.agg(id=("cluster", "first"))
.reset_index()
)

joined = all_clusters.merge(
Expand All @@ -94,14 +74,15 @@ def create_final_communities(
return create_verb_result(
cast(
Table,
filtered[
filtered.loc[
:,
[
"id",
"title",
"level",
"relationship_ids",
"text_unit_ids",
]
],
],
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

from typing import cast

import pandas as pd
from datashaper.engine.verbs.verb_input import VerbInput
from datashaper.engine.verbs.verbs_mapping import verb
from datashaper.table_store.types import Table, VerbResult, create_verb_result

from graphrag.index.verbs.overrides.aggregate import aggregate_df


@verb(
name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True
Expand All @@ -21,15 +20,15 @@ def create_final_text_units_pre_embedding(
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform before we embed the text units."""
table = input.get_input()
table = cast(pd.DataFrame, input.get_input())
others = input.get_others()

selected = cast(Table, table[["id", "chunk", "document_ids", "n_tokens"]]).rename(
selected = table.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename(
columns={"chunk": "text"}
)

final_entities = others[0]
final_relationships = others[1]
final_entities = cast(pd.DataFrame, others[0])
final_relationships = cast(pd.DataFrame, others[1])
entity_join = _entities(final_entities)
relationship_join = _relationships(final_relationships)

Expand All @@ -38,116 +37,47 @@ def create_final_text_units_pre_embedding(
final_joined = relationship_joined

if covariates_enabled:
final_covariates = others[2]
final_covariates = cast(pd.DataFrame, others[2])
covariate_join = _covariates(final_covariates)
final_joined = _join(relationship_joined, covariate_join)

aggregated = _final_aggregation(final_joined, covariates_enabled)

return create_verb_result(aggregated)


def _final_aggregation(table, covariates_enabled):
aggregations = [
{
"column": "text",
"operation": "any",
"to": "text",
},
{
"column": "n_tokens",
"operation": "any",
"to": "n_tokens",
},
{
"column": "document_ids",
"operation": "any",
"to": "document_ids",
},
{
"column": "entity_ids",
"operation": "any",
"to": "entity_ids",
},
{
"column": "relationship_ids",
"operation": "any",
"to": "relationship_ids",
},
]
if covariates_enabled:
aggregations.append({
"column": "covariate_ids",
"operation": "any",
"to": "covariate_ids",
})
return aggregate_df(
table,
aggregations,
["id"],
)
aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index()

return create_verb_result(cast(Table, aggregated))


def _entities(df: pd.DataFrame) -> pd.DataFrame:
selected = df.loc[:, ["id", "text_unit_ids"]]
unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True)

def _entities(table):
selected = cast(Table, table[["id", "text_unit_ids"]])
unrolled = selected.explode("text_unit_ids").reset_index(drop=True)
return aggregate_df(
unrolled,
[
{
"column": "id",
"operation": "array_agg_distinct",
"to": "entity_ids",
},
{
"column": "text_unit_ids",
"operation": "any",
"to": "id",
},
],
["text_unit_ids"],
return (
unrolled.groupby("text_unit_ids", sort=False)
.agg(entity_ids=("id", "unique"))
.reset_index()
.rename(columns={"text_unit_ids": "id"})
)


def _relationships(table):
selected = cast(Table, table[["id", "text_unit_ids"]])
unrolled = selected.explode("text_unit_ids").reset_index(drop=True)
aggregated = aggregate_df(
unrolled,
[
{
"column": "id",
"operation": "array_agg_distinct",
"to": "relationship_ids",
},
{
"column": "text_unit_ids",
"operation": "any",
"to": "id",
},
],
["text_unit_ids"],
def _relationships(df: pd.DataFrame) -> pd.DataFrame:
selected = df.loc[:, ["id", "text_unit_ids"]]
unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True)

return (
unrolled.groupby("text_unit_ids", sort=False)
.agg(relationship_ids=("id", "unique"))
.reset_index()
.rename(columns={"text_unit_ids": "id"})
)
return aggregated[["id", "relationship_ids"]]


def _covariates(table):
selected = cast(Table, table[["id", "text_unit_id"]])
return aggregate_df(
selected,
[
{
"column": "id",
"operation": "array_agg_distinct",
"to": "covariate_ids",
},
{
"column": "text_unit_id",
"operation": "any",
"to": "id",
},
],
["text_unit_id"],


def _covariates(df: pd.DataFrame) -> pd.DataFrame:
selected = df.loc[:, ["id", "text_unit_id"]]

return (
selected.groupby("text_unit_id", sort=False)
.agg(covariate_ids=("id", "unique"))
.reset_index()
.rename(columns={"text_unit_id": "id"})
)


Expand Down
Loading