Skip to content

Commit

Permalink
feat(provider): subscribe to new blocks if possible in heartbeat
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniPopes committed Sep 20, 2024
1 parent 57dd4c5 commit a18a111
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 165 deletions.
4 changes: 2 additions & 2 deletions crates/network-primitives/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ pub trait HeaderResponse {
/// Block JSON-RPC response.
pub trait BlockResponse {
/// Header type
type Header;
type Header: HeaderResponse;
/// Transaction type
type Transaction;
type Transaction: TransactionResponse;

/// Block header
fn header(&self) -> &Self::Header;
Expand Down
113 changes: 70 additions & 43 deletions crates/provider/src/chain.rs → crates/provider/src/blocks.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use alloy_network::{Ethereum, Network};
use alloy_primitives::{BlockNumber, U64};
use alloy_rpc_client::{NoParams, PollerBuilder, WeakClient};
use alloy_rpc_types_eth::Block;
use alloy_transport::{RpcError, Transport};
use alloy_transport::{RpcError, Transport, TransportResult};
use async_stream::stream;
use futures::{Stream, StreamExt};
use futures::{future::Either, Stream, StreamExt};
use lru::LruCache;
use std::{marker::PhantomData, num::NonZeroUsize};

Expand All @@ -17,38 +16,48 @@ const MAX_RETRIES: usize = 3;
/// Default block number for when we don't have a block yet.
const NO_BLOCK_NUMBER: BlockNumber = BlockNumber::MAX;

pub(crate) struct ChainStreamPoller<T, N = Ethereum> {
/// Streams new blocks from the client.
pub(crate) struct NewBlocks<T, N: Network = Ethereum> {
client: WeakClient<T>,
poll_task: PollerBuilder<T, NoParams, U64>,
next_yield: BlockNumber,
known_blocks: LruCache<BlockNumber, Block>,
known_blocks: LruCache<BlockNumber, N::BlockResponse>,
_phantom: PhantomData<N>,
}

impl<T: Transport + Clone, N: Network> ChainStreamPoller<T, N> {
pub(crate) fn from_weak_client(w: WeakClient<T>) -> Self {
Self::new(w)
}

impl<T: Transport + Clone, N: Network> NewBlocks<T, N> {
pub(crate) fn new(client: WeakClient<T>) -> Self {
Self::with_next_yield(client, NO_BLOCK_NUMBER)
}

/// Can be used to force the poller to start at a specific block number.
/// Mostly useful for tests.
fn with_next_yield(client: WeakClient<T>, next_yield: BlockNumber) -> Self {
Self {
client: client.clone(),
poll_task: PollerBuilder::new(client, "eth_blockNumber", []),
next_yield,
next_yield: NO_BLOCK_NUMBER,
known_blocks: LruCache::new(BLOCK_CACHE_SIZE),
_phantom: PhantomData,
}
}

pub(crate) fn into_stream(mut self) -> impl Stream<Item = Block> + 'static {
pub(crate) async fn into_stream(
self,
) -> TransportResult<impl Stream<Item = N::BlockResponse> + 'static> {
#[cfg(feature = "pubsub")]
if let Some(client) = self.client.upgrade() {
if let Some(pubsub) = client.pubsub_frontend() {
let id = client.request("eth_subscribe", ("newHeads",)).await?;
let sub = pubsub.get_subscription(id).await?;
return Ok(Either::Left(sub.into_typed::<N::BlockResponse>().into_stream()));
}
}

#[cfg(feature = "pubsub")]
let right = Either::Right;
#[cfg(not(feature = "pubsub"))]
let right = std::convert::identity;
Ok(right(self.into_poll_stream()))
}

fn into_poll_stream(mut self) -> impl Stream<Item = N::BlockResponse> + 'static {
let poll_task_builder: PollerBuilder<T, NoParams, U64> =
PollerBuilder::new(self.client.clone(), "eth_blockNumber", []);
let mut poll_task = poll_task_builder.spawn().into_stream_raw();
stream! {
let mut poll_task = self.poll_task.spawn().into_stream_raw();
'task: loop {
// Clear any buffered blocks.
while let Some(known_block) = self.known_blocks.pop(&self.next_yield) {
Expand All @@ -62,11 +71,11 @@ impl<T: Transport + Clone, N: Network> ChainStreamPoller<T, N> {
Some(Ok(block_number)) => block_number,
Some(Err(err)) => {
// This is fine.
debug!(%err, "polling stream lagged");
debug!(%err, "block number stream lagged");
continue 'task;
}
None => {
debug!("polling stream ended");
debug!("block number stream ended");
break 'task;
}
};
Expand Down Expand Up @@ -125,64 +134,82 @@ impl<T: Transport + Clone, N: Network> ChainStreamPoller<T, N> {

#[cfg(all(test, feature = "anvil-api"))] // Tests rely heavily on ability to mine blocks on demand.
mod tests {
use std::{future::Future, time::Duration};

use crate::{ext::AnvilApi, ProviderBuilder};
use super::*;
use crate::{ext::AnvilApi, Provider, ProviderBuilder};
use alloy_node_bindings::Anvil;
use alloy_primitives::U256;
use alloy_rpc_client::ReqwestClient;

use super::*;
use std::{future::Future, time::Duration};

fn init_tracing() {
let _ = tracing_subscriber::fmt::try_init();
}

async fn with_timeout<T: Future>(fut: T) -> T::Output {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(1)) => panic!("Operation timed out"),
_ = tokio::time::sleep(Duration::from_secs(2)) => panic!("Operation timed out"),
out = fut => out,
}
}

#[tokio::test]
async fn yield_block() {
async fn yield_block_http() {
yield_block(false).await;
}
#[tokio::test]
#[cfg(feature = "ws")]
async fn yield_block_ws() {
yield_block(true).await;
}
async fn yield_block(ws: bool) {
init_tracing();

let anvil = Anvil::new().spawn();

let client = ReqwestClient::new_http(anvil.endpoint_url());
let poller: ChainStreamPoller<_, Ethereum> =
ChainStreamPoller::with_next_yield(client.get_weak(), 1);
let mut stream = Box::pin(poller.into_stream());
let url = if ws { anvil.ws_endpoint() } else { anvil.endpoint() };
let provider = ProviderBuilder::new().on_builtin(&url).await.unwrap();

let poller: NewBlocks<_, Ethereum> = NewBlocks::new(provider.weak_client());
let mut stream = Box::pin(poller.into_stream().await.unwrap());

// We will also use provider to manipulate anvil instance via RPC.
let provider = ProviderBuilder::new().on_http(anvil.endpoint_url());
provider.anvil_mine(Some(U256::from(1)), None).await.unwrap();

let block = with_timeout(stream.next()).await.expect("Block wasn't fetched");
assert_eq!(block.header.number, 1);
assert!(block.header.number <= 1);
}

#[tokio::test]
async fn yield_many_blocks() {
async fn yield_many_blocks_http() {
yield_many_blocks(false).await;
}
#[tokio::test]
#[cfg(feature = "ws")]
async fn yield_many_blocks_ws() {
yield_many_blocks(true).await;
}
async fn yield_many_blocks(ws: bool) {
// Make sure that we can process more blocks than fits in the cache.
const BLOCKS_TO_MINE: usize = BLOCK_CACHE_SIZE.get() + 1;

init_tracing();

let anvil = Anvil::new().spawn();

let client = ReqwestClient::new_http(anvil.endpoint_url());
let poller: ChainStreamPoller<_, Ethereum> =
ChainStreamPoller::with_next_yield(client.get_weak(), 1);
let stream = Box::pin(poller.into_stream());
let url = if ws { anvil.ws_endpoint() } else { anvil.endpoint() };
let provider = ProviderBuilder::new().on_builtin(&url).await.unwrap();

let poller: NewBlocks<_, Ethereum> = NewBlocks::new(provider.weak_client());
let stream = Box::pin(poller.into_stream().await.unwrap());

// We will also use provider to manipulate anvil instance via RPC.
let provider = ProviderBuilder::new().on_http(anvil.endpoint_url());
provider.anvil_mine(Some(U256::from(BLOCKS_TO_MINE)), None).await.unwrap();

let blocks = with_timeout(stream.take(BLOCKS_TO_MINE).collect::<Vec<_>>()).await;
assert_eq!(blocks.len(), BLOCKS_TO_MINE);
let first = blocks[0].header.number;
assert!(first <= 1);
for (i, block) in blocks.iter().enumerate() {
assert_eq!(block.header.number, first + i as u64);
}
}
}
58 changes: 31 additions & 27 deletions crates/provider/src/heart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

use crate::{Provider, RootProvider};
use alloy_json_rpc::RpcError;
use alloy_network::Network;
use alloy_network::{BlockResponse, HeaderResponse, Network};
use alloy_primitives::{TxHash, B256};
use alloy_rpc_types_eth::Block;
use alloy_transport::{utils::Spawnable, Transport, TransportError};
use futures::{stream::StreamExt, FutureExt, Stream};
use std::{
Expand Down Expand Up @@ -74,7 +73,7 @@ pub enum PendingTransactionError {
#[must_use = "this type does nothing unless you call `register`, `watch` or `get_receipt`"]
#[derive(Debug)]
#[doc(alias = "PendingTxBuilder")]
pub struct PendingTransactionBuilder<'a, T, N> {
pub struct PendingTransactionBuilder<'a, T, N: Network> {
config: PendingTransactionConfig,
provider: &'a RootProvider<T, N>,
}
Expand Down Expand Up @@ -400,12 +399,12 @@ impl Future for PendingTransaction {

/// A handle to the heartbeat task.
#[derive(Clone, Debug)]
pub(crate) struct HeartbeatHandle {
pub(crate) struct HeartbeatHandle<N: Network> {
tx: mpsc::Sender<TxWatcher>,
latest: watch::Receiver<Option<Block>>,
latest: watch::Receiver<Option<N::BlockResponse>>,
}

impl HeartbeatHandle {
impl<N: Network> HeartbeatHandle<N> {
/// Watch for a transaction to be confirmed with the given config.
#[doc(alias = "watch_transaction")]
pub(crate) async fn watch_tx(
Expand All @@ -423,14 +422,14 @@ impl HeartbeatHandle {

/// Returns a watcher that always sees the latest block.
#[allow(dead_code)]
pub(crate) const fn latest(&self) -> &watch::Receiver<Option<Block>> {
pub(crate) const fn latest(&self) -> &watch::Receiver<Option<N::BlockResponse>> {
&self.latest
}
}

// TODO: Parameterize with `Network`
/// A heartbeat task that receives blocks and watches for transactions.
pub(crate) struct Heartbeat<S> {
pub(crate) struct Heartbeat<N, S> {
/// The stream of incoming blocks to watch.
stream: futures::stream::Fuse<S>,

Expand All @@ -445,9 +444,11 @@ pub(crate) struct Heartbeat<S> {

/// Ordered map of transactions to reap at a certain time.
reap_at: BTreeMap<Instant, B256>,

_network: std::marker::PhantomData<N>,
}

impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
impl<N: Network, S: Stream<Item = N::BlockResponse> + Unpin + 'static> Heartbeat<N, S> {
/// Create a new heartbeat task.
pub(crate) fn new(stream: S) -> Self {
Self {
Expand All @@ -456,11 +457,10 @@ impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
unconfirmed: Default::default(),
waiting_confs: Default::default(),
reap_at: Default::default(),
_network: Default::default(),
}
}
}

impl<S> Heartbeat<S> {
/// Check if any transactions have enough confirmations to notify.
fn check_confirmations(&mut self, current_height: u64) {
let to_keep = self.waiting_confs.split_off(&(current_height + 1));
Expand Down Expand Up @@ -561,9 +561,13 @@ impl<S> Heartbeat<S> {
/// Handle a new block by checking if any of the transactions we're
/// watching are in it, and if so, notifying the watcher. Also updates
/// the latest block.
fn handle_new_block(&mut self, block: Block, latest: &watch::Sender<Option<Block>>) {
fn handle_new_block(
&mut self,
block: N::BlockResponse,
latest: &watch::Sender<Option<N::BlockResponse>>,
) {
// Blocks without numbers are ignored, as they're not part of the chain.
let block_height = &block.header.number;
let block_height = block.header().number();

// Add the block the lookbehind.
// The value is chosen arbitrarily to not have a huge memory footprint but still
Expand All @@ -577,19 +581,19 @@ impl<S> Heartbeat<S> {
}
if let Some((last_height, _)) = self.past_blocks.back().as_ref() {
// Check that the chain is continuous.
if *last_height + 1 != *block_height {
if *last_height + 1 != block_height {
// Move all the transactions that were reset by the reorg to the unconfirmed list.
warn!(%block_height, last_height, "reorg detected");
self.move_reorg_to_unconfirmed(*block_height);
self.move_reorg_to_unconfirmed(block_height);
// Remove past blocks that are now invalid.
self.past_blocks.retain(|(h, _)| h < block_height);
self.past_blocks.retain(|(h, _)| *h < block_height);
}
}
self.past_blocks.push_back((*block_height, block.transactions.hashes().collect()));
self.past_blocks.push_back((block_height, block.transactions().hashes().collect()));

// Check if we are watching for any of the transactions in this block.
let to_check: Vec<_> = block
.transactions
.transactions()
.hashes()
.filter_map(|tx_hash| self.unconfirmed.remove(&tx_hash))
.collect();
Expand All @@ -607,12 +611,12 @@ impl<S> Heartbeat<S> {
warn!(tx=%watcher.config.tx_hash, set_block=%set_block, new_block=%block_height, "received_at_block already set");
// We don't override the set value.
} else {
watcher.received_at_block = Some(*block_height);
watcher.received_at_block = Some(block_height);
}
self.add_to_waiting_list(watcher, *block_height);
self.add_to_waiting_list(watcher, block_height);
}

self.check_confirmations(*block_height);
self.check_confirmations(block_height);

// Update the latest block. We use `send_replace` here to ensure the
// latest block is always up to date, even if no receivers exist.
Expand All @@ -623,7 +627,7 @@ impl<S> Heartbeat<S> {
}

#[cfg(target_arch = "wasm32")]
impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
impl<N: Network, S: Stream<Item = N::BlockResponse> + Unpin + 'static> Heartbeat<N, S> {
/// Spawn the heartbeat task, returning a [`HeartbeatHandle`].
pub(crate) fn spawn(self) -> HeartbeatHandle {
let (latest, latest_rx) = watch::channel(None::<Block>);
Expand All @@ -636,10 +640,10 @@ impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
}

#[cfg(not(target_arch = "wasm32"))]
impl<S: Stream<Item = Block> + Unpin + Send + 'static> Heartbeat<S> {
impl<N: Network, S: Stream<Item = N::BlockResponse> + Unpin + Send + 'static> Heartbeat<N, S> {
/// Spawn the heartbeat task, returning a [`HeartbeatHandle`].
pub(crate) fn spawn(self) -> HeartbeatHandle {
let (latest, latest_rx) = watch::channel(None::<Block>);
pub(crate) fn spawn(self) -> HeartbeatHandle<N> {
let (latest, latest_rx) = watch::channel(None::<N::BlockResponse>);
let (ix_tx, ixns) = mpsc::channel(16);

self.into_future(latest, ixns).spawn_task();
Expand All @@ -648,10 +652,10 @@ impl<S: Stream<Item = Block> + Unpin + Send + 'static> Heartbeat<S> {
}
}

impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
impl<N: Network, S: Stream<Item = N::BlockResponse> + Unpin + 'static> Heartbeat<N, S> {
async fn into_future(
mut self,
latest: watch::Sender<Option<Block>>,
latest: watch::Sender<Option<N::BlockResponse>>,
mut ixns: mpsc::Receiver<TxWatcher>,
) {
'shutdown: loop {
Expand Down
2 changes: 1 addition & 1 deletion crates/provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extern crate tracing;
mod builder;
pub use builder::{Identity, ProviderBuilder, ProviderLayer, Stack};

mod chain;
mod blocks;

pub mod ext;

Expand Down
Loading

0 comments on commit a18a111

Please sign in to comment.