1use crate::{
7 error::PoolError, AddedTransactionOutcome, PoolTransaction, TransactionOrigin, TransactionPool,
8};
9use pin_project::pin_project;
10use std::{
11 future::Future,
12 pin::Pin,
13 task::{ready, Context, Poll},
14};
15use tokio::sync::{mpsc, oneshot};
16
17#[derive(Debug)]
19pub struct BatchTxRequest<T: PoolTransaction> {
20 origin: TransactionOrigin,
22 pool_tx: T,
24 response_tx: oneshot::Sender<Result<AddedTransactionOutcome, PoolError>>,
26}
27
28impl<T> BatchTxRequest<T>
29where
30 T: PoolTransaction,
31{
32 pub const fn new(
34 origin: TransactionOrigin,
35 pool_tx: T,
36 response_tx: oneshot::Sender<Result<AddedTransactionOutcome, PoolError>>,
37 ) -> Self {
38 Self { origin, pool_tx, response_tx }
39 }
40}
41
42#[pin_project]
44#[derive(Debug)]
45pub struct BatchTxProcessor<Pool: TransactionPool> {
46 pool: Pool,
47 max_batch_size: usize,
48 buf: Vec<BatchTxRequest<Pool::Transaction>>,
49 #[pin]
50 request_rx: mpsc::UnboundedReceiver<BatchTxRequest<Pool::Transaction>>,
51}
52
53impl<Pool> BatchTxProcessor<Pool>
54where
55 Pool: TransactionPool + 'static,
56{
57 pub fn new(
59 pool: Pool,
60 max_batch_size: usize,
61 ) -> (Self, mpsc::UnboundedSender<BatchTxRequest<Pool::Transaction>>) {
62 let (request_tx, request_rx) = mpsc::unbounded_channel();
63
64 let processor = Self { pool, max_batch_size, buf: Vec::with_capacity(1), request_rx };
65
66 (processor, request_tx)
67 }
68
69 async fn process_request(pool: &Pool, req: BatchTxRequest<Pool::Transaction>) {
70 let BatchTxRequest { origin, pool_tx, response_tx } = req;
71 let pool_result = pool.add_transaction(origin, pool_tx).await;
72 let _ = response_tx.send(pool_result);
73 }
74
75 async fn process_batch(pool: &Pool, batch: Vec<BatchTxRequest<Pool::Transaction>>) {
77 if batch.len() == 1 {
78 Self::process_request(pool, batch.into_iter().next().expect("batch is not empty"))
79 .await;
80 return
81 }
82
83 let mut batch_iter = batch.iter();
85 if let Some(origin) = batch_iter.next().map(|req| req.origin) &&
86 batch_iter.all(|req| req.origin == origin)
87 {
88 let (transactions, response_txs): (Vec<_>, Vec<_>) =
89 batch.into_iter().map(|req| (req.pool_tx, req.response_tx)).unzip();
90
91 let pool_results = pool.add_transactions(origin, transactions).await;
92 for (response_tx, pool_result) in response_txs.into_iter().zip(pool_results) {
93 let _ = response_tx.send(pool_result);
94 }
95 return
96 }
97
98 let (transactions, response_txs): (Vec<_>, Vec<_>) =
99 batch.into_iter().map(|req| ((req.origin, req.pool_tx), req.response_tx)).unzip();
100
101 let pool_results = pool.add_transactions_with_origins(transactions).await;
102 for (response_tx, pool_result) in response_txs.into_iter().zip(pool_results) {
103 let _ = response_tx.send(pool_result);
104 }
105 }
106}
107
108impl<Pool> Future for BatchTxProcessor<Pool>
109where
110 Pool: TransactionPool + 'static,
111{
112 type Output = ();
113
114 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
115 let mut this = self.project();
116
117 loop {
118 ready!(this.request_rx.poll_recv_many(cx, this.buf, *this.max_batch_size));
120
121 if !this.buf.is_empty() {
122 let batch = std::mem::take(this.buf);
123 let pool = this.pool.clone();
124 tokio::spawn(async move {
125 Self::process_batch(&pool, batch).await;
126 });
127 this.buf.reserve(1);
128
129 continue;
130 }
131
132 return Poll::Pending;
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::test_utils::{testing_pool, MockTransaction};
142 use futures::stream::{FuturesUnordered, StreamExt};
143 use std::time::Duration;
144 use tokio::time::timeout;
145
146 #[tokio::test]
147 async fn test_process_batch() {
148 let pool = testing_pool();
149
150 let mut batch_requests = Vec::new();
151 let mut responses = Vec::new();
152
153 for i in 0..100 {
154 let tx = MockTransaction::legacy().with_nonce(i).with_gas_price(100);
155 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
156
157 batch_requests.push(BatchTxRequest::new(TransactionOrigin::Local, tx, response_tx));
158 responses.push(response_rx);
159 }
160
161 BatchTxProcessor::process_batch(&pool, batch_requests).await;
162
163 for response_rx in responses {
164 let result = timeout(Duration::from_millis(5), response_rx)
165 .await
166 .expect("Timeout waiting for response")
167 .expect("Response channel was closed unexpectedly");
168 assert!(result.is_ok());
169 }
170 }
171
172 #[tokio::test]
173 async fn test_process_batch_mixed_origins() {
174 let pool = testing_pool();
175
176 let mut batch_requests = Vec::new();
177 let mut responses = Vec::new();
178
179 for (nonce, origin) in [
180 (0, TransactionOrigin::Local),
181 (1, TransactionOrigin::External),
182 (2, TransactionOrigin::Private),
183 ] {
184 let tx = MockTransaction::legacy().with_nonce(nonce).with_gas_price(100);
185 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
186
187 batch_requests.push(BatchTxRequest::new(origin, tx, response_tx));
188 responses.push(response_rx);
189 }
190
191 BatchTxProcessor::process_batch(&pool, batch_requests).await;
192
193 for response_rx in responses {
194 let result = timeout(Duration::from_millis(5), response_rx)
195 .await
196 .expect("Timeout waiting for response")
197 .expect("Response channel was closed unexpectedly");
198 assert!(result.is_ok());
199 }
200 }
201
202 #[tokio::test]
203 async fn test_batch_processor() {
204 let pool = testing_pool();
205 let (processor, request_tx) = BatchTxProcessor::new(pool.clone(), 1000);
206
207 let handle = tokio::spawn(processor);
209
210 let mut responses = Vec::new();
211
212 for i in 0..50 {
213 let tx = MockTransaction::legacy().with_nonce(i).with_gas_price(100);
214 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
215
216 request_tx
217 .send(BatchTxRequest::new(TransactionOrigin::Local, tx, response_tx))
218 .expect("Could not send batch tx");
219 responses.push(response_rx);
220 }
221
222 tokio::time::sleep(Duration::from_millis(10)).await;
223
224 for rx in responses {
225 let result = timeout(Duration::from_millis(10), rx)
226 .await
227 .expect("Timeout waiting for response")
228 .expect("Response channel was closed unexpectedly");
229 assert!(result.is_ok());
230 }
231
232 drop(request_tx);
233 handle.abort();
234 }
235
236 #[tokio::test]
237 async fn test_add_transaction() {
238 let pool = testing_pool();
239 let (processor, request_tx) = BatchTxProcessor::new(pool.clone(), 1000);
240
241 let handle = tokio::spawn(processor);
243
244 let mut results = Vec::new();
245 for i in 0..10 {
246 let tx = MockTransaction::legacy().with_nonce(i).with_gas_price(100);
247 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
248 let request = BatchTxRequest::new(TransactionOrigin::Local, tx, response_tx);
249 request_tx.send(request).expect("Could not send batch tx");
250 results.push(response_rx);
251 }
252
253 for res in results {
254 let result = timeout(Duration::from_millis(10), res)
255 .await
256 .expect("Timeout waiting for transaction result");
257 assert!(result.is_ok());
258 }
259
260 handle.abort();
261 }
262
263 #[tokio::test]
264 async fn test_max_batch_size() {
265 let pool = testing_pool();
266 let max_batch_size = 10;
267 let (processor, request_tx) = BatchTxProcessor::new(pool.clone(), max_batch_size);
268
269 let handle = tokio::spawn(processor);
271
272 let mut futures = FuturesUnordered::new();
273 for i in 0..max_batch_size {
274 let tx = MockTransaction::legacy().with_nonce(i as u64).with_gas_price(100);
275 let (response_tx, response_rx) = tokio::sync::oneshot::channel();
276 let request = BatchTxRequest::new(TransactionOrigin::Local, tx, response_tx);
277 let request_tx_clone = request_tx.clone();
278
279 let tx_fut = async move {
280 request_tx_clone.send(request).expect("Could not send batch tx");
281 response_rx.await.expect("Could not receive batch response")
282 };
283 futures.push(tx_fut);
284 }
285
286 while let Some(result) = timeout(Duration::from_millis(5), futures.next())
287 .await
288 .expect("Timeout waiting for transaction result")
289 {
290 assert!(result.is_ok());
291 }
292
293 handle.abort();
294 }
295}