From be7d3eb189e56f96490a00b34a9101d19379f037 Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Mon, 23 Sep 2024 16:54:15 -0600 Subject: [PATCH] Remove aggregate_df from final coomunities and final text units (#1179) * Remove aggregate_df from final coomunities and final text units * Semver * Ruff and format * Format * Format * Fix tests, ruff and checks * Remove some leftover prints * Removed _final_join method --- .../patch-20240920221112632172.json | 4 + .../v1/subflows/create_final_communities.py | 63 +++----- .../create_final_text_units_pre_embedding.py | 144 +++++------------- 3 files changed, 63 insertions(+), 148 deletions(-) create mode 100644 .semversioner/next-release/patch-20240920221112632172.json diff --git a/.semversioner/next-release/patch-20240920221112632172.json b/.semversioner/next-release/patch-20240920221112632172.json new file mode 100644 index 000000000..47a2a6d76 --- /dev/null +++ b/.semversioner/next-release/patch-20240920221112632172.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Remove aggregate_df from final coomunities and final text units" +} diff --git a/graphrag/index/workflows/v1/subflows/create_final_communities.py b/graphrag/index/workflows/v1/subflows/create_final_communities.py index 5db80fc6a..2cbe8f6ca 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_communities.py +++ b/graphrag/index/workflows/v1/subflows/create_final_communities.py @@ -15,7 +15,6 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.verbs.graph.unpack import unpack_graph_df -from graphrag.index.verbs.overrides.aggregate import aggregate_df @verb(name="create_final_communities", treats_input_tables_as_immutable=True) @@ -30,54 +29,35 @@ def create_final_communities( graph_nodes = unpack_graph_df(table, callbacks, "clustered_graph", "nodes") graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges") + # Merge graph_nodes with graph_edges for both source and target matches source_clusters = graph_nodes.merge( - graph_edges, - left_on="label", - right_on="source", - how="inner", + graph_edges, left_on="label", right_on="source", how="inner" ) + target_clusters = graph_nodes.merge( - graph_edges, - left_on="label", - right_on="target", - how="inner", + graph_edges, left_on="label", right_on="target", how="inner" ) - concatenated_clusters = pd.concat( - [source_clusters, target_clusters], ignore_index=True - ) + # Concatenate the source and target clusters + clusters = pd.concat([source_clusters, target_clusters], ignore_index=True) - # level_x is the left side of the join - # level_y is the right side of the join - # we only want to keep the clusters that are the same on both sides - combined_clusters = concatenated_clusters[ - concatenated_clusters["level_x"] == concatenated_clusters["level_y"] + # Keep only rows where level_x == level_y + combined_clusters = clusters[ + clusters["level_x"] == clusters["level_y"] ].reset_index(drop=True) - cluster_relationships = aggregate_df( - cast(Table, combined_clusters), - aggregations=[ - { - "column": "id_y", # this is the id of the edge from the join steps above - "to": "relationship_ids", - "operation": "array_agg_distinct", - }, - { - "column": "source_id_x", - "to": "text_unit_ids", - "operation": "array_agg_distinct", - }, - ], - groupby=[ - "cluster", - "level_x", # level_x is the left side of the join - ], + cluster_relationships = ( + combined_clusters.groupby(["cluster", "level_x"], sort=False) + .agg( + relationship_ids=("id_y", "unique"), text_unit_ids=("source_id_x", "unique") + ) + .reset_index() ) - all_clusters = aggregate_df( - graph_nodes, - aggregations=[{"column": "cluster", "to": "id", "operation": "any"}], - groupby=["cluster", "level"], + all_clusters = ( + graph_nodes.groupby(["cluster", "level"], sort=False) + .agg(id=("cluster", "first")) + .reset_index() ) joined = all_clusters.merge( @@ -94,14 +74,15 @@ def create_final_communities( return create_verb_result( cast( Table, - filtered[ + filtered.loc[ + :, [ "id", "title", "level", "relationship_ids", "text_unit_ids", - ] + ], ], ) ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py b/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py index 5cb48f051..49ebd8198 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py +++ b/graphrag/index/workflows/v1/subflows/create_final_text_units_pre_embedding.py @@ -5,12 +5,11 @@ from typing import cast +import pandas as pd from datashaper.engine.verbs.verb_input import VerbInput from datashaper.engine.verbs.verbs_mapping import verb from datashaper.table_store.types import Table, VerbResult, create_verb_result -from graphrag.index.verbs.overrides.aggregate import aggregate_df - @verb( name="create_final_text_units_pre_embedding", treats_input_tables_as_immutable=True @@ -21,15 +20,15 @@ def create_final_text_units_pre_embedding( **_kwargs: dict, ) -> VerbResult: """All the steps to transform before we embed the text units.""" - table = input.get_input() + table = cast(pd.DataFrame, input.get_input()) others = input.get_others() - selected = cast(Table, table[["id", "chunk", "document_ids", "n_tokens"]]).rename( + selected = table.loc[:, ["id", "chunk", "document_ids", "n_tokens"]].rename( columns={"chunk": "text"} ) - final_entities = others[0] - final_relationships = others[1] + final_entities = cast(pd.DataFrame, others[0]) + final_relationships = cast(pd.DataFrame, others[1]) entity_join = _entities(final_entities) relationship_join = _relationships(final_relationships) @@ -38,116 +37,47 @@ def create_final_text_units_pre_embedding( final_joined = relationship_joined if covariates_enabled: - final_covariates = others[2] + final_covariates = cast(pd.DataFrame, others[2]) covariate_join = _covariates(final_covariates) final_joined = _join(relationship_joined, covariate_join) - aggregated = _final_aggregation(final_joined, covariates_enabled) - - return create_verb_result(aggregated) - - -def _final_aggregation(table, covariates_enabled): - aggregations = [ - { - "column": "text", - "operation": "any", - "to": "text", - }, - { - "column": "n_tokens", - "operation": "any", - "to": "n_tokens", - }, - { - "column": "document_ids", - "operation": "any", - "to": "document_ids", - }, - { - "column": "entity_ids", - "operation": "any", - "to": "entity_ids", - }, - { - "column": "relationship_ids", - "operation": "any", - "to": "relationship_ids", - }, - ] - if covariates_enabled: - aggregations.append({ - "column": "covariate_ids", - "operation": "any", - "to": "covariate_ids", - }) - return aggregate_df( - table, - aggregations, - ["id"], - ) + aggregated = final_joined.groupby("id", sort=False).agg("first").reset_index() + + return create_verb_result(cast(Table, aggregated)) + +def _entities(df: pd.DataFrame) -> pd.DataFrame: + selected = df.loc[:, ["id", "text_unit_ids"]] + unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True) -def _entities(table): - selected = cast(Table, table[["id", "text_unit_ids"]]) - unrolled = selected.explode("text_unit_ids").reset_index(drop=True) - return aggregate_df( - unrolled, - [ - { - "column": "id", - "operation": "array_agg_distinct", - "to": "entity_ids", - }, - { - "column": "text_unit_ids", - "operation": "any", - "to": "id", - }, - ], - ["text_unit_ids"], + return ( + unrolled.groupby("text_unit_ids", sort=False) + .agg(entity_ids=("id", "unique")) + .reset_index() + .rename(columns={"text_unit_ids": "id"}) ) -def _relationships(table): - selected = cast(Table, table[["id", "text_unit_ids"]]) - unrolled = selected.explode("text_unit_ids").reset_index(drop=True) - aggregated = aggregate_df( - unrolled, - [ - { - "column": "id", - "operation": "array_agg_distinct", - "to": "relationship_ids", - }, - { - "column": "text_unit_ids", - "operation": "any", - "to": "id", - }, - ], - ["text_unit_ids"], +def _relationships(df: pd.DataFrame) -> pd.DataFrame: + selected = df.loc[:, ["id", "text_unit_ids"]] + unrolled = selected.explode(["text_unit_ids"]).reset_index(drop=True) + + return ( + unrolled.groupby("text_unit_ids", sort=False) + .agg(relationship_ids=("id", "unique")) + .reset_index() + .rename(columns={"text_unit_ids": "id"}) ) - return aggregated[["id", "relationship_ids"]] - - -def _covariates(table): - selected = cast(Table, table[["id", "text_unit_id"]]) - return aggregate_df( - selected, - [ - { - "column": "id", - "operation": "array_agg_distinct", - "to": "covariate_ids", - }, - { - "column": "text_unit_id", - "operation": "any", - "to": "id", - }, - ], - ["text_unit_id"], + + +def _covariates(df: pd.DataFrame) -> pd.DataFrame: + selected = df.loc[:, ["id", "text_unit_id"]] + + return ( + selected.groupby("text_unit_id", sort=False) + .agg(covariate_ids=("id", "unique")) + .reset_index() + .rename(columns={"text_unit_id": "id"}) )