Skip to content

Commit

Permalink
provider: Introduce trace_types to RpcWithBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
moricho committed Jul 21, 2024
1 parent 9829039 commit d1f0463
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 5 deletions.
4 changes: 2 additions & 2 deletions crates/provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ pub use heart::{PendingTransaction, PendingTransactionBuilder, PendingTransactio

mod provider;
pub use provider::{
builder, EthCall, FilterPollerBuilder, Provider, RootProvider, RpcWithBlock, SendableTx,
WalletProvider,
builder, EthCall, FilterPollerBuilder, Provider, RootProvider, RpcWithBlock, RpcWithBlockFut,
SendableTx, WalletProvider,
};

pub mod utils;
Expand Down
2 changes: 1 addition & 1 deletion crates/provider/src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ mod wallet;
pub use wallet::WalletProvider;

mod with_block;
pub use with_block::RpcWithBlock;
pub use with_block::{RpcWithBlock, RpcWithBlockFut};
70 changes: 68 additions & 2 deletions crates/provider/src/provider/with_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ use std::{
task::Poll,
};

#[cfg(feature = "trace-api")]
use {
crate::ext::TraceRpcWithBlock,
alloy_rpc_types_trace::parity::TraceType,
std::{collections::HashSet, ops::Deref},
};

/// States of the
#[derive(Clone)]
enum States<T, Params, Resp, Output = Resp, Map = fn(Resp) -> Output>
Expand All @@ -26,6 +33,8 @@ where
method: Cow<'static, str>,
params: Params,
block_id: BlockId,
#[cfg(feature = "trace-api")]
trace_types: HashSet<TraceType>,
map: Map,
},
Running(RpcCall<T, serde_json::Value, Resp, Output, Map>),
Expand Down Expand Up @@ -80,6 +89,15 @@ where
cx: &mut std::task::Context<'_>,
) -> Poll<TransportResult<Output>> {
let this = self.project();

#[cfg(feature = "trace-api")]
let States::Preparing { client, method, params, block_id, trace_types, map } =
std::mem::replace(this.state, States::Invalid)
else {
unreachable!("bad state")
};

#[cfg(not(feature = "trace-api"))]
let States::Preparing { client, method, params, block_id, map } =
std::mem::replace(this.state, States::Invalid)
else {
Expand Down Expand Up @@ -110,10 +128,26 @@ where
// append the block id to the params
if let serde_json::Value::Array(ref mut arr) = ser {
arr.push(block_id);
#[cfg(feature = "trace-api")]
if !trace_types.is_empty() {
arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?);
}
} else if ser.is_null() {
ser = serde_json::Value::Array(vec![block_id]);
let mut arr = vec![];
arr.push(block_id);
#[cfg(feature = "trace-api")]
if !trace_types.is_empty() {
arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?);
}
ser = serde_json::Value::Array(arr);
} else {
ser = serde_json::Value::Array(vec![ser, block_id]);
let mut arr = vec![ser];
arr.push(block_id);
#[cfg(feature = "trace-api")]
if !trace_types.is_empty() {
arr.push(serde_json::to_value(trace_types).map_err(RpcError::ser_err)?);
}
ser = serde_json::Value::Array(arr);
}

// create the call
Expand Down Expand Up @@ -173,6 +207,8 @@ where
method: Cow<'static, str>,
params: Params,
block_id: BlockId,
#[cfg(feature = "trace-api")]
trace_types: HashSet<TraceType>,
map: Map,
_pd: PhantomData<fn() -> (Resp, Output)>,
}
Expand All @@ -194,6 +230,8 @@ where
method: method.into(),
params,
block_id: Default::default(),
#[cfg(feature = "trace-api")]
trace_types: vec![TraceType::Trace].into_iter().collect(),
map: std::convert::identity,
_pd: PhantomData,
}
Expand All @@ -220,6 +258,8 @@ where
method: self.method,
params: self.params,
block_id: self.block_id,
#[cfg(feature = "trace-api")]
trace_types: self.trace_types,
map,
_pd: PhantomData,
}
Expand Down Expand Up @@ -293,8 +333,34 @@ where
method: self.method,
params: self.params,
block_id: self.block_id,
#[cfg(feature = "trace-api")]
trace_types: self.trace_types,
map: self.map,
},
}
}
}

#[cfg(feature = "trace-api")]
impl<T, Params, Resp, Output, Map> From<TraceRpcWithBlock<T, Params, Resp, Output, Map>>
for RpcWithBlock<T, Params, Resp, Output, Map>
where
T: Transport + Clone,
Params: RpcParam,
Resp: RpcReturn,
Output: 'static,
Map: Fn(Resp) -> Output + 'static + Copy,
{
fn from(trace_rpc: TraceRpcWithBlock<T, Params, Resp, Output, Map>) -> Self {
let rpc = trace_rpc.deref();
RpcWithBlock {
client: rpc.client.clone(),
method: rpc.method.clone(),
params: rpc.params.clone(),
block_id: rpc.block_id.clone(),
trace_types: trace_rpc.get_trace_types().clone(),
map: rpc.map,
_pd: rpc._pd,
}
}
}

0 comments on commit d1f0463

Please sign in to comment.