1#![doc(
11 html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png",
12 html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256",
13 issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/"
14)]
15#![cfg_attr(not(test), warn(unused_crate_dependencies))]
16#![cfg_attr(docsrs, feature(doc_cfg))]
17
18mod implementation;
19pub mod lockfile;
20#[cfg(feature = "mdbx")]
21mod metrics;
22pub mod static_file;
23#[cfg(feature = "mdbx")]
24mod utils;
25pub mod version;
26
27#[cfg(feature = "mdbx")]
28pub mod mdbx;
29
30pub use reth_storage_errors::db::{DatabaseError, DatabaseWriteOperation};
31#[cfg(feature = "mdbx")]
32pub use utils::is_database_empty;
33
34#[cfg(feature = "mdbx")]
35pub use mdbx::{create_db, init_db, open_db, open_db_read_only, DatabaseEnv, DatabaseEnvKind};
36
37pub use models::ClientVersion;
38pub use reth_db_api::*;
39
40#[cfg(any(test, feature = "test-utils"))]
42pub mod test_utils {
43 use super::*;
44 use crate::mdbx::DatabaseArguments;
45 use parking_lot::RwLock;
46 use reth_db_api::{
47 database::Database, database_metrics::DatabaseMetrics, models::ClientVersion,
48 };
49 use reth_fs_util;
50 use reth_libmdbx::MaxReadTransactionDuration;
51 use std::{
52 fmt::Formatter,
53 path::{Path, PathBuf},
54 sync::Arc,
55 };
56 use tempfile::TempDir;
57
58 pub const ERROR_DB_OPEN: &str = "could not open the database file";
60 pub const ERROR_DB_CREATION: &str = "could not create the database file";
62 pub const ERROR_STATIC_FILES_CREATION: &str = "could not create the static file path";
64 pub const ERROR_TABLE_CREATION: &str = "could not create tables in the database";
66 pub const ERROR_TEMPDIR: &str = "could not create a temporary directory";
68
69 pub struct TempDatabase<DB> {
71 db: Option<DB>,
72 path: PathBuf,
73 pre_tx_hook: RwLock<Box<dyn Fn() + Send + Sync>>,
75 post_tx_hook: RwLock<Box<dyn Fn() + Send + Sync>>,
77 }
78
79 impl<DB: std::fmt::Debug> std::fmt::Debug for TempDatabase<DB> {
80 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("TempDatabase").field("db", &self.db).field("path", &self.path).finish()
82 }
83 }
84
85 impl<DB> Drop for TempDatabase<DB> {
86 fn drop(&mut self) {
87 if let Some(db) = self.db.take() {
88 drop(db);
89 let _ = reth_fs_util::remove_dir_all(&self.path);
90 }
91 }
92 }
93
94 impl<DB> TempDatabase<DB> {
95 pub fn new(db: DB, path: PathBuf) -> Self {
97 Self {
98 db: Some(db),
99 path,
100 pre_tx_hook: RwLock::new(Box::new(|| ())),
101 post_tx_hook: RwLock::new(Box::new(|| ())),
102 }
103 }
104
105 pub const fn db(&self) -> &DB {
107 self.db.as_ref().unwrap()
108 }
109
110 pub fn path(&self) -> &Path {
112 &self.path
113 }
114
115 pub fn into_inner_db(mut self) -> DB {
117 self.db.take().unwrap() }
119
120 pub fn set_pre_transaction_hook(&self, hook: Box<dyn Fn() + Send + Sync>) {
122 let mut db_hook = self.pre_tx_hook.write();
123 *db_hook = hook;
124 }
125
126 pub fn set_post_transaction_hook(&self, hook: Box<dyn Fn() + Send + Sync>) {
128 let mut db_hook = self.post_tx_hook.write();
129 *db_hook = hook;
130 }
131 }
132
133 impl<DB: Database> Database for TempDatabase<DB> {
134 type TX = <DB as Database>::TX;
135 type TXMut = <DB as Database>::TXMut;
136 fn tx(&self) -> Result<Self::TX, DatabaseError> {
137 self.pre_tx_hook.read()();
138 let tx = self.db().tx()?;
139 self.post_tx_hook.read()();
140 Ok(tx)
141 }
142
143 fn tx_mut(&self) -> Result<Self::TXMut, DatabaseError> {
144 self.db().tx_mut()
145 }
146 }
147
148 impl<DB: DatabaseMetrics> DatabaseMetrics for TempDatabase<DB> {
149 fn report_metrics(&self) {
150 self.db().report_metrics()
151 }
152 }
153
154 #[track_caller]
156 pub fn create_test_static_files_dir() -> (TempDir, PathBuf) {
157 let temp_dir = TempDir::with_prefix("reth-test-static-").expect(ERROR_TEMPDIR);
158 let path = temp_dir.path().to_path_buf();
159 (temp_dir, path)
160 }
161
162 #[track_caller]
164 pub fn create_test_rocksdb_dir() -> (TempDir, PathBuf) {
165 let temp_dir = TempDir::with_prefix("reth-test-rocksdb-").expect(ERROR_TEMPDIR);
166 let path = temp_dir.path().to_path_buf();
167 (temp_dir, path)
168 }
169
170 pub fn tempdir_path() -> PathBuf {
172 let builder = tempfile::Builder::new().prefix("reth-test-").rand_bytes(8).tempdir();
173 builder.expect(ERROR_TEMPDIR).keep()
174 }
175
176 #[track_caller]
178 pub fn create_test_rw_db() -> Arc<TempDatabase<DatabaseEnv>> {
179 let path = tempdir_path();
180 let emsg = format!("{ERROR_DB_CREATION}: {path:?}");
181
182 let db = init_db(
183 &path,
184 DatabaseArguments::new(ClientVersion::default())
185 .with_max_read_transaction_duration(Some(MaxReadTransactionDuration::Unbounded)),
186 )
187 .expect(&emsg);
188
189 Arc::new(TempDatabase::new(db, path))
190 }
191
192 #[track_caller]
194 pub fn create_test_rw_db_with_path<P: AsRef<Path>>(path: P) -> Arc<TempDatabase<DatabaseEnv>> {
195 let path = path.as_ref().to_path_buf();
196 let emsg = format!("{ERROR_DB_CREATION}: {path:?}");
197 let db = init_db(
198 path.as_path(),
199 DatabaseArguments::new(ClientVersion::default())
200 .with_max_read_transaction_duration(Some(MaxReadTransactionDuration::Unbounded)),
201 )
202 .expect(&emsg);
203 Arc::new(TempDatabase::new(db, path))
204 }
205
206 #[track_caller]
208 pub fn create_test_ro_db() -> Arc<TempDatabase<DatabaseEnv>> {
209 let args = DatabaseArguments::new(ClientVersion::default())
210 .with_max_read_transaction_duration(Some(MaxReadTransactionDuration::Unbounded));
211
212 let path = tempdir_path();
213 let emsg = format!("{ERROR_DB_CREATION}: {path:?}");
214 {
215 init_db(path.as_path(), args.clone()).expect(&emsg);
216 }
217 let db = open_db_read_only(path.as_path(), args).expect(ERROR_DB_OPEN);
218 Arc::new(TempDatabase::new(db, path))
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use crate::{
225 init_db,
226 mdbx::DatabaseArguments,
227 open_db, tables,
228 version::{db_version_file_path, DatabaseVersionError},
229 };
230 use assert_matches::assert_matches;
231 use reth_db_api::{
232 cursor::DbCursorRO, database::Database, models::ClientVersion, transaction::DbTx,
233 };
234 use reth_libmdbx::MaxReadTransactionDuration;
235 use std::time::Duration;
236 use tempfile::tempdir;
237
238 #[test]
239 fn db_version() {
240 let path = tempdir().unwrap();
241
242 let args = DatabaseArguments::new(ClientVersion::default())
243 .with_max_read_transaction_duration(Some(MaxReadTransactionDuration::Unbounded));
244
245 {
247 let db = init_db(&path, args.clone());
248 assert_matches!(db, Ok(_));
249 }
250
251 {
253 let db = init_db(&path, args.clone());
254 assert_matches!(db, Ok(_));
255 }
256
257 {
259 reth_fs_util::write(path.path().join(db_version_file_path(&path)), "invalid-version")
260 .unwrap();
261 let db = init_db(&path, args.clone());
262 assert!(db.is_err());
263 assert_matches!(
264 db.unwrap_err().downcast_ref::<DatabaseVersionError>(),
265 Some(DatabaseVersionError::MalformedFile)
266 )
267 }
268
269 {
271 reth_fs_util::write(path.path().join(db_version_file_path(&path)), "0").unwrap();
272 let db = init_db(&path, args);
273 assert!(db.is_err());
274 assert_matches!(
275 db.unwrap_err().downcast_ref::<DatabaseVersionError>(),
276 Some(DatabaseVersionError::VersionMismatch { version: 0 })
277 )
278 }
279 }
280
281 #[test]
282 fn db_client_version() {
283 let path = tempdir().unwrap();
284
285 {
287 let db = init_db(&path, DatabaseArguments::new(ClientVersion::default())).unwrap();
288 let tx = db.tx().unwrap();
289 let mut cursor = tx.cursor_read::<tables::VersionHistory>().unwrap();
290 assert_matches!(cursor.first(), Ok(None));
291 }
292
293 let first_version = ClientVersion { version: String::from("v1"), ..Default::default() };
295 {
296 let db = init_db(&path, DatabaseArguments::new(first_version.clone())).unwrap();
297 let tx = db.tx().unwrap();
298 let mut cursor = tx.cursor_read::<tables::VersionHistory>().unwrap();
299 assert_eq!(
300 cursor
301 .walk_range(..)
302 .unwrap()
303 .map(|x| x.map(|(_, v)| v))
304 .collect::<Result<Vec<_>, _>>()
305 .unwrap(),
306 vec![first_version.clone()]
307 );
308 }
309
310 {
312 let db = init_db(&path, DatabaseArguments::new(first_version.clone())).unwrap();
313 let tx = db.tx().unwrap();
314 let mut cursor = tx.cursor_read::<tables::VersionHistory>().unwrap();
315 assert_eq!(
316 cursor
317 .walk_range(..)
318 .unwrap()
319 .map(|x| x.map(|(_, v)| v))
320 .collect::<Result<Vec<_>, _>>()
321 .unwrap(),
322 vec![first_version.clone()]
323 );
324 }
325
326 std::thread::sleep(Duration::from_secs(1));
328 let second_version = ClientVersion { version: String::from("v2"), ..Default::default() };
329 {
330 let db = init_db(&path, DatabaseArguments::new(second_version.clone())).unwrap();
331 let tx = db.tx().unwrap();
332 let mut cursor = tx.cursor_read::<tables::VersionHistory>().unwrap();
333 assert_eq!(
334 cursor
335 .walk_range(..)
336 .unwrap()
337 .map(|x| x.map(|(_, v)| v))
338 .collect::<Result<Vec<_>, _>>()
339 .unwrap(),
340 vec![first_version.clone(), second_version.clone()]
341 );
342 }
343
344 std::thread::sleep(Duration::from_secs(1));
346 let third_version = ClientVersion { version: String::from("v3"), ..Default::default() };
347 {
348 let db = open_db(path.path(), DatabaseArguments::new(third_version.clone())).unwrap();
349 let tx = db.tx().unwrap();
350 let mut cursor = tx.cursor_read::<tables::VersionHistory>().unwrap();
351 assert_eq!(
352 cursor
353 .walk_range(..)
354 .unwrap()
355 .map(|x| x.map(|(_, v)| v))
356 .collect::<Result<Vec<_>, _>>()
357 .unwrap(),
358 vec![first_version, second_version, third_version]
359 );
360 }
361 }
362}