Skip to content

Commit

Permalink
add invoice get rpc and check expire
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyukang committed Oct 16, 2024
1 parent 856a2ad commit 1734bfb
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 5 deletions.
6 changes: 6 additions & 0 deletions src/fiber/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ impl SendPaymentData {
.transpose()
.map_err(|_| "invoice is invalid".to_string())?;

if let Some(invoice) = invoice.clone() {
if invoice.is_expired() {
return Err("invoice is expired".to_string());
}
}

fn validate_field<T: PartialEq + Clone>(
field: Option<T>,
invoice_field: Option<T>,
Expand Down
7 changes: 7 additions & 0 deletions src/invoice/invoice_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ impl CkbInvoice {
&self.data.payment_hash
}

pub fn is_expired(&self) -> bool {
self.expiry_time().map_or(false, |expiry| {
self.data.timestamp + expiry.as_millis()
< std::time::UNIX_EPOCH.elapsed().unwrap().as_millis()
})
}

/// Check that the invoice is signed correctly and that key recovery works
pub fn check_signature(&self) -> Result<(), InvoiceError> {
if self.signature.is_none() {
Expand Down
15 changes: 15 additions & 0 deletions src/invoice/tests/invoice_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,18 @@ fn test_invoice_udt_script() {
let decoded = serde_json::from_str::<CkbInvoice>(&res.unwrap()).unwrap();
assert_eq!(decoded, invoice);
}

#[test]
fn test_invoice_check_expired() {
let private_key = gen_rand_private_key();
let invoice = InvoiceBuilder::new(Currency::Fibb)
.amount(Some(1280))
.payment_hash(rand_sha256_hash())
.expiry_time(Duration::from_secs(1))
.build_with_sign(|hash| Secp256k1::new().sign_ecdsa_recoverable(hash, &private_key))
.unwrap();

assert_eq!(invoice.is_expired(), false);
std::thread::sleep(Duration::from_secs(2));
assert_eq!(invoice.is_expired(), true);
}
72 changes: 67 additions & 5 deletions src/rpc/invoice.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::time::Duration;

use crate::fiber::graph::{NetworkGraphStateStore, PaymentSessionStatus};
use crate::fiber::hash_algorithm::HashAlgorithm;
use crate::fiber::serde_utils::{U128Hex, U64Hex};
use crate::fiber::types::Hash256;
Expand Down Expand Up @@ -32,7 +33,7 @@ pub(crate) struct NewInvoiceParams {
}

#[derive(Clone, Serialize, Deserialize)]
pub(crate) struct NewInvoiceResult {
pub(crate) struct InvoiceResult {
invoice_address: String,
invoice: CkbInvoice,
}
Expand All @@ -47,19 +48,45 @@ pub(crate) struct ParseInvoiceResult {
invoice: CkbInvoice,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct GetInvoiceParams {
payment_hash: Hash256,
}

#[derive(Clone, Serialize, Deserialize)]
enum InvoiceStatus {
Unpaid,
Inflight,
Paid,
Expired,
}

#[derive(Clone, Serialize, Deserialize)]
pub(crate) struct GetInvoiceResult {
invoice_address: String,
invoice: CkbInvoice,
status: InvoiceStatus,
}

#[rpc(server)]
trait InvoiceRpc {
#[method(name = "new_invoice")]
async fn new_invoice(
&self,
params: NewInvoiceParams,
) -> Result<NewInvoiceResult, ErrorObjectOwned>;
) -> Result<InvoiceResult, ErrorObjectOwned>;

#[method(name = "parse_invoice")]
async fn parse_invoice(
&self,
params: ParseInvoiceParams,
) -> Result<ParseInvoiceResult, ErrorObjectOwned>;

#[method(name = "get_invoice")]
async fn get_invoice(
&self,
payment_hash: GetInvoiceParams,
) -> Result<GetInvoiceResult, ErrorObjectOwned>;
}

pub(crate) struct InvoiceRpcServerImpl<S> {
Expand All @@ -76,12 +103,12 @@ impl<S> InvoiceRpcServerImpl<S> {
#[async_trait]
impl<S> InvoiceRpcServer for InvoiceRpcServerImpl<S>
where
S: InvoiceStore + Send + Sync + 'static,
S: InvoiceStore + NetworkGraphStateStore + Send + Sync + 'static,
{
async fn new_invoice(
&self,
params: NewInvoiceParams,
) -> Result<NewInvoiceResult, ErrorObjectOwned> {
) -> Result<InvoiceResult, ErrorObjectOwned> {
let mut invoice_builder = InvoiceBuilder::new(params.currency)
.amount(Some(params.amount))
.payment_preimage(params.payment_preimage);
Expand Down Expand Up @@ -116,7 +143,7 @@ where
.store
.insert_invoice(invoice.clone(), Some(params.payment_preimage))
{
Ok(_) => Ok(NewInvoiceResult {
Ok(_) => Ok(InvoiceResult {
invoice_address: invoice.to_string(),
invoice,
}),
Expand Down Expand Up @@ -150,4 +177,39 @@ where
)),
}
}

async fn get_invoice(
&self,
params: GetInvoiceParams,
) -> Result<GetInvoiceResult, ErrorObjectOwned> {
let payment_hash = params.payment_hash;
match self.store.get_invoice(&payment_hash) {
Some(invoice) => {
let invoice_status = if invoice.is_expired() {
InvoiceStatus::Expired
} else {
InvoiceStatus::Unpaid
};
let payment_session = self.store.get_payment_session(payment_hash);
let status = match payment_session {
Some(session) => match session.status {
PaymentSessionStatus::Inflight => InvoiceStatus::Inflight,
PaymentSessionStatus::Success => InvoiceStatus::Paid,
_ => invoice_status,
},
None => invoice_status,
};
Ok(GetInvoiceResult {
invoice_address: invoice.to_string(),
invoice,
status,
})
}
None => Err(ErrorObjectOwned::owned(
CALL_EXECUTION_FAILED_CODE,
"invoice not found".to_string(),
Some(payment_hash),
)),
}
}
}

0 comments on commit 1734bfb

Please sign in to comment.