Skip to content

Commit

Permalink
Merge pull request #282 from chenyukang/yukang-add-self-payment
Browse files Browse the repository at this point in the history
Add allow_self_payment for send_payment
  • Loading branch information
quake authored Oct 24, 2024
2 parents 38b7547 + 36306a0 commit f61ee09
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 38 deletions.
45 changes: 31 additions & 14 deletions src/fiber/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,16 @@ where
/// Returns a list of `PaymentHopData` for all nodes in the route, including the origin and the target node.
pub fn build_route(
&self,
payment_request: SendPaymentData,
payment_data: &SendPaymentData,
) -> Result<Vec<PaymentHopData>, GraphError> {
let payment_data = payment_data.clone();
let source = self.get_source_pubkey();
let target = payment_request.target_pubkey;
let amount = payment_request.amount;
let preimage = payment_request.preimage;
let payment_hash = payment_request.payment_hash;
let udt_type_script = payment_request.udt_type_script;
let invoice = payment_request
let target = payment_data.target_pubkey;
let amount = payment_data.amount;
let preimage = payment_data.preimage;
let payment_hash = payment_data.payment_hash;
let udt_type_script = payment_data.udt_type_script;
let invoice = payment_data
.invoice
.map(|x| x.parse::<CkbInvoice>().unwrap());
let hash_algorithm = invoice
Expand All @@ -503,12 +504,20 @@ where
source, target, amount, payment_hash
);

let allow_self_payment = payment_data.allow_self_payment;
if source == target && !allow_self_payment {
return Err(GraphError::PathFind(
"source and target are the same and allow_self_payment is not enable".to_string(),
));
}

let route = self.find_route(
source,
target,
amount,
payment_request.max_fee_amount,
payment_data.max_fee_amount,
udt_type_script,
allow_self_payment,
)?;
assert!(!route.is_empty());

Expand Down Expand Up @@ -581,6 +590,7 @@ where
amount: u128,
max_fee_amount: Option<u128>,
udt_type_script: Option<Script>,
allow_self: bool,
) -> Result<Vec<PathEdge>, GraphError> {
let started_time = std::time::Instant::now();
let nodes_len = self.nodes.len();
Expand All @@ -596,11 +606,12 @@ where
));
}

if source == target {
if source == target && !allow_self {
return Err(GraphError::PathFind(
"source and target are the same".to_string(),
));
}

let Some(source_node) = self.nodes.get(&source) else {
return Err(GraphError::PathFind(format!(
"source node not found: {:?}",
Expand All @@ -613,6 +624,7 @@ where
&target
)));
};

// initialize the target node
nodes_heap.push(NodeHeapElement {
node_id: target,
Expand All @@ -624,19 +636,21 @@ where
next_hop: None,
incoming_cltv_height: 0,
});
let route_to_self = source == target;
while let Some(cur_hop) = nodes_heap.pop() {
if cur_hop.node_id == source {
break;
}
nodes_visited += 1;

for (from, channel_info, channel_update) in self.get_node_inbounds(cur_hop.node_id) {
edges_expanded += 1;
if from == target && !route_to_self {
continue;
}
// if charge inbound fees for exit hop
if udt_type_script != channel_info.announcement_msg.udt_type_script {
continue;
}

edges_expanded += 1;

let fee_rate = channel_update.fee_rate;
let next_hop_received_amount = cur_hop.amount_received;
let fee = calculate_tlc_forward_fee(next_hop_received_amount, fee_rate as u128);
Expand Down Expand Up @@ -720,7 +734,7 @@ where
}

let mut current = source_node.node_id;
while current != target {
loop {
if let Some(elem) = distances.get(&current) {
let next_hop = elem.next_hop.as_ref().expect("next_hop is none");
result.push(PathEdge {
Expand All @@ -731,6 +745,9 @@ where
} else {
break;
}
if current == target {
break;
}
}

info!(
Expand Down
26 changes: 15 additions & 11 deletions src/fiber/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ pub struct SendPaymentCommand {
// udt type script
#[serde_as(as = "Option<EntityHex>")]
pub udt_type_script: Option<Script>,
// allow self payment, default is false
pub allow_self_payment: bool,
}

#[serde_as]
Expand All @@ -309,10 +311,11 @@ pub struct SendPaymentData {
#[serde_as(as = "Option<EntityHex>")]
pub udt_type_script: Option<Script>,
pub preimage: Option<Hash256>,
pub allow_self_payment: bool,
}

impl SendPaymentData {
pub fn new(command: SendPaymentCommand) -> Result<SendPaymentData, String> {
pub fn new(command: SendPaymentCommand, source: Pubkey) -> Result<SendPaymentData, String> {
let invoice = command
.invoice
.as_ref()
Expand Down Expand Up @@ -352,6 +355,10 @@ impl SendPaymentData {
"target_pubkey",
)?;

if !command.allow_self_payment && target == source {
return Err("allow_self_payment is not enable, can not pay self".to_string());
}

let amount = validate_field(
command.amount,
invoice.as_ref().and_then(|i| i.amount()),
Expand Down Expand Up @@ -404,6 +411,7 @@ impl SendPaymentData {
keysend,
udt_type_script,
preimage,
allow_self_payment: command.allow_self_payment,
})
}
}
Expand Down Expand Up @@ -2261,12 +2269,7 @@ where
let mut error = None;
while payment_session.can_retry() {
payment_session.retried_times += 1;
let hops_infos = match self
.network_graph
.read()
.await
.build_route(payment_data.clone())
{
let hops_infos = match self.network_graph.read().await.build_route(&payment_data) {
Err(e) => {
error!("Failed to build route: {:?}", e);
error = Some(format!("Failed to build route: {:?}", payment_hash));
Expand Down Expand Up @@ -2330,10 +2333,11 @@ where
state: &mut NetworkActorState<S>,
payment_request: SendPaymentCommand,
) -> Result<SendPaymentResponse, Error> {
let payment_data = SendPaymentData::new(payment_request.clone()).map_err(|e| {
error!("Failed to validate payment request: {:?}", e);
Error::InvalidParameter(format!("Failed to validate payment request: {:?}", e))
})?;
let payment_data = SendPaymentData::new(payment_request.clone(), state.get_public_key())
.map_err(|e| {
error!("Failed to validate payment request: {:?}", e);
Error::InvalidParameter(format!("Failed to validate payment request: {:?}", e))
})?;

// initialize the payment session in db and begin the payment process lifecycle
if let Some(payment_session) = self.store.get_payment_session(payment_data.payment_hash) {
Expand Down
Loading

0 comments on commit f61ee09

Please sign in to comment.