Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(provider): subscribe to new blocks if possible in heartbeat #1321

Merged
merged 9 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/network-primitives/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ pub trait HeaderResponse {
/// Block JSON-RPC response.
pub trait BlockResponse {
/// Header type
type Header;
type Header: HeaderResponse;
/// Transaction type
type Transaction;
type Transaction: TransactionResponse;

/// Block header
fn header(&self) -> &Self::Header;
Expand Down
112 changes: 71 additions & 41 deletions crates/provider/src/chain.rs → crates/provider/src/blocks.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use alloy_network::{Ethereum, Network};
use alloy_primitives::{BlockNumber, U64};
use alloy_rpc_client::{NoParams, PollerBuilder, WeakClient};
use alloy_rpc_types_eth::Block;
use alloy_transport::{RpcError, Transport};
use alloy_transport::{RpcError, Transport, TransportResult};
use async_stream::stream;
use futures::{Stream, StreamExt};
use lru::LruCache;
use std::{marker::PhantomData, num::NonZeroUsize};

#[cfg(feature = "pubsub")]
use futures::future::Either;

/// The size of the block cache.
const BLOCK_CACHE_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(10) };

Expand All @@ -17,38 +19,48 @@ const MAX_RETRIES: usize = 3;
/// Default block number for when we don't have a block yet.
const NO_BLOCK_NUMBER: BlockNumber = BlockNumber::MAX;

pub(crate) struct ChainStreamPoller<T, N = Ethereum> {
/// Streams new blocks from the client.
pub(crate) struct NewBlocks<T, N: Network = Ethereum> {
client: WeakClient<T>,
poll_task: PollerBuilder<T, NoParams, U64>,
next_yield: BlockNumber,
known_blocks: LruCache<BlockNumber, Block>,
known_blocks: LruCache<BlockNumber, N::BlockResponse>,
_phantom: PhantomData<N>,
}

impl<T: Transport + Clone, N: Network> ChainStreamPoller<T, N> {
pub(crate) fn from_weak_client(w: WeakClient<T>) -> Self {
Self::new(w)
}

impl<T: Transport + Clone, N: Network> NewBlocks<T, N> {
pub(crate) fn new(client: WeakClient<T>) -> Self {
Self::with_next_yield(client, NO_BLOCK_NUMBER)
}

/// Can be used to force the poller to start at a specific block number.
/// Mostly useful for tests.
fn with_next_yield(client: WeakClient<T>, next_yield: BlockNumber) -> Self {
Self {
client: client.clone(),
poll_task: PollerBuilder::new(client, "eth_blockNumber", []),
next_yield,
client,
next_yield: NO_BLOCK_NUMBER,
known_blocks: LruCache::new(BLOCK_CACHE_SIZE),
_phantom: PhantomData,
}
}

pub(crate) fn into_stream(mut self) -> impl Stream<Item = Block> + 'static {
pub(crate) async fn into_stream(
self,
) -> TransportResult<impl Stream<Item = N::BlockResponse> + 'static> {
#[cfg(feature = "pubsub")]
if let Some(client) = self.client.upgrade() {
if let Some(pubsub) = client.pubsub_frontend() {
let id = client.request("eth_subscribe", ("newHeads",)).await?;
let sub = pubsub.get_subscription(id).await?;
return Ok(Either::Left(sub.into_typed::<N::BlockResponse>().into_stream()));
}
}

#[cfg(feature = "pubsub")]
let right = Either::Right;
#[cfg(not(feature = "pubsub"))]
let right = std::convert::identity;
Ok(right(self.into_poll_stream()))
}

fn into_poll_stream(mut self) -> impl Stream<Item = N::BlockResponse> + 'static {
stream! {
let mut poll_task = self.poll_task.spawn().into_stream_raw();
let poll_task_builder: PollerBuilder<T, NoParams, U64> =
PollerBuilder::new(self.client.clone(), "eth_blockNumber", []);
let mut poll_task = poll_task_builder.spawn().into_stream_raw();
'task: loop {
// Clear any buffered blocks.
while let Some(known_block) = self.known_blocks.pop(&self.next_yield) {
Expand Down Expand Up @@ -125,64 +137,82 @@ impl<T: Transport + Clone, N: Network> ChainStreamPoller<T, N> {

#[cfg(all(test, feature = "anvil-api"))] // Tests rely heavily on ability to mine blocks on demand.
mod tests {
use std::{future::Future, time::Duration};

use crate::{ext::AnvilApi, ProviderBuilder};
use super::*;
use crate::{ext::AnvilApi, Provider, ProviderBuilder};
use alloy_node_bindings::Anvil;
use alloy_primitives::U256;
use alloy_rpc_client::ReqwestClient;

use super::*;
use std::{future::Future, time::Duration};

fn init_tracing() {
let _ = tracing_subscriber::fmt::try_init();
}

async fn with_timeout<T: Future>(fut: T) -> T::Output {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(1)) => panic!("Operation timed out"),
_ = tokio::time::sleep(Duration::from_secs(2)) => panic!("Operation timed out"),
out = fut => out,
}
}

#[tokio::test]
async fn yield_block() {
async fn yield_block_http() {
yield_block(false).await;
}
#[tokio::test]
#[cfg(feature = "ws")]
async fn yield_block_ws() {
yield_block(true).await;
}
async fn yield_block(ws: bool) {
init_tracing();

let anvil = Anvil::new().spawn();

let client = ReqwestClient::new_http(anvil.endpoint_url());
let poller: ChainStreamPoller<_, Ethereum> =
ChainStreamPoller::with_next_yield(client.get_weak(), 1);
let mut stream = Box::pin(poller.into_stream());
let url = if ws { anvil.ws_endpoint() } else { anvil.endpoint() };
let provider = ProviderBuilder::new().on_builtin(&url).await.unwrap();

let poller: NewBlocks<_, Ethereum> = NewBlocks::new(provider.weak_client());
let mut stream = Box::pin(poller.into_stream().await.unwrap());

// We will also use provider to manipulate anvil instance via RPC.
let provider = ProviderBuilder::new().on_http(anvil.endpoint_url());
provider.anvil_mine(Some(U256::from(1)), None).await.unwrap();

let block = with_timeout(stream.next()).await.expect("Block wasn't fetched");
assert_eq!(block.header.number, 1);
assert!(block.header.number <= 1);
}

#[tokio::test]
async fn yield_many_blocks() {
async fn yield_many_blocks_http() {
yield_many_blocks(false).await;
}
#[tokio::test]
#[cfg(feature = "ws")]
async fn yield_many_blocks_ws() {
yield_many_blocks(true).await;
}
async fn yield_many_blocks(ws: bool) {
// Make sure that we can process more blocks than fits in the cache.
const BLOCKS_TO_MINE: usize = BLOCK_CACHE_SIZE.get() + 1;

init_tracing();

let anvil = Anvil::new().spawn();

let client = ReqwestClient::new_http(anvil.endpoint_url());
let poller: ChainStreamPoller<_, Ethereum> =
ChainStreamPoller::with_next_yield(client.get_weak(), 1);
let stream = Box::pin(poller.into_stream());
let url = if ws { anvil.ws_endpoint() } else { anvil.endpoint() };
let provider = ProviderBuilder::new().on_builtin(&url).await.unwrap();

let poller: NewBlocks<_, Ethereum> = NewBlocks::new(provider.weak_client());
let stream = Box::pin(poller.into_stream().await.unwrap());

// We will also use provider to manipulate anvil instance via RPC.
let provider = ProviderBuilder::new().on_http(anvil.endpoint_url());
provider.anvil_mine(Some(U256::from(BLOCKS_TO_MINE)), None).await.unwrap();

let blocks = with_timeout(stream.take(BLOCKS_TO_MINE).collect::<Vec<_>>()).await;
assert_eq!(blocks.len(), BLOCKS_TO_MINE);
let first = blocks[0].header.number;
assert!(first <= 1);
for (i, block) in blocks.iter().enumerate() {
assert_eq!(block.header.number, first + i as u64);
}
}
}
82 changes: 43 additions & 39 deletions crates/provider/src/heart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

use crate::{Provider, RootProvider};
use alloy_json_rpc::RpcError;
use alloy_network::Network;
use alloy_network::{BlockResponse, HeaderResponse, Network};
use alloy_primitives::{TxHash, B256};
use alloy_rpc_types_eth::Block;
use alloy_transport::{utils::Spawnable, Transport, TransportError};
use futures::{stream::StreamExt, FutureExt, Stream};
use std::{
Expand Down Expand Up @@ -74,7 +73,7 @@ pub enum PendingTransactionError {
#[must_use = "this type does nothing unless you call `register`, `watch` or `get_receipt`"]
#[derive(Debug)]
#[doc(alias = "PendingTxBuilder")]
pub struct PendingTransactionBuilder<'a, T, N> {
pub struct PendingTransactionBuilder<'a, T, N: Network> {
config: PendingTransactionConfig,
provider: &'a RootProvider<T, N>,
}
Expand Down Expand Up @@ -400,12 +399,12 @@ impl Future for PendingTransaction {

/// A handle to the heartbeat task.
#[derive(Clone, Debug)]
pub(crate) struct HeartbeatHandle {
pub(crate) struct HeartbeatHandle<N: Network> {
tx: mpsc::Sender<TxWatcher>,
latest: watch::Receiver<Option<Block>>,
latest: watch::Receiver<Option<N::BlockResponse>>,
}

impl HeartbeatHandle {
impl<N: Network> HeartbeatHandle<N> {
/// Watch for a transaction to be confirmed with the given config.
#[doc(alias = "watch_transaction")]
pub(crate) async fn watch_tx(
Expand All @@ -423,14 +422,14 @@ impl HeartbeatHandle {

/// Returns a watcher that always sees the latest block.
#[allow(dead_code)]
pub(crate) const fn latest(&self) -> &watch::Receiver<Option<Block>> {
pub(crate) const fn latest(&self) -> &watch::Receiver<Option<N::BlockResponse>> {
&self.latest
}
}

// TODO: Parameterize with `Network`
/// A heartbeat task that receives blocks and watches for transactions.
pub(crate) struct Heartbeat<S> {
pub(crate) struct Heartbeat<N, S> {
/// The stream of incoming blocks to watch.
stream: futures::stream::Fuse<S>,

Expand All @@ -445,9 +444,11 @@ pub(crate) struct Heartbeat<S> {

/// Ordered map of transactions to reap at a certain time.
reap_at: BTreeMap<Instant, B256>,

_network: std::marker::PhantomData<N>,
}

impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
impl<N: Network, S: Stream<Item = N::BlockResponse> + Unpin + 'static> Heartbeat<N, S> {
/// Create a new heartbeat task.
pub(crate) fn new(stream: S) -> Self {
Self {
Expand All @@ -456,11 +457,10 @@ impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
unconfirmed: Default::default(),
waiting_confs: Default::default(),
reap_at: Default::default(),
_network: Default::default(),
}
}
}

impl<S> Heartbeat<S> {
/// Check if any transactions have enough confirmations to notify.
fn check_confirmations(&mut self, current_height: u64) {
let to_keep = self.waiting_confs.split_off(&(current_height + 1));
Expand Down Expand Up @@ -561,9 +561,13 @@ impl<S> Heartbeat<S> {
/// Handle a new block by checking if any of the transactions we're
/// watching are in it, and if so, notifying the watcher. Also updates
/// the latest block.
fn handle_new_block(&mut self, block: Block, latest: &watch::Sender<Option<Block>>) {
fn handle_new_block(
&mut self,
block: N::BlockResponse,
latest: &watch::Sender<Option<N::BlockResponse>>,
) {
// Blocks without numbers are ignored, as they're not part of the chain.
let block_height = &block.header.number;
let block_height = block.header().number();

// Add the block the lookbehind.
// The value is chosen arbitrarily to not have a huge memory footprint but still
Expand All @@ -577,19 +581,19 @@ impl<S> Heartbeat<S> {
}
if let Some((last_height, _)) = self.past_blocks.back().as_ref() {
// Check that the chain is continuous.
if *last_height + 1 != *block_height {
if *last_height + 1 != block_height {
// Move all the transactions that were reset by the reorg to the unconfirmed list.
warn!(%block_height, last_height, "reorg detected");
self.move_reorg_to_unconfirmed(*block_height);
self.move_reorg_to_unconfirmed(block_height);
// Remove past blocks that are now invalid.
self.past_blocks.retain(|(h, _)| h < block_height);
self.past_blocks.retain(|(h, _)| *h < block_height);
}
}
self.past_blocks.push_back((*block_height, block.transactions.hashes().collect()));
self.past_blocks.push_back((block_height, block.transactions().hashes().collect()));

// Check if we are watching for any of the transactions in this block.
let to_check: Vec<_> = block
.transactions
.transactions()
.hashes()
.filter_map(|tx_hash| self.unconfirmed.remove(&tx_hash))
.collect();
Expand All @@ -607,12 +611,12 @@ impl<S> Heartbeat<S> {
warn!(tx=%watcher.config.tx_hash, set_block=%set_block, new_block=%block_height, "received_at_block already set");
// We don't override the set value.
} else {
watcher.received_at_block = Some(*block_height);
watcher.received_at_block = Some(block_height);
}
self.add_to_waiting_list(watcher, *block_height);
self.add_to_waiting_list(watcher, block_height);
}

self.check_confirmations(*block_height);
self.check_confirmations(block_height);

// Update the latest block. We use `send_replace` here to ensure the
// latest block is always up to date, even if no receivers exist.
Expand All @@ -623,35 +627,35 @@ impl<S> Heartbeat<S> {
}

#[cfg(target_arch = "wasm32")]
impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
impl<N: Network, S: Stream<Item = N::BlockResponse> + Unpin + 'static> Heartbeat<N, S> {
/// Spawn the heartbeat task, returning a [`HeartbeatHandle`].
pub(crate) fn spawn(self) -> HeartbeatHandle {
let (latest, latest_rx) = watch::channel(None::<Block>);
let (ix_tx, ixns) = mpsc::channel(16);

self.into_future(latest, ixns).spawn_task();

HeartbeatHandle { tx: ix_tx, latest: latest_rx }
pub(crate) fn spawn(self) -> HeartbeatHandle<N> {
let (task, handle) = self.consume();
task.spawn_task();
handle
}
}

#[cfg(not(target_arch = "wasm32"))]
impl<S: Stream<Item = Block> + Unpin + Send + 'static> Heartbeat<S> {
impl<N: Network, S: Stream<Item = N::BlockResponse> + Unpin + Send + 'static> Heartbeat<N, S> {
/// Spawn the heartbeat task, returning a [`HeartbeatHandle`].
pub(crate) fn spawn(self) -> HeartbeatHandle {
let (latest, latest_rx) = watch::channel(None::<Block>);
let (ix_tx, ixns) = mpsc::channel(16);

self.into_future(latest, ixns).spawn_task();

HeartbeatHandle { tx: ix_tx, latest: latest_rx }
pub(crate) fn spawn(self) -> HeartbeatHandle<N> {
let (task, handle) = self.consume();
task.spawn_task();
handle
}
}

impl<S: Stream<Item = Block> + Unpin + 'static> Heartbeat<S> {
impl<N: Network, S: Stream<Item = N::BlockResponse> + Unpin + 'static> Heartbeat<N, S> {
fn consume(self) -> (impl Future<Output = ()>, HeartbeatHandle<N>) {
let (latest, latest_rx) = watch::channel(None::<N::BlockResponse>);
let (ix_tx, ixns) = mpsc::channel(16);
(self.into_future(latest, ixns), HeartbeatHandle { tx: ix_tx, latest: latest_rx })
}

async fn into_future(
mut self,
latest: watch::Sender<Option<Block>>,
latest: watch::Sender<Option<N::BlockResponse>>,
mut ixns: mpsc::Receiver<TxWatcher>,
) {
'shutdown: loop {
Expand Down
2 changes: 1 addition & 1 deletion crates/provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ extern crate tracing;
mod builder;
pub use builder::{Identity, ProviderBuilder, ProviderLayer, Stack};

mod chain;
mod blocks;

pub mod ext;

Expand Down
Loading
Loading