Skip to content

Commit

Permalink
fix issues in session route
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyukang committed Oct 24, 2024
1 parent 1df513d commit 92c63c8
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 57 deletions.
32 changes: 21 additions & 11 deletions src/fiber/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ where
}

pub(crate) fn record_payment_success(&mut self, payment_session: &PaymentSession) {
let session_route = &payment_session.route.channels;
let session_route = &payment_session.route.nodes;
let mut result = InternalResult::default();
result.succeed_range_pairs(session_route, 0, session_route.len() - 1);
self.history.apply_internal_result(result);
Expand All @@ -488,9 +488,9 @@ where
payment_session: &PaymentSession,
tlc_err: TlcErr,
) -> bool {
let route = &payment_session.route.channels;
let mut internal_result = InternalResult::default();
let need_to_retry = internal_result.record_payment_fail(route, tlc_err);
let nodes = &payment_session.route.nodes;
let need_to_retry = internal_result.record_payment_fail(nodes, tlc_err);
self.history.apply_internal_result(internal_result);
return need_to_retry;
}
Expand Down Expand Up @@ -527,7 +527,7 @@ where
source, target, amount, payment_hash
);

let route = self.find_route(
let route = self.find_path(
source,
target,
amount,
Expand Down Expand Up @@ -608,7 +608,7 @@ where
}

// the algorithm works from target-to-source to find the shortest path
pub fn find_route(
pub fn find_path(
&self,
source: Pubkey,
target: Pubkey,
Expand Down Expand Up @@ -843,32 +843,42 @@ pub struct SessionRouteNode {

// The router is a list of nodes that the payment will go through.
// We store in the payment session and then will use it to track the payment history.
// The router is a list of nodes that the payment will go through.
// For example:
// A(amount, channel) -> B -> C -> D means A will send `amount` with `channel` to B.
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct SessionRoute {
pub channels: Vec<SessionRouteNode>,
pub nodes: Vec<SessionRouteNode>,
}

impl SessionRoute {
pub fn new(target: Pubkey, payment_hops: &Vec<PaymentHopData>) -> Self {
// Create a new route from the source to the target with the given payment hops.
// The payment hops are the hops that the payment will go through.
// for a payment route A -> B -> C -> D
// the `payment_hops` is [B, C, D], which is a convinent way for onion routing.
// here we need to create a session route with source, which is A -> B -> C -> D
pub fn new(source: Pubkey, target: Pubkey, payment_hops: &Vec<PaymentHopData>) -> Self {
let mut router = Self::default();
let mut current = source;
for hop in payment_hops {
if let Some(key) = hop.next_hop {
router.add_node(
key,
current,
hop.channel_outpoint
.clone()
.expect("expect channel outpoint"),
hop.amount,
);
} else {
router.add_node(target, OutPoint::default(), hop.amount);
current = key;
}
}
assert_eq!(current, target);
router.add_node(target, OutPoint::default(), 0);
router
}

fn add_node(&mut self, pubkey: Pubkey, channel_outpoint: OutPoint, amount: u128) {
self.channels.push(SessionRouteNode {
self.nodes.push(SessionRouteNode {
pubkey,
channel_outpoint,
amount,
Expand Down
78 changes: 39 additions & 39 deletions src/fiber/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ impl InternalResult {
self.add(from, target, current_time(), amount, false);
}

pub fn fail_node(&mut self, route: &Vec<SessionRouteNode>, index: usize) {
self.fail_node = Some(route[index].pubkey);
pub fn fail_node(&mut self, nodes: &Vec<SessionRouteNode>, index: usize) {
self.fail_node = Some(nodes[index].pubkey);
if index > 0 {
self.fail_pair(route, index);
self.fail_pair(nodes, index);
}
if index + 1 < route.len() {
self.fail_pair(route, index + 1);
if index + 1 < nodes.len() {
self.fail_pair(nodes, index + 1);
}
}

Expand All @@ -82,36 +82,36 @@ impl InternalResult {
}
}

pub fn fail_pair_balanced(&mut self, route: &Vec<SessionRouteNode>, index: usize) {
pub fn fail_pair_balanced(&mut self, nodes: &Vec<SessionRouteNode>, index: usize) {
if index > 0 {
let a = route[index - 1].pubkey;
let b = route[index].pubkey;
let amount = route[index].amount;
let a = nodes[index - 1].pubkey;
let b = nodes[index].pubkey;
let amount = nodes[index].amount;
self.add_fail_pair_balanced(a, b, amount);
}
}

pub fn succeed_range_pairs(&mut self, route: &Vec<SessionRouteNode>, start: usize, end: usize) {
pub fn succeed_range_pairs(&mut self, nodes: &Vec<SessionRouteNode>, start: usize, end: usize) {
for i in start..end {
self.add(
route[i].pubkey,
route[i + 1].pubkey,
nodes[i].pubkey,
nodes[i + 1].pubkey,
current_time(),
route[i].amount,
nodes[i].amount,
true,
);
}
}
pub fn fail_range_pairs(&mut self, route: &Vec<SessionRouteNode>, start: usize, end: usize) {
pub fn fail_range_pairs(&mut self, nodes: &Vec<SessionRouteNode>, start: usize, end: usize) {
for index in start.max(1)..=end {
self.fail_pair(route, index);
self.fail_pair(nodes, index);
}
}

pub fn record_payment_fail(&mut self, route: &Vec<SessionRouteNode>, tlc_err: TlcErr) -> bool {
pub fn record_payment_fail(&mut self, nodes: &Vec<SessionRouteNode>, tlc_err: TlcErr) -> bool {
let mut need_to_retry = true;

let error_index = route.iter().position(|s| {
let error_index = nodes.iter().position(|s| {
Some(s.channel_outpoint.clone()) == tlc_err.error_channel_outpoint()
|| Some(s.pubkey) == tlc_err.error_node_id()
});
Expand All @@ -121,18 +121,18 @@ impl InternalResult {
return need_to_retry;
};

let len = route.len();
let len = nodes.len();
assert!(len >= 2);
let error_code = tlc_err.error_code;
if index == 0 {
if index == 1 {
match error_code {
// we received an error from the first node, we trust our own node
// so we need to penalize the first node
TlcErrorCode::InvalidOnionVersion
| TlcErrorCode::InvalidOnionHmac
| TlcErrorCode::InvalidOnionKey
| TlcErrorCode::InvalidOnionPayload => {
self.fail_node(route, 1);
self.fail_node(nodes, 1);
}
_ => {
// we can not penalize our own node, the whole payment session need to retry
Expand All @@ -144,15 +144,15 @@ impl InternalResult {
TlcErrorCode::FinalIncorrectCltvExpiry | TlcErrorCode::FinalIncorrectHtlcAmount => {
if len == 2 {
need_to_retry = false;
self.fail_node(route, len - 1);
self.fail_node(nodes, len - 1);
} else {
self.fail_pair(route, index - 1);
self.succeed_range_pairs(route, 0, index - 2);
self.fail_pair(nodes, index - 1);
self.succeed_range_pairs(nodes, 0, index - 2);
}
}
TlcErrorCode::IncorrectOrUnknownPaymentDetails | TlcErrorCode::InvoiceExpired => {
need_to_retry = false;
self.succeed_range_pairs(route, 0, len - 1);
self.succeed_range_pairs(nodes, 0, len - 1);
}
TlcErrorCode::ExpiryTooSoon => {
need_to_retry = false;
Expand All @@ -161,9 +161,9 @@ impl InternalResult {
unimplemented!("not implemented");
}
_ => {
self.fail_node(route, len - 1);
self.fail_node(nodes, len - 1);
if len > 1 {
self.succeed_range_pairs(route, 0, len - 2);
self.succeed_range_pairs(nodes, 0, len - 2);
}
}
}
Expand All @@ -173,44 +173,44 @@ impl InternalResult {
TlcErrorCode::InvalidOnionVersion
| TlcErrorCode::InvalidOnionHmac
| TlcErrorCode::InvalidOnionKey => {
self.fail_pair(route, index);
self.fail_pair(nodes, index);
}
TlcErrorCode::InvalidOnionPayload => {
self.fail_node(route, index);
self.fail_node(nodes, index);
if index > 1 {
self.succeed_range_pairs(route, 0, index - 1);
self.succeed_range_pairs(nodes, 0, index - 1);
}
}
TlcErrorCode::UnknownNextPeer => {
self.fail_pair(route, index);
self.fail_pair(nodes, index);
}
TlcErrorCode::PermanentChannelFailure => {
self.fail_pair(route, index);
self.fail_pair(nodes, index);
}
TlcErrorCode::FeeInsufficient | TlcErrorCode::IncorrectCltvExpiry => {
need_to_retry = false;
if index == 1 {
self.fail_node(route, 1);
self.fail_node(nodes, 1);
} else {
self.fail_pair(route, index - 1);
self.fail_pair(nodes, index - 1);
if index > 1 {
self.succeed_range_pairs(route, 0, index - 2);
self.succeed_range_pairs(nodes, 0, index - 2);
}
}
}
TlcErrorCode::TemporaryChannelFailure => {
self.fail_pair_balanced(route, index);
self.succeed_range_pairs(route, 0, index - 1);
self.fail_pair_balanced(nodes, index);
self.succeed_range_pairs(nodes, 0, index - 1);
}
TlcErrorCode::ExpiryTooSoon => {
if index == 1 {
self.fail_node(route, 1);
self.fail_node(nodes, 1);
} else {
self.fail_range_pairs(route, 0, index - 1);
self.fail_range_pairs(nodes, 0, index - 1);
}
}
_ => {
self.fail_node(route, index);
self.fail_node(nodes, index);
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/fiber/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2290,7 +2290,11 @@ where
.clone()
.expect("first hop channel outpoint");

let session_route = SessionRoute::new(payment_data.target_pubkey, &hops_infos);
let session_route = SessionRoute::new(
state.get_public_key(),
payment_data.target_pubkey,
&hops_infos,
);
// generate session key
let session_key = Privkey::from_slice(KeyPair::generate_random_key().as_ref());
let peeled_packet = match PeeledPaymentOnionPacket::create(
Expand Down
49 changes: 43 additions & 6 deletions src/fiber/tests/graph.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::test_utils::{generate_keypair, generate_pubkey};
use crate::fiber::types::Pubkey;
use super::test_utils::generate_keypair;
use crate::fiber::graph::SessionRoute;
use crate::{
fiber::{
graph::{ChannelInfo, GraphError, NetworkGraph, NodeInfo, PathEdge},
Expand Down Expand Up @@ -180,7 +180,7 @@ impl MockNetworkGraph {
let source = self.keys[source].into();
let target = self.keys[target].into();
self.graph
.find_route(source, target, amount, Some(max_fee), None)
.find_path(source, target, amount, Some(max_fee), None)
}

pub fn find_route_udt(
Expand All @@ -194,7 +194,7 @@ impl MockNetworkGraph {
let source = self.keys[source].into();
let target = self.keys[target].into();
self.graph
.find_route(source, target, amount, Some(max_fee), Some(udt_type_script))
.find_path(source, target, amount, Some(max_fee), Some(udt_type_script))
}
}

Expand Down Expand Up @@ -459,7 +459,7 @@ fn test_graph_find_path_err() {
assert!(route.is_err());

let no_exits_public_key = network.keys[0];
let route = network.graph.find_route(
let route = network.graph.find_path(
node1.into(),
no_exits_public_key.into(),
100,
Expand All @@ -468,7 +468,7 @@ fn test_graph_find_path_err() {
);
assert!(route.is_err());

let route = network.graph.find_route(
let route = network.graph.find_path(
no_exits_public_key.into(),
node1.into(),
100,
Expand Down Expand Up @@ -631,6 +631,43 @@ fn test_graph_mark_failed_channel() {
assert!(route.is_ok());
}

#[test]
fn test_graph_session_router() {
let mut network = MockNetworkGraph::new(5);
network.add_edge(0, 2, Some(500), Some(2));
network.add_edge(2, 3, Some(500), Some(2));
network.add_edge(3, 4, Some(500), Some(2));

let node0 = network.keys[0];
let node2 = network.keys[2];
let node3 = network.keys[3];
let node4 = network.keys[4];

// Test build route from node1 to node4 should be Ok
let route = network.graph.build_route(SendPaymentData {
target_pubkey: node4.into(),
amount: 100,
payment_hash: Hash256::default(),
invoice: None,
final_cltv_delta: Some(100),
timeout: Some(10),
max_fee_amount: Some(1000),
max_parts: None,
keysend: false,
udt_type_script: None,
preimage: None,
});
assert!(route.is_ok());

let route = route.unwrap();
let session_route = SessionRoute::new(node0.into(), node4.into(), &route);
let session_route_keys: Vec<_> = session_route.nodes.iter().map(|x| x.pubkey).collect();
assert_eq!(
session_route_keys,
vec![node0.into(), node2.into(), node3.into(), node4.into()]
);
}

#[test]
fn test_graph_mark_failed_node() {
let mut network = MockNetworkGraph::new(5);
Expand Down

0 comments on commit 92c63c8

Please sign in to comment.