1#![doc(
4 html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
5 html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
6 issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
7)]
8#![cfg_attr(not(test), warn(unused_crate_dependencies))]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use reth_tasks::{TaskExecutor, TaskManager};
14use std::{future::Future, pin::pin, sync::mpsc, time::Duration};
15use tracing::{debug, error, trace};
16
17#[derive(Debug)]
21#[non_exhaustive]
22pub struct CliRunner {
23 tokio_runtime: tokio::runtime::Runtime,
24}
25
26impl CliRunner {
27 pub fn try_default_runtime() -> Result<Self, std::io::Error> {
32 Ok(Self { tokio_runtime: tokio_runtime()? })
33 }
34
35 pub const fn from_runtime(tokio_runtime: tokio::runtime::Runtime) -> Self {
37 Self { tokio_runtime }
38 }
39
40 pub fn block_on<F, T>(&self, fut: F) -> T
42 where
43 F: Future<Output = T>,
44 {
45 self.tokio_runtime.block_on(fut)
46 }
47
48 pub fn run_command_until_exit<F, E>(
54 self,
55 command: impl FnOnce(CliContext) -> F,
56 ) -> Result<(), E>
57 where
58 F: Future<Output = Result<(), E>>,
59 E: Send + Sync + From<std::io::Error> + From<reth_tasks::PanickedTaskError> + 'static,
60 {
61 let AsyncCliRunner { context, mut task_manager, tokio_runtime } =
62 AsyncCliRunner::new(self.tokio_runtime);
63
64 let command_res = tokio_runtime.block_on(run_to_completion_or_panic(
66 &mut task_manager,
67 run_until_ctrl_c(command(context)),
68 ));
69
70 if command_res.is_err() {
71 error!(target: "reth::cli", "shutting down due to error");
72 } else {
73 debug!(target: "reth::cli", "shutting down gracefully");
74 task_manager.graceful_shutdown_with_timeout(Duration::from_secs(5));
78 }
79
80 let (tx, rx) = mpsc::channel();
85 std::thread::Builder::new()
86 .name("tokio-runtime-shutdown".to_string())
87 .spawn(move || {
88 drop(tokio_runtime);
89 let _ = tx.send(());
90 })
91 .unwrap();
92
93 let _ = rx.recv_timeout(Duration::from_secs(5)).inspect_err(|err| {
94 debug!(target: "reth::cli", %err, "tokio runtime shutdown timed out");
95 });
96
97 command_res
98 }
99
100 pub fn run_blocking_command_until_exit<F, E>(
104 self,
105 command: impl FnOnce(CliContext) -> F + Send + 'static,
106 ) -> Result<(), E>
107 where
108 F: Future<Output = Result<(), E>> + Send + 'static,
109 E: Send + Sync + From<std::io::Error> + From<reth_tasks::PanickedTaskError> + 'static,
110 {
111 let AsyncCliRunner { context, mut task_manager, tokio_runtime } =
112 AsyncCliRunner::new(self.tokio_runtime);
113
114 let handle = tokio_runtime.handle().clone();
116 let command_handle =
117 tokio_runtime.handle().spawn_blocking(move || handle.block_on(command(context)));
118
119 let command_res = tokio_runtime.block_on(run_to_completion_or_panic(
121 &mut task_manager,
122 run_until_ctrl_c(
123 async move { command_handle.await.expect("Failed to join blocking task") },
124 ),
125 ));
126
127 if command_res.is_err() {
128 error!(target: "reth::cli", "shutting down due to error");
129 } else {
130 debug!(target: "reth::cli", "shutting down gracefully");
131 task_manager.graceful_shutdown_with_timeout(Duration::from_secs(5));
132 }
133
134 let (tx, rx) = mpsc::channel();
136 std::thread::Builder::new()
137 .name("tokio-runtime-shutdown".to_string())
138 .spawn(move || {
139 drop(tokio_runtime);
140 let _ = tx.send(());
141 })
142 .unwrap();
143
144 let _ = rx.recv_timeout(Duration::from_secs(5)).inspect_err(|err| {
145 debug!(target: "reth::cli", %err, "tokio runtime shutdown timed out");
146 });
147
148 command_res
149 }
150
151 pub fn run_until_ctrl_c<F, E>(self, fut: F) -> Result<(), E>
153 where
154 F: Future<Output = Result<(), E>>,
155 E: Send + Sync + From<std::io::Error> + 'static,
156 {
157 self.tokio_runtime.block_on(run_until_ctrl_c(fut))?;
158 Ok(())
159 }
160
161 pub fn run_blocking_until_ctrl_c<F, E>(self, fut: F) -> Result<(), E>
166 where
167 F: Future<Output = Result<(), E>> + Send + 'static,
168 E: Send + Sync + From<std::io::Error> + 'static,
169 {
170 let tokio_runtime = self.tokio_runtime;
171 let handle = tokio_runtime.handle().clone();
172 let fut = tokio_runtime.handle().spawn_blocking(move || handle.block_on(fut));
173 tokio_runtime
174 .block_on(run_until_ctrl_c(async move { fut.await.expect("Failed to join task") }))?;
175
176 std::thread::Builder::new()
180 .name("tokio-runtime-shutdown".to_string())
181 .spawn(move || drop(tokio_runtime))
182 .unwrap();
183
184 Ok(())
185 }
186}
187
188struct AsyncCliRunner {
190 context: CliContext,
191 task_manager: TaskManager,
192 tokio_runtime: tokio::runtime::Runtime,
193}
194
195impl AsyncCliRunner {
198 fn new(tokio_runtime: tokio::runtime::Runtime) -> Self {
201 let task_manager = TaskManager::new(tokio_runtime.handle().clone());
202 let task_executor = task_manager.executor();
203 Self { context: CliContext { task_executor }, task_manager, tokio_runtime }
204 }
205}
206
207#[derive(Debug)]
209pub struct CliContext {
210 pub task_executor: TaskExecutor,
212}
213
214pub fn tokio_runtime() -> Result<tokio::runtime::Runtime, std::io::Error> {
217 tokio::runtime::Builder::new_multi_thread().enable_all().build()
218}
219
220async fn run_to_completion_or_panic<F, E>(tasks: &mut TaskManager, fut: F) -> Result<(), E>
224where
225 F: Future<Output = Result<(), E>>,
226 E: Send + Sync + From<reth_tasks::PanickedTaskError> + 'static,
227{
228 {
229 let fut = pin!(fut);
230 tokio::select! {
231 task_manager_result = tasks => {
232 if let Err(panicked_error) = task_manager_result {
233 return Err(panicked_error.into());
234 }
235 },
236 res = fut => res?,
237 }
238 }
239 Ok(())
240}
241
242async fn run_until_ctrl_c<F, E>(fut: F) -> Result<(), E>
246where
247 F: Future<Output = Result<(), E>>,
248 E: Send + Sync + 'static + From<std::io::Error>,
249{
250 let ctrl_c = tokio::signal::ctrl_c();
251
252 #[cfg(unix)]
253 {
254 let mut stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
255 let sigterm = stream.recv();
256 let sigterm = pin!(sigterm);
257 let ctrl_c = pin!(ctrl_c);
258 let fut = pin!(fut);
259
260 tokio::select! {
261 _ = ctrl_c => {
262 trace!(target: "reth::cli", "Received ctrl-c");
263 },
264 _ = sigterm => {
265 trace!(target: "reth::cli", "Received SIGTERM");
266 },
267 res = fut => res?,
268 }
269 }
270
271 #[cfg(not(unix))]
272 {
273 let ctrl_c = pin!(ctrl_c);
274 let fut = pin!(fut);
275
276 tokio::select! {
277 _ = ctrl_c => {
278 trace!(target: "reth::cli", "Received ctrl-c");
279 },
280 res = fut => res?,
281 }
282 }
283
284 Ok(())
285}