diff --git a/crates/provider/src/lib.rs b/crates/provider/src/lib.rs index 314e1ff2494f..9be20928ea99 100644 --- a/crates/provider/src/lib.rs +++ b/crates/provider/src/lib.rs @@ -40,8 +40,8 @@ pub use heart::{PendingTransaction, PendingTransactionBuilder, PendingTransactio mod provider; pub use provider::{ - builder, EthCall, FilterPollerBuilder, Provider, RootProvider, RpcWithBlock, SendableTx, - WalletProvider, + builder, EthCall, FilterPollerBuilder, Provider, RootProvider, RpcWithBlock, RpcWithBlockFut, + SendableTx, WalletProvider, }; pub mod utils; diff --git a/crates/provider/src/provider/mod.rs b/crates/provider/src/provider/mod.rs index 0d8e939ed5a3..7528847e7873 100644 --- a/crates/provider/src/provider/mod.rs +++ b/crates/provider/src/provider/mod.rs @@ -14,4 +14,4 @@ mod wallet; pub use wallet::WalletProvider; mod with_block; -pub use with_block::RpcWithBlock; +pub use with_block::{RpcWithBlock, RpcWithBlockFut}; diff --git a/crates/provider/src/provider/with_block.rs b/crates/provider/src/provider/with_block.rs index 808b72a15f6e..cfbbe1d6905e 100644 --- a/crates/provider/src/provider/with_block.rs +++ b/crates/provider/src/provider/with_block.rs @@ -11,6 +11,13 @@ use std::{ task::Poll, }; +#[cfg(feature = "trace-api")] +use { + crate::ext::TraceRpcWithBlock, + alloy_rpc_types_trace::parity::TraceType, + std::{collections::HashSet, ops::Deref}, +}; + /// States of the #[derive(Clone)] enum States Output> @@ -26,6 +33,8 @@ where method: Cow<'static, str>, params: Params, block_id: BlockId, + #[cfg(feature = "trace-api")] + trace_types: HashSet, map: Map, }, Running(RpcCall), @@ -80,6 +89,15 @@ where cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.project(); + + #[cfg(feature = "trace-api")] + let States::Preparing { client, method, params, block_id, trace_types, map } = + std::mem::replace(this.state, States::Invalid) + else { + unreachable!("bad state") + }; + + #[cfg(not(feature = "trace-api"))] let States::Preparing { client, method, params, block_id, map } = std::mem::replace(this.state, States::Invalid) else { @@ -110,10 +128,26 @@ where // append the block id to the params if let serde_json::Value::Array(ref mut arr) = ser { arr.push(block_id); + #[cfg(feature = "trace-api")] + if !trace_types.is_empty() { + arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?); + } } else if ser.is_null() { - ser = serde_json::Value::Array(vec![block_id]); + let mut arr = vec![]; + arr.push(block_id); + #[cfg(feature = "trace-api")] + if !trace_types.is_empty() { + arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?); + } + ser = serde_json::Value::Array(arr); } else { - ser = serde_json::Value::Array(vec![ser, block_id]); + let mut arr = vec![ser]; + arr.push(block_id); + #[cfg(feature = "trace-api")] + if !trace_types.is_empty() { + arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?); + } + ser = serde_json::Value::Array(arr); } // create the call @@ -173,6 +207,8 @@ where method: Cow<'static, str>, params: Params, block_id: BlockId, + #[cfg(feature = "trace-api")] + trace_types: HashSet, map: Map, _pd: PhantomData (Resp, Output)>, } @@ -194,6 +230,8 @@ where method: method.into(), params, block_id: Default::default(), + #[cfg(feature = "trace-api")] + trace_types: vec![TraceType::Trace].into_iter().collect(), map: std::convert::identity, _pd: PhantomData, } @@ -220,6 +258,8 @@ where method: self.method, params: self.params, block_id: self.block_id, + #[cfg(feature = "trace-api")] + trace_types: self.trace_types, map, _pd: PhantomData, } @@ -293,8 +333,34 @@ where method: self.method, params: self.params, block_id: self.block_id, + #[cfg(feature = "trace-api")] + trace_types: self.trace_types, map: self.map, }, } } } + +#[cfg(feature = "trace-api")] +impl From> + for RpcWithBlock +where + T: Transport + Clone, + Params: RpcParam, + Resp: RpcReturn, + Output: 'static, + Map: Fn(Resp) -> Output + 'static + Copy, +{ + fn from(trace_rpc: TraceRpcWithBlock) -> Self { + let rpc = trace_rpc.deref(); + RpcWithBlock { + client: rpc.client.clone(), + method: rpc.method.clone(), + params: rpc.params.clone(), + block_id: rpc.block_id.clone(), + trace_types: trace_rpc.get_trace_types().clone(), + map: rpc.map, + _pd: rpc._pd, + } + } +}