Skip to content

Commit

Permalink
new: insert_edge_fn (#31)
Browse files Browse the repository at this point in the history
* new: `insert_edge_fn`

* fix: return error
  • Loading branch information
aMahanna authored Aug 30, 2024
1 parent acdc346 commit 9f2cc5f
Showing 1 changed file with 106 additions and 42 deletions.
148 changes: 106 additions & 42 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ pub struct NumpyGraph {

#[derive(Debug)]
pub struct NetworkXGraph {
pub load_adj_dict: bool,
pub load_coo: bool,
pub load_all_vertex_attributes: bool,
pub load_all_edge_attributes: bool,
pub is_directed: bool,
pub is_multigraph: bool,
pub symmetrize_edges_if_directed: bool,

// node_map is a dictionary of node IDs to their json data
Expand All @@ -87,12 +81,11 @@ pub struct NetworkXGraph {
// pre-defined functions
get_vertex_properties_fn:
fn(&mut NetworkXGraph, String, Vec<Value>, &Vec<String>) -> Map<String, Value>,

get_edge_properties_fn:
fn(&mut NetworkXGraph, String, String, Vec<Value>, &Vec<String>) -> Map<String, Value>,

insert_coo_fn: fn(&mut NetworkXGraph, String, String, HashMap<String, f64>),
insert_adj_fn: fn(&mut NetworkXGraph, String, String, Map<String, Value>),
insert_edge_fn: fn(&mut NetworkXGraph, String, String, Vec<Value>, &Vec<String>) -> Result<()>,
}

impl NumpyGraph {
Expand Down Expand Up @@ -172,13 +165,15 @@ impl NetworkXGraph {
}
};

let insert_edge_fn = if load_coo && load_adj_dict {
NetworkXGraph::insert_edge_as_coo_and_adj
} else if load_coo {
NetworkXGraph::insert_edge_as_coo_only
} else {
NetworkXGraph::insert_edge_as_adj_only
};

Arc::new(RwLock::new(NetworkXGraph {
load_adj_dict,
load_coo,
load_all_vertex_attributes,
load_all_edge_attributes,
is_directed,
is_multigraph,
symmetrize_edges_if_directed,
node_map: HashMap::new(),
adj_map_graph: HashMap::new(),
Expand All @@ -194,6 +189,7 @@ impl NetworkXGraph {
get_edge_properties_fn,
insert_coo_fn,
insert_adj_fn,
insert_edge_fn,
}))
}

Expand Down Expand Up @@ -580,6 +576,101 @@ impl NetworkXGraph {
pred_from_to_map.insert(index, properties);
}
}

fn insert_edge_as_coo(
&mut self,
from_id_str: String,
to_id_str: String,
columns: &Vec<Value>,
field_names: &Vec<String>,
) -> Result<()> {
let mut properties: HashMap<String, f64> = HashMap::new();
for (field_position, field_name) in field_names.iter().enumerate() {
if field_name == "@collection_name" {
continue;
}
let field_vec = match columns[field_position].as_f64() {
Some(v) => v,
_ => return Err(anyhow!("Edge data must be a numeric value")),
};

properties.insert(field_name.clone(), field_vec);
}

(self.insert_coo_fn)(self, from_id_str, to_id_str, properties);

Ok(())
}

fn insert_edge_as_adj(
&mut self,
from_id_str: String,
to_id_str: String,
columns: Vec<Value>,
field_names: &Vec<String>,
) -> Result<()> {
let properties = (self.get_edge_properties_fn)(
self,
from_id_str.clone(),
to_id_str.clone(),
columns,
field_names,
);

(self.insert_adj_fn)(self, from_id_str, to_id_str, properties);

Ok(())
}

fn insert_edge_as_coo_and_adj(
&mut self,
from_id_str: String,
to_id_str: String,
columns: Vec<Value>,
field_names: &Vec<String>,
) -> Result<()> {
let res = self.insert_edge_as_coo(
from_id_str.clone(),
to_id_str.clone(),
&columns,
field_names,
);

if let Err(e) = res {
return Err(e);
}

self.insert_edge_as_adj(from_id_str, to_id_str, columns, field_names)?;

Ok(())
}

fn insert_edge_as_coo_only(
&mut self,
from_id_str: String,
to_id_str: String,
columns: Vec<Value>,
field_names: &Vec<String>,
) -> Result<()> {
let res = self.insert_edge_as_coo(from_id_str, to_id_str, &columns, field_names);
if let Err(e) = res {
return Err(e);
}

Ok(())
}

fn insert_edge_as_adj_only(
&mut self,
from_id_str: String,
to_id_str: String,
columns: Vec<Value>,
field_names: &Vec<String>,
) -> Result<()> {
self.insert_edge_as_adj(from_id_str, to_id_str, columns, field_names)?;

Ok(())
}
}

impl Graph for NumpyGraph {
Expand Down Expand Up @@ -756,34 +847,7 @@ impl Graph for NetworkXGraph {
let from_id_str: String = String::from_utf8(from_id.clone()).unwrap();
let to_id_str: String = String::from_utf8(to_id.clone()).unwrap();

if self.load_coo {
let mut properties: HashMap<String, f64> = HashMap::new();
for (field_position, field_name) in field_names.iter().enumerate() {
if field_name == "@collection_name" {
continue;
}
let field_vec = match columns[field_position].as_f64() {
Some(v) => v,
_ => return Err(anyhow!("Edge data must be a numeric value")),
};

properties.insert(field_name.clone(), field_vec);
}

(self.insert_coo_fn)(self, from_id_str.clone(), to_id_str.clone(), properties);
}

if self.load_adj_dict {
let properties = (self.get_edge_properties_fn)(
self,
from_id_str.clone(),
to_id_str.clone(),
columns,
field_names,
);

(self.insert_adj_fn)(self, from_id_str, to_id_str, properties);
}
(self.insert_edge_fn)(self, from_id_str, to_id_str, columns, field_names)?;

Ok(())
}
Expand Down

0 comments on commit 9f2cc5f

Please sign in to comment.