use crate::server::connection::{IpcConn, JsonRpcStream};
use futures::StreamExt;
use futures_util::future::Either;
use interprocess::local_socket::{
tokio::prelude::{LocalSocketListener, LocalSocketStream},
traits::tokio::{Listener, Stream},
GenericFilePath, ListenerOptions, ToFsName,
};
use jsonrpsee::{
core::TEN_MB_SIZE_BYTES,
server::{
middleware::rpc::{RpcLoggerLayer, RpcServiceT},
stop_channel, ConnectionGuard, ConnectionPermit, IdProvider, RandomIntegerIdProvider,
ServerHandle, StopHandle,
},
BoundedSubscriptions, MethodSink, Methods,
};
use std::{
future::Future,
io,
pin::{pin, Pin},
sync::Arc,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
sync::oneshot,
};
use tower::{layer::util::Identity, Layer, Service};
use tracing::{debug, instrument, trace, warn, Instrument};
use crate::{
server::{connection::IpcConnDriver, rpc_service::RpcServiceCfg},
stream_codec::StreamCodec,
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tower::layer::{util::Stack, LayerFn};
mod connection;
mod ipc;
mod rpc_service;
pub use rpc_service::RpcService;
pub struct IpcServer<HttpMiddleware = Identity, RpcMiddleware = Identity> {
endpoint: String,
id_provider: Arc<dyn IdProvider>,
cfg: Settings,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
http_middleware: tower::ServiceBuilder<HttpMiddleware>,
}
impl<HttpMiddleware, RpcMiddleware> IpcServer<HttpMiddleware, RpcMiddleware> {
pub fn endpoint(&self) -> String {
self.endpoint.clone()
}
}
impl<HttpMiddleware, RpcMiddleware> IpcServer<HttpMiddleware, RpcMiddleware>
where
RpcMiddleware: for<'a> Layer<RpcService, Service: RpcServiceT<'a>> + Clone + Send + 'static,
HttpMiddleware: Layer<
TowerServiceNoHttp<RpcMiddleware>,
Service: Service<
String,
Response = Option<String>,
Error = Box<dyn core::error::Error + Send + Sync + 'static>,
Future: Send + Unpin,
> + Send,
> + Send
+ 'static,
{
pub async fn start(
mut self,
methods: impl Into<Methods>,
) -> Result<ServerHandle, IpcServerStartError> {
let methods = methods.into();
let (stop_handle, server_handle) = stop_channel();
let (tx, rx) = oneshot::channel();
match self.cfg.tokio_runtime.take() {
Some(rt) => rt.spawn(self.start_inner(methods, stop_handle, tx)),
None => tokio::spawn(self.start_inner(methods, stop_handle, tx)),
};
rx.await.expect("channel is open")?;
Ok(server_handle)
}
async fn start_inner(
self,
methods: Methods,
stop_handle: StopHandle,
on_ready: oneshot::Sender<Result<(), IpcServerStartError>>,
) {
trace!(endpoint = ?self.endpoint, "starting ipc server");
if cfg!(unix) {
if std::fs::remove_file(&self.endpoint).is_ok() {
debug!(endpoint = ?self.endpoint, "removed existing IPC endpoint file");
}
}
let listener = match self
.endpoint
.as_str()
.to_fs_name::<GenericFilePath>()
.and_then(|name| ListenerOptions::new().name(name).create_tokio())
{
Ok(listener) => listener,
Err(err) => {
on_ready
.send(Err(IpcServerStartError { endpoint: self.endpoint.clone(), source: err }))
.ok();
return;
}
};
on_ready.send(Ok(())).ok();
let mut id: u32 = 0;
let connection_guard = ConnectionGuard::new(self.cfg.max_connections as usize);
let stopped = stop_handle.clone().shutdown();
let mut stopped = pin!(stopped);
let (drop_on_completion, mut process_connection_awaiter) = mpsc::channel::<()>(1);
trace!("accepting ipc connections");
loop {
match try_accept_conn(&listener, stopped).await {
AcceptConnection::Established { local_socket_stream, stop } => {
let Some(conn_permit) = connection_guard.try_acquire() else {
let (_reader, mut writer) = local_socket_stream.split();
let _ = writer
.write_all(b"Too many connections. Please try again later.")
.await;
stopped = stop;
continue;
};
let max_conns = connection_guard.max_connections();
let curr_conns = max_conns - connection_guard.available_connections();
trace!("Accepting new connection {}/{}", curr_conns, max_conns);
let conn_permit = Arc::new(conn_permit);
process_connection(ProcessConnection {
http_middleware: &self.http_middleware,
rpc_middleware: self.rpc_middleware.clone(),
conn_permit,
conn_id: id,
server_cfg: self.cfg.clone(),
stop_handle: stop_handle.clone(),
drop_on_completion: drop_on_completion.clone(),
methods: methods.clone(),
id_provider: self.id_provider.clone(),
local_socket_stream,
});
id = id.wrapping_add(1);
stopped = stop;
}
AcceptConnection::Shutdown => {
break;
}
AcceptConnection::Err((err, stop)) => {
tracing::error!(%err, "Failed accepting a new IPC connection");
stopped = stop;
}
}
}
drop(drop_on_completion);
while process_connection_awaiter.recv().await.is_some() {
}
}
}
enum AcceptConnection<S> {
Shutdown,
Established { local_socket_stream: LocalSocketStream, stop: S },
Err((io::Error, S)),
}
async fn try_accept_conn<S>(listener: &LocalSocketListener, stopped: S) -> AcceptConnection<S>
where
S: Future + Unpin,
{
match futures_util::future::select(pin!(listener.accept()), stopped).await {
Either::Left((res, stop)) => match res {
Ok(local_socket_stream) => AcceptConnection::Established { local_socket_stream, stop },
Err(e) => AcceptConnection::Err((e, stop)),
},
Either::Right(_) => AcceptConnection::Shutdown,
}
}
impl std::fmt::Debug for IpcServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IpcServer")
.field("endpoint", &self.endpoint)
.field("cfg", &self.cfg)
.field("id_provider", &self.id_provider)
.finish()
}
}
#[derive(Debug, thiserror::Error)]
#[error("failed to listen on ipc endpoint `{endpoint}`: {source}")]
pub struct IpcServerStartError {
endpoint: String,
#[source]
source: io::Error,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct ServiceData {
pub(crate) methods: Methods,
pub(crate) id_provider: Arc<dyn IdProvider>,
pub(crate) stop_handle: StopHandle,
pub(crate) conn_id: u32,
pub(crate) conn_permit: Arc<ConnectionPermit>,
pub(crate) bounded_subscriptions: BoundedSubscriptions,
pub(crate) method_sink: MethodSink,
pub(crate) server_cfg: Settings,
}
#[derive(Debug, Clone)]
pub struct RpcServiceBuilder<L>(tower::ServiceBuilder<L>);
impl Default for RpcServiceBuilder<Identity> {
fn default() -> Self {
Self(tower::ServiceBuilder::new())
}
}
impl RpcServiceBuilder<Identity> {
pub fn new() -> Self {
Self(tower::ServiceBuilder::new())
}
}
impl<L> RpcServiceBuilder<L> {
pub fn option_layer<T>(
self,
layer: Option<T>,
) -> RpcServiceBuilder<Stack<Either<T, Identity>, L>> {
let layer = if let Some(layer) = layer {
Either::Left(layer)
} else {
Either::Right(Identity::new())
};
self.layer(layer)
}
pub fn layer<T>(self, layer: T) -> RpcServiceBuilder<Stack<T, L>> {
RpcServiceBuilder(self.0.layer(layer))
}
pub fn layer_fn<F>(self, f: F) -> RpcServiceBuilder<Stack<LayerFn<F>, L>> {
RpcServiceBuilder(self.0.layer_fn(f))
}
pub fn rpc_logger(self, max_log_len: u32) -> RpcServiceBuilder<Stack<RpcLoggerLayer, L>> {
RpcServiceBuilder(self.0.layer(RpcLoggerLayer::new(max_log_len)))
}
pub(crate) fn service<S>(&self, service: S) -> L::Service
where
L: tower::Layer<S>,
{
self.0.service(service)
}
}
#[derive(Debug, Clone)]
pub struct TowerServiceNoHttp<L> {
inner: ServiceData,
rpc_middleware: RpcServiceBuilder<L>,
}
impl<RpcMiddleware> Service<String> for TowerServiceNoHttp<RpcMiddleware>
where
RpcMiddleware: for<'a> Layer<RpcService>,
for<'a> <RpcMiddleware as Layer<RpcService>>::Service: Send + Sync + 'static + RpcServiceT<'a>,
{
type Response = Option<String>;
type Error = Box<dyn core::error::Error + Send + Sync + 'static>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: String) -> Self::Future {
trace!("{:?}", request);
let cfg = RpcServiceCfg::CallsAndSubscriptions {
bounded_subscriptions: BoundedSubscriptions::new(
self.inner.server_cfg.max_subscriptions_per_connection,
),
id_provider: self.inner.id_provider.clone(),
sink: self.inner.method_sink.clone(),
};
let max_response_body_size = self.inner.server_cfg.max_response_body_size as usize;
let max_request_body_size = self.inner.server_cfg.max_request_body_size as usize;
let conn = self.inner.conn_permit.clone();
let rpc_service = self.rpc_middleware.service(RpcService::new(
self.inner.methods.clone(),
max_response_body_size,
self.inner.conn_id.into(),
cfg,
));
let f = tokio::task::spawn(async move {
ipc::call_with_service(
request,
rpc_service,
max_response_body_size,
max_request_body_size,
conn,
)
.await
});
Box::pin(async move { f.await.map_err(|err| err.into()) })
}
}
struct ProcessConnection<'a, HttpMiddleware, RpcMiddleware> {
http_middleware: &'a tower::ServiceBuilder<HttpMiddleware>,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
conn_permit: Arc<ConnectionPermit>,
conn_id: u32,
server_cfg: Settings,
stop_handle: StopHandle,
drop_on_completion: mpsc::Sender<()>,
methods: Methods,
id_provider: Arc<dyn IdProvider>,
local_socket_stream: LocalSocketStream,
}
#[instrument(name = "connection", skip_all, fields(conn_id = %params.conn_id), level = "INFO")]
fn process_connection<'b, RpcMiddleware, HttpMiddleware>(
params: ProcessConnection<'_, HttpMiddleware, RpcMiddleware>,
) where
RpcMiddleware: Layer<RpcService> + Clone + Send + 'static,
for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service: Send
+ Service<
String,
Response = Option<String>,
Error = Box<dyn core::error::Error + Send + Sync + 'static>,
>,
<<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<String>>::Future:
Send + Unpin,
{
let ProcessConnection {
http_middleware,
rpc_middleware,
conn_permit,
conn_id,
server_cfg,
stop_handle,
drop_on_completion,
id_provider,
methods,
local_socket_stream,
} = params;
let ipc = IpcConn(tokio_util::codec::Decoder::framed(
StreamCodec::stream_incoming(),
local_socket_stream,
));
let (tx, rx) = mpsc::channel::<String>(server_cfg.message_buffer_capacity as usize);
let method_sink = MethodSink::new_with_limit(tx, server_cfg.max_response_body_size);
let tower_service = TowerServiceNoHttp {
inner: ServiceData {
methods,
id_provider,
stop_handle: stop_handle.clone(),
server_cfg: server_cfg.clone(),
conn_id,
conn_permit,
bounded_subscriptions: BoundedSubscriptions::new(
server_cfg.max_subscriptions_per_connection,
),
method_sink,
},
rpc_middleware,
};
let service = http_middleware.service(tower_service);
tokio::spawn(async {
to_ipc_service(ipc, service, stop_handle, rx).in_current_span().await;
drop(drop_on_completion)
});
}
async fn to_ipc_service<S, T>(
ipc: IpcConn<JsonRpcStream<T>>,
service: S,
stop_handle: StopHandle,
rx: mpsc::Receiver<String>,
) where
S: Service<String, Response = Option<String>> + Send + 'static,
S::Error: Into<Box<dyn core::error::Error + Send + Sync>>,
S::Future: Send + Unpin,
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let rx_item = ReceiverStream::new(rx);
let conn = IpcConnDriver {
conn: ipc,
service,
pending_calls: Default::default(),
items: Default::default(),
};
let stopped = stop_handle.shutdown();
let mut conn = pin!(conn);
let mut rx_item = pin!(rx_item);
let mut stopped = pin!(stopped);
loop {
tokio::select! {
_ = &mut conn => {
break
}
item = rx_item.next() => {
if let Some(item) = item {
conn.push_back(item);
}
}
_ = &mut stopped => {
break
}
}
}
}
#[derive(Debug, Clone)]
pub struct Settings {
max_request_body_size: u32,
max_response_body_size: u32,
max_log_length: u32,
max_connections: u32,
max_subscriptions_per_connection: u32,
message_buffer_capacity: u32,
tokio_runtime: Option<tokio::runtime::Handle>,
}
impl Default for Settings {
fn default() -> Self {
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
max_response_body_size: TEN_MB_SIZE_BYTES,
max_log_length: 4096,
max_connections: 100,
max_subscriptions_per_connection: 1024,
message_buffer_capacity: 1024,
tokio_runtime: None,
}
}
}
#[derive(Debug)]
pub struct Builder<HttpMiddleware, RpcMiddleware> {
settings: Settings,
id_provider: Arc<dyn IdProvider>,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
http_middleware: tower::ServiceBuilder<HttpMiddleware>,
}
impl Default for Builder<Identity, Identity> {
fn default() -> Self {
Self {
settings: Settings::default(),
id_provider: Arc::new(RandomIntegerIdProvider),
rpc_middleware: RpcServiceBuilder::new(),
http_middleware: tower::ServiceBuilder::new(),
}
}
}
impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
pub const fn max_request_body_size(mut self, size: u32) -> Self {
self.settings.max_request_body_size = size;
self
}
pub const fn max_response_body_size(mut self, size: u32) -> Self {
self.settings.max_response_body_size = size;
self
}
pub const fn max_log_length(mut self, size: u32) -> Self {
self.settings.max_log_length = size;
self
}
pub const fn max_connections(mut self, max: u32) -> Self {
self.settings.max_connections = max;
self
}
pub const fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
self.settings.max_subscriptions_per_connection = max;
self
}
pub const fn set_message_buffer_capacity(mut self, c: u32) -> Self {
self.settings.message_buffer_capacity = c;
self
}
pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self {
self.settings.tokio_runtime = Some(rt);
self
}
pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
self.id_provider = Arc::new(id_provider);
self
}
pub fn set_http_middleware<T>(
self,
service_builder: tower::ServiceBuilder<T>,
) -> Builder<T, RpcMiddleware> {
Builder {
settings: self.settings,
id_provider: self.id_provider,
http_middleware: service_builder,
rpc_middleware: self.rpc_middleware,
}
}
pub fn set_rpc_middleware<T>(
self,
rpc_middleware: RpcServiceBuilder<T>,
) -> Builder<HttpMiddleware, T> {
Builder {
settings: self.settings,
id_provider: self.id_provider,
rpc_middleware,
http_middleware: self.http_middleware,
}
}
pub fn build(self, endpoint: String) -> IpcServer<HttpMiddleware, RpcMiddleware> {
IpcServer {
endpoint,
cfg: self.settings,
id_provider: self.id_provider,
http_middleware: self.http_middleware,
rpc_middleware: self.rpc_middleware,
}
}
}
#[cfg(test)]
#[allow(missing_docs)]
pub fn dummy_name() -> String {
let num: u64 = rand::Rng::gen(&mut rand::thread_rng());
if cfg!(windows) {
format!(r"\\.\pipe\my-pipe-{}", num)
} else {
format!(r"/tmp/my-uds-{}", num)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::IpcClientBuilder;
use futures::future::select;
use jsonrpsee::{
core::{
client,
client::{ClientT, Error, Subscription, SubscriptionClientT},
params::BatchRequestBuilder,
},
rpc_params,
types::Request,
PendingSubscriptionSink, RpcModule, SubscriptionMessage,
};
use reth_tracing::init_test_tracing;
use std::pin::pin;
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
async fn pipe_from_stream_with_bounded_buffer(
pending: PendingSubscriptionSink,
stream: BroadcastStream<usize>,
) -> Result<(), Box<dyn core::error::Error + Send + Sync>> {
let sink = pending.accept().await.unwrap();
let closed = sink.closed();
let mut closed = pin!(closed);
let mut stream = pin!(stream);
loop {
match select(closed, stream.next()).await {
Either::Left((_, _)) | Either::Right((None, _)) => break Ok(()),
Either::Right((Some(Ok(item)), c)) => {
let notif = SubscriptionMessage::from_json(&item)?;
if sink.send(notif).await.is_err() {
break Ok(());
}
closed = c;
}
Either::Right((Some(Err(e)), _)) => break Err(e.into()),
}
}
}
fn produce_items(tx: broadcast::Sender<usize>) {
for c in 1..=100 {
std::thread::sleep(std::time::Duration::from_millis(1));
let _ = tx.send(c);
}
}
#[tokio::test]
async fn can_set_the_max_response_body_size() {
let endpoint = &dummy_name();
let server = Builder::default().max_response_body_size(100).build(endpoint.clone());
let mut module = RpcModule::new(());
module.register_method("anything", |_, _, _| "a".repeat(101)).unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client = IpcClientBuilder::default().build(endpoint).await.unwrap();
let response: Result<String, Error> = client.request("anything", rpc_params![]).await;
assert!(response.unwrap_err().to_string().contains("Exceeded max limit of"));
}
#[tokio::test]
async fn can_set_the_max_request_body_size() {
init_test_tracing();
let endpoint = &dummy_name();
let server = Builder::default().max_request_body_size(100).build(endpoint.clone());
let mut module = RpcModule::new(());
module.register_method("anything", |_, _, _| "succeed").unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client = IpcClientBuilder::default().build(endpoint).await.unwrap();
let response: Result<String, Error> =
client.request("anything", rpc_params!["a".repeat(101)]).await;
assert!(response.is_err());
let mut batch_request_builder = BatchRequestBuilder::new();
let _ = batch_request_builder.insert("anything", rpc_params![]);
let _ = batch_request_builder.insert("anything", rpc_params![]);
let _ = batch_request_builder.insert("anything", rpc_params![]);
let response: Result<client::BatchResponse<'_, String>, Error> =
client.batch_request(batch_request_builder).await;
assert!(response.is_err());
}
#[tokio::test]
async fn can_set_max_connections() {
init_test_tracing();
let endpoint = &dummy_name();
let server = Builder::default().max_connections(2).build(endpoint.clone());
let mut module = RpcModule::new(());
module.register_method("anything", |_, _, _| "succeed").unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client1 = IpcClientBuilder::default().build(endpoint).await.unwrap();
let client2 = IpcClientBuilder::default().build(endpoint).await.unwrap();
let client3 = IpcClientBuilder::default().build(endpoint).await.unwrap();
let response1: Result<String, Error> = client1.request("anything", rpc_params![]).await;
let response2: Result<String, Error> = client2.request("anything", rpc_params![]).await;
let response3: Result<String, Error> = client3.request("anything", rpc_params![]).await;
assert!(response1.is_ok());
assert!(response2.is_ok());
assert!(response3.is_err());
drop(client2);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let client4 = IpcClientBuilder::default().build(endpoint).await.unwrap();
let response4: Result<String, Error> = client4.request("anything", rpc_params![]).await;
assert!(response4.is_ok());
}
#[tokio::test]
async fn test_rpc_request() {
init_test_tracing();
let endpoint = &dummy_name();
let server = Builder::default().build(endpoint.clone());
let mut module = RpcModule::new(());
let msg = r#"{"jsonrpc":"2.0","id":83,"result":"0x7a69"}"#;
module.register_method("eth_chainId", move |_, _, _| msg).unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client = IpcClientBuilder::default().build(endpoint).await.unwrap();
let response: String = client.request("eth_chainId", rpc_params![]).await.unwrap();
assert_eq!(response, msg);
}
#[tokio::test]
async fn test_batch_request() {
let endpoint = &dummy_name();
let server = Builder::default().build(endpoint.clone());
let mut module = RpcModule::new(());
module.register_method("anything", |_, _, _| "ok").unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client = IpcClientBuilder::default().build(endpoint).await.unwrap();
let mut batch_request_builder = BatchRequestBuilder::new();
let _ = batch_request_builder.insert("anything", rpc_params![]);
let _ = batch_request_builder.insert("anything", rpc_params![]);
let _ = batch_request_builder.insert("anything", rpc_params![]);
let result = client
.batch_request(batch_request_builder)
.await
.unwrap()
.into_ok()
.unwrap()
.collect::<Vec<String>>();
assert_eq!(result, vec!["ok", "ok", "ok"]);
}
#[tokio::test]
async fn test_ipc_modules() {
reth_tracing::init_test_tracing();
let endpoint = &dummy_name();
let server = Builder::default().build(endpoint.clone());
let mut module = RpcModule::new(());
let msg = r#"{"admin":"1.0","debug":"1.0","engine":"1.0","eth":"1.0","ethash":"1.0","miner":"1.0","net":"1.0","rpc":"1.0","txpool":"1.0","web3":"1.0"}"#;
module.register_method("rpc_modules", move |_, _, _| msg).unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client = IpcClientBuilder::default().build(endpoint).await.unwrap();
let response: String = client.request("rpc_modules", rpc_params![]).await.unwrap();
assert_eq!(response, msg);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_rpc_subscription() {
let endpoint = &dummy_name();
let server = Builder::default().build(endpoint.clone());
let (tx, _rx) = broadcast::channel::<usize>(16);
let mut module = RpcModule::new(tx.clone());
std::thread::spawn(move || produce_items(tx));
module
.register_subscription(
"subscribe_hello",
"s_hello",
"unsubscribe_hello",
|_, pending, tx, _| async move {
let rx = tx.subscribe();
let stream = BroadcastStream::new(rx);
pipe_from_stream_with_bounded_buffer(pending, stream).await?;
Ok(())
},
)
.unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client = IpcClientBuilder::default().build(endpoint).await.unwrap();
let sub: Subscription<usize> =
client.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await.unwrap();
let items = sub.take(16).collect::<Vec<_>>().await;
assert_eq!(items.len(), 16);
}
#[tokio::test]
async fn test_rpc_middleware() {
#[derive(Clone)]
struct ModifyRequestIf<S>(S);
impl<'a, S> RpcServiceT<'a> for ModifyRequestIf<S>
where
S: Send + Sync + RpcServiceT<'a>,
{
type Future = S::Future;
fn call(&self, mut req: Request<'a>) -> Self::Future {
if req.method == "say_hello" {
req.method = "say_goodbye".into();
} else if req.method == "say_goodbye" {
req.method = "say_hello".into();
}
self.0.call(req)
}
}
reth_tracing::init_test_tracing();
let endpoint = &dummy_name();
let rpc_middleware = RpcServiceBuilder::new().layer_fn(ModifyRequestIf);
let server = Builder::default().set_rpc_middleware(rpc_middleware).build(endpoint.clone());
let mut module = RpcModule::new(());
let goodbye_msg = r#"{"jsonrpc":"2.0","id":1,"result":"goodbye"}"#;
let hello_msg = r#"{"jsonrpc":"2.0","id":2,"result":"hello"}"#;
module.register_method("say_hello", move |_, _, _| hello_msg).unwrap();
module.register_method("say_goodbye", move |_, _, _| goodbye_msg).unwrap();
let handle = server.start(module).await.unwrap();
tokio::spawn(handle.stopped());
let client = IpcClientBuilder::default().build(endpoint).await.unwrap();
let say_hello_response: String = client.request("say_hello", rpc_params![]).await.unwrap();
let say_goodbye_response: String =
client.request("say_goodbye", rpc_params![]).await.unwrap();
assert_eq!(say_hello_response, goodbye_msg);
assert_eq!(say_goodbye_response, hello_msg);
}
}