From 9f2cc5f5de7c92ba106e40c4564350bfe4f952a5 Mon Sep 17 00:00:00 2001 From: Anthony Mahanna <43019056+aMahanna@users.noreply.github.com> Date: Fri, 30 Aug 2024 10:50:35 -0400 Subject: [PATCH] new: `insert_edge_fn` (#31) * new: `insert_edge_fn` * fix: return error --- src/graph.rs | 148 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 106 insertions(+), 42 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index e2c29e1..09338b9 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -58,12 +58,6 @@ pub struct NumpyGraph { #[derive(Debug)] pub struct NetworkXGraph { - pub load_adj_dict: bool, - pub load_coo: bool, - pub load_all_vertex_attributes: bool, - pub load_all_edge_attributes: bool, - pub is_directed: bool, - pub is_multigraph: bool, pub symmetrize_edges_if_directed: bool, // node_map is a dictionary of node IDs to their json data @@ -87,12 +81,11 @@ pub struct NetworkXGraph { // pre-defined functions get_vertex_properties_fn: fn(&mut NetworkXGraph, String, Vec, &Vec) -> Map, - get_edge_properties_fn: fn(&mut NetworkXGraph, String, String, Vec, &Vec) -> Map, - insert_coo_fn: fn(&mut NetworkXGraph, String, String, HashMap), insert_adj_fn: fn(&mut NetworkXGraph, String, String, Map), + insert_edge_fn: fn(&mut NetworkXGraph, String, String, Vec, &Vec) -> Result<()>, } impl NumpyGraph { @@ -172,13 +165,15 @@ impl NetworkXGraph { } }; + let insert_edge_fn = if load_coo && load_adj_dict { + NetworkXGraph::insert_edge_as_coo_and_adj + } else if load_coo { + NetworkXGraph::insert_edge_as_coo_only + } else { + NetworkXGraph::insert_edge_as_adj_only + }; + Arc::new(RwLock::new(NetworkXGraph { - load_adj_dict, - load_coo, - load_all_vertex_attributes, - load_all_edge_attributes, - is_directed, - is_multigraph, symmetrize_edges_if_directed, node_map: HashMap::new(), adj_map_graph: HashMap::new(), @@ -194,6 +189,7 @@ impl NetworkXGraph { get_edge_properties_fn, insert_coo_fn, insert_adj_fn, + insert_edge_fn, })) } @@ -580,6 +576,101 @@ impl NetworkXGraph { pred_from_to_map.insert(index, properties); } } + + fn insert_edge_as_coo( + &mut self, + from_id_str: String, + to_id_str: String, + columns: &Vec, + field_names: &Vec, + ) -> Result<()> { + let mut properties: HashMap = HashMap::new(); + for (field_position, field_name) in field_names.iter().enumerate() { + if field_name == "@collection_name" { + continue; + } + let field_vec = match columns[field_position].as_f64() { + Some(v) => v, + _ => return Err(anyhow!("Edge data must be a numeric value")), + }; + + properties.insert(field_name.clone(), field_vec); + } + + (self.insert_coo_fn)(self, from_id_str, to_id_str, properties); + + Ok(()) + } + + fn insert_edge_as_adj( + &mut self, + from_id_str: String, + to_id_str: String, + columns: Vec, + field_names: &Vec, + ) -> Result<()> { + let properties = (self.get_edge_properties_fn)( + self, + from_id_str.clone(), + to_id_str.clone(), + columns, + field_names, + ); + + (self.insert_adj_fn)(self, from_id_str, to_id_str, properties); + + Ok(()) + } + + fn insert_edge_as_coo_and_adj( + &mut self, + from_id_str: String, + to_id_str: String, + columns: Vec, + field_names: &Vec, + ) -> Result<()> { + let res = self.insert_edge_as_coo( + from_id_str.clone(), + to_id_str.clone(), + &columns, + field_names, + ); + + if let Err(e) = res { + return Err(e); + } + + self.insert_edge_as_adj(from_id_str, to_id_str, columns, field_names)?; + + Ok(()) + } + + fn insert_edge_as_coo_only( + &mut self, + from_id_str: String, + to_id_str: String, + columns: Vec, + field_names: &Vec, + ) -> Result<()> { + let res = self.insert_edge_as_coo(from_id_str, to_id_str, &columns, field_names); + if let Err(e) = res { + return Err(e); + } + + Ok(()) + } + + fn insert_edge_as_adj_only( + &mut self, + from_id_str: String, + to_id_str: String, + columns: Vec, + field_names: &Vec, + ) -> Result<()> { + self.insert_edge_as_adj(from_id_str, to_id_str, columns, field_names)?; + + Ok(()) + } } impl Graph for NumpyGraph { @@ -756,34 +847,7 @@ impl Graph for NetworkXGraph { let from_id_str: String = String::from_utf8(from_id.clone()).unwrap(); let to_id_str: String = String::from_utf8(to_id.clone()).unwrap(); - if self.load_coo { - let mut properties: HashMap = HashMap::new(); - for (field_position, field_name) in field_names.iter().enumerate() { - if field_name == "@collection_name" { - continue; - } - let field_vec = match columns[field_position].as_f64() { - Some(v) => v, - _ => return Err(anyhow!("Edge data must be a numeric value")), - }; - - properties.insert(field_name.clone(), field_vec); - } - - (self.insert_coo_fn)(self, from_id_str.clone(), to_id_str.clone(), properties); - } - - if self.load_adj_dict { - let properties = (self.get_edge_properties_fn)( - self, - from_id_str.clone(), - to_id_str.clone(), - columns, - field_names, - ); - - (self.insert_adj_fn)(self, from_id_str, to_id_str, properties); - } + (self.insert_edge_fn)(self, from_id_str, to_id_str, columns, field_names)?; Ok(()) }