Skip to main content

reth_db/implementation/mdbx/
cursor.rs

1//! Cursor wrapper for libmdbx-sys.
2
3use super::utils::*;
4use crate::{
5    metrics::{DatabaseEnvMetrics, Operation},
6    DatabaseError,
7};
8use reth_db_api::{
9    common::{PairResult, ValueOnlyResult},
10    cursor::{
11        DbCursorRO, DbCursorRW, DbDupCursorRO, DbDupCursorRW, DupWalker, RangeWalker,
12        ReverseWalker, Walker,
13    },
14    table::{Compress, Decode, Decompress, DupSort, Encode, IntoVec, Table},
15};
16use reth_libmdbx::{Error as MDBXError, TransactionKind, WriteFlags, RO, RW};
17use reth_storage_errors::db::{DatabaseErrorInfo, DatabaseWriteError, DatabaseWriteOperation};
18use std::{borrow::Cow, collections::Bound, marker::PhantomData, ops::RangeBounds, sync::Arc};
19
20/// Read only Cursor.
21pub type CursorRO<T> = Cursor<RO, T>;
22/// Read write cursor.
23pub type CursorRW<T> = Cursor<RW, T>;
24
25/// Cursor wrapper to access KV items.
26#[derive(Debug)]
27pub struct Cursor<K: TransactionKind, T: Table> {
28    /// Inner `libmdbx` cursor.
29    pub(crate) inner: reth_libmdbx::Cursor<K>,
30    /// Cache buffer that receives compressed values.
31    buf: Vec<u8>,
32    /// Reference to metric handles in the DB environment. If `None`, metrics are not recorded.
33    metrics: Option<Arc<DatabaseEnvMetrics>>,
34    /// Phantom data to enforce encoding/decoding.
35    _dbi: PhantomData<T>,
36}
37
38impl<K: TransactionKind, T: Table> Cursor<K, T> {
39    pub(crate) const fn new_with_metrics(
40        inner: reth_libmdbx::Cursor<K>,
41        metrics: Option<Arc<DatabaseEnvMetrics>>,
42    ) -> Self {
43        Self { inner, buf: Vec::new(), metrics, _dbi: PhantomData }
44    }
45
46    /// If `self.metrics` is `Some(...)`, record a metric with the provided operation and value
47    /// size.
48    ///
49    /// Otherwise, just execute the closure.
50    fn execute_with_operation_metric<R>(
51        &mut self,
52        operation: Operation,
53        value_size: Option<usize>,
54        f: impl FnOnce(&mut Self) -> R,
55    ) -> R {
56        if let Some(metrics) = self.metrics.clone() {
57            metrics.record_operation(T::NAME, operation, value_size, || f(self))
58        } else {
59            f(self)
60        }
61    }
62}
63
64/// Decodes a `(key, value)` pair from the database.
65#[expect(clippy::type_complexity)]
66pub fn decode<T>(
67    res: Result<Option<(Cow<'_, [u8]>, Cow<'_, [u8]>)>, impl Into<DatabaseErrorInfo>>,
68) -> PairResult<T>
69where
70    T: Table,
71    T::Key: Decode,
72    T::Value: Decompress,
73{
74    res.map_err(|e| DatabaseError::Read(e.into()))?.map(decoder::<T>).transpose()
75}
76
77/// Some types don't support compression (eg. B256), and we don't want to be copying them to the
78/// allocated buffer when we can just use their reference.
79macro_rules! compress_to_buf_or_ref {
80    ($self:expr, $value:expr) => {
81        if let Some(value) = $value.uncompressable_ref() {
82            Some(value)
83        } else {
84            $self.buf.clear();
85            $value.compress_to_buf(&mut $self.buf);
86            None
87        }
88    };
89}
90
91impl<K: TransactionKind, T: Table> DbCursorRO<T> for Cursor<K, T> {
92    fn first(&mut self) -> PairResult<T> {
93        decode::<T>(self.inner.first())
94    }
95
96    fn seek_exact(&mut self, key: <T as Table>::Key) -> PairResult<T> {
97        decode::<T>(self.inner.set_key(key.encode().as_ref()))
98    }
99
100    fn seek(&mut self, key: <T as Table>::Key) -> PairResult<T> {
101        decode::<T>(self.inner.set_range(key.encode().as_ref()))
102    }
103
104    fn next(&mut self) -> PairResult<T> {
105        decode::<T>(self.inner.next())
106    }
107
108    fn prev(&mut self) -> PairResult<T> {
109        decode::<T>(self.inner.prev())
110    }
111
112    fn last(&mut self) -> PairResult<T> {
113        decode::<T>(self.inner.last())
114    }
115
116    fn current(&mut self) -> PairResult<T> {
117        decode::<T>(self.inner.get_current())
118    }
119
120    fn walk(&mut self, start_key: Option<T::Key>) -> Result<Walker<'_, T, Self>, DatabaseError> {
121        let start = if let Some(start_key) = start_key {
122            decode::<T>(self.inner.set_range(start_key.encode().as_ref())).transpose()
123        } else {
124            self.first().transpose()
125        };
126
127        Ok(Walker::new(self, start))
128    }
129
130    fn walk_range(
131        &mut self,
132        range: impl RangeBounds<T::Key>,
133    ) -> Result<RangeWalker<'_, T, Self>, DatabaseError> {
134        let start = match range.start_bound().cloned() {
135            Bound::Included(key) => self.inner.set_range(key.encode().as_ref()),
136            Bound::Excluded(_key) => {
137                unreachable!("Rust doesn't allow for Bound::Excluded in starting bounds");
138            }
139            Bound::Unbounded => self.inner.first(),
140        };
141        let start = decode::<T>(start).transpose();
142        Ok(RangeWalker::new(self, start, range.end_bound().cloned()))
143    }
144
145    fn walk_back(
146        &mut self,
147        start_key: Option<T::Key>,
148    ) -> Result<ReverseWalker<'_, T, Self>, DatabaseError> {
149        let start = if let Some(start_key) = start_key {
150            decode::<T>(self.inner.set_range(start_key.encode().as_ref()))
151        } else {
152            self.last()
153        }
154        .transpose();
155
156        Ok(ReverseWalker::new(self, start))
157    }
158}
159
160impl<K: TransactionKind, T: DupSort> DbDupCursorRO<T> for Cursor<K, T> {
161    /// Returns the previous `(key, value)` pair of a DUPSORT table.
162    fn prev_dup(&mut self) -> PairResult<T> {
163        decode::<T>(self.inner.prev_dup())
164    }
165
166    /// Returns the next `(key, value)` pair of a DUPSORT table.
167    fn next_dup(&mut self) -> PairResult<T> {
168        decode::<T>(self.inner.next_dup())
169    }
170
171    /// Returns the last `value` of the current duplicate `key`.
172    fn last_dup(&mut self) -> ValueOnlyResult<T> {
173        self.inner
174            .last_dup()
175            .map_err(|e| DatabaseError::Read(e.into()))?
176            .map(decode_one::<T>)
177            .transpose()
178    }
179
180    /// Returns the next `(key, value)` pair skipping the duplicates.
181    fn next_no_dup(&mut self) -> PairResult<T> {
182        decode::<T>(self.inner.next_nodup())
183    }
184
185    /// Returns the next `value` of a duplicate `key`.
186    fn next_dup_val(&mut self) -> ValueOnlyResult<T> {
187        self.inner
188            .next_dup()
189            .map_err(|e| DatabaseError::Read(e.into()))?
190            .map(decode_value::<T>)
191            .transpose()
192    }
193
194    fn seek_by_key_subkey(
195        &mut self,
196        key: <T as Table>::Key,
197        subkey: <T as DupSort>::SubKey,
198    ) -> ValueOnlyResult<T> {
199        self.inner
200            .get_both_range(key.encode().as_ref(), subkey.encode().as_ref())
201            .map_err(|e| DatabaseError::Read(e.into()))?
202            .map(decode_one::<T>)
203            .transpose()
204    }
205
206    /// Depending on its arguments, returns an iterator starting at:
207    /// - Some(key), Some(subkey): a `key` item whose data is >= than `subkey`
208    /// - Some(key), None: first item of a specified `key`
209    /// - None, Some(subkey): like first case, but in the first key
210    /// - None, None: first item in the table of a DUPSORT table.
211    fn walk_dup(
212        &mut self,
213        key: Option<T::Key>,
214        subkey: Option<T::SubKey>,
215    ) -> Result<DupWalker<'_, T, Self>, DatabaseError> {
216        let start = match (key, subkey) {
217            (Some(key), Some(subkey)) => {
218                let encoded_key = key.encode();
219                self.inner
220                    .get_both_range(encoded_key.as_ref(), subkey.encode().as_ref())
221                    .map_err(|e| DatabaseError::Read(e.into()))?
222                    .map(|val| decoder::<T>((Cow::Borrowed(encoded_key.as_ref()), val)))
223            }
224            (Some(key), None) => {
225                let encoded_key = key.encode();
226                self.inner
227                    .set(encoded_key.as_ref())
228                    .map_err(|e| DatabaseError::Read(e.into()))?
229                    .map(|val| decoder::<T>((Cow::Borrowed(encoded_key.as_ref()), val)))
230            }
231            (None, Some(subkey)) => {
232                if let Some((key, _)) = self.first()? {
233                    let encoded_key = key.encode();
234                    self.inner
235                        .get_both_range(encoded_key.as_ref(), subkey.encode().as_ref())
236                        .map_err(|e| DatabaseError::Read(e.into()))?
237                        .map(|val| decoder::<T>((Cow::Borrowed(encoded_key.as_ref()), val)))
238                } else {
239                    Some(Err(DatabaseError::Read(MDBXError::NotFound.into())))
240                }
241            }
242            (None, None) => self.first().transpose(),
243        };
244
245        Ok(DupWalker::<'_, T, Self> { cursor: self, start })
246    }
247}
248
249impl<T: Table> DbCursorRW<T> for Cursor<RW, T> {
250    /// Database operation that will update an existing row if a specified value already
251    /// exists in a table, and insert a new row if the specified value doesn't already exist
252    ///
253    /// For a DUPSORT table, `upsert` will not actually update-or-insert. If the key already exists,
254    /// it will append the value to the subkey, even if the subkeys are the same. So if you want
255    /// to properly upsert, you'll need to `seek_exact` & `delete_current` if the key+subkey was
256    /// found, before calling `upsert`.
257    fn upsert(&mut self, key: T::Key, value: &T::Value) -> Result<(), DatabaseError> {
258        let key = key.encode();
259        let value = compress_to_buf_or_ref!(self, value);
260        self.execute_with_operation_metric(
261            Operation::CursorUpsert,
262            Some(value.unwrap_or(&self.buf).len()),
263            |this| {
264                this.inner
265                    .put(key.as_ref(), value.unwrap_or(&this.buf), WriteFlags::UPSERT)
266                    .map_err(|e| {
267                        DatabaseWriteError {
268                            info: e.into(),
269                            operation: DatabaseWriteOperation::CursorUpsert,
270                            table_name: T::NAME,
271                            key: key.into_vec(),
272                        }
273                        .into()
274                    })
275            },
276        )
277    }
278
279    fn insert(&mut self, key: T::Key, value: &T::Value) -> Result<(), DatabaseError> {
280        let key = key.encode();
281        let value = compress_to_buf_or_ref!(self, value);
282        self.execute_with_operation_metric(
283            Operation::CursorInsert,
284            Some(value.unwrap_or(&self.buf).len()),
285            |this| {
286                this.inner
287                    .put(key.as_ref(), value.unwrap_or(&this.buf), WriteFlags::NO_OVERWRITE)
288                    .map_err(|e| {
289                        DatabaseWriteError {
290                            info: e.into(),
291                            operation: DatabaseWriteOperation::CursorInsert,
292                            table_name: T::NAME,
293                            key: key.into_vec(),
294                        }
295                        .into()
296                    })
297            },
298        )
299    }
300
301    /// Appends the data to the end of the table. Consequently, the append operation
302    /// will fail if the inserted key is less than the last table key
303    fn append(&mut self, key: T::Key, value: &T::Value) -> Result<(), DatabaseError> {
304        let key = key.encode();
305        let value = compress_to_buf_or_ref!(self, value);
306        self.execute_with_operation_metric(
307            Operation::CursorAppend,
308            Some(value.unwrap_or(&self.buf).len()),
309            |this| {
310                this.inner
311                    .put(key.as_ref(), value.unwrap_or(&this.buf), WriteFlags::APPEND)
312                    .map_err(|e| {
313                        DatabaseWriteError {
314                            info: e.into(),
315                            operation: DatabaseWriteOperation::CursorAppend,
316                            table_name: T::NAME,
317                            key: key.into_vec(),
318                        }
319                        .into()
320                    })
321            },
322        )
323    }
324
325    fn delete_current(&mut self) -> Result<(), DatabaseError> {
326        self.execute_with_operation_metric(Operation::CursorDeleteCurrent, None, |this| {
327            this.inner.del(WriteFlags::CURRENT).map_err(|e| DatabaseError::Delete(e.into()))
328        })
329    }
330}
331
332impl<T: DupSort> DbDupCursorRW<T> for Cursor<RW, T> {
333    fn delete_current_duplicates(&mut self) -> Result<(), DatabaseError> {
334        self.execute_with_operation_metric(Operation::CursorDeleteCurrentDuplicates, None, |this| {
335            this.inner.del(WriteFlags::NO_DUP_DATA).map_err(|e| DatabaseError::Delete(e.into()))
336        })
337    }
338
339    fn append_dup(&mut self, key: T::Key, value: T::Value) -> Result<(), DatabaseError> {
340        let key = key.encode();
341        let value = compress_to_buf_or_ref!(self, value);
342        self.execute_with_operation_metric(
343            Operation::CursorAppendDup,
344            Some(value.unwrap_or(&self.buf).len()),
345            |this| {
346                this.inner
347                    .put(key.as_ref(), value.unwrap_or(&this.buf), WriteFlags::APPEND_DUP)
348                    .map_err(|e| {
349                        DatabaseWriteError {
350                            info: e.into(),
351                            operation: DatabaseWriteOperation::CursorAppendDup,
352                            table_name: T::NAME,
353                            key: key.into_vec(),
354                        }
355                        .into()
356                    })
357            },
358        )
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use crate::{
365        mdbx::{DatabaseArguments, DatabaseEnv, DatabaseEnvKind},
366        tables::StorageChangeSets,
367        Database,
368    };
369    use alloy_primitives::{address, Address, B256, U256};
370    use reth_db_api::{
371        cursor::{DbCursorRO, DbDupCursorRW},
372        models::{BlockNumberAddress, ClientVersion},
373        table::TableImporter,
374        transaction::{DbTx, DbTxMut},
375    };
376    use reth_primitives_traits::StorageEntry;
377    use tempfile::TempDir;
378
379    fn create_test_db() -> DatabaseEnv {
380        let path = TempDir::new().unwrap();
381        let mut db = DatabaseEnv::open(
382            path.path(),
383            DatabaseEnvKind::RW,
384            DatabaseArguments::new(ClientVersion::default()),
385        )
386        .unwrap();
387        db.create_tables().unwrap();
388        db
389    }
390
391    #[test]
392    fn test_import_table_with_range_works_on_dupsort() {
393        let addr1 = address!("0000000000000000000000000000000000000001");
394        let addr2 = address!("0000000000000000000000000000000000000002");
395        let addr3 = address!("0000000000000000000000000000000000000003");
396        let source_db = create_test_db();
397        let target_db = create_test_db();
398        let test_data = vec![
399            (
400                BlockNumberAddress((100, addr1)),
401                StorageEntry { key: B256::with_last_byte(1), value: U256::from(100) },
402            ),
403            (
404                BlockNumberAddress((100, addr1)),
405                StorageEntry { key: B256::with_last_byte(2), value: U256::from(200) },
406            ),
407            (
408                BlockNumberAddress((100, addr1)),
409                StorageEntry { key: B256::with_last_byte(3), value: U256::from(300) },
410            ),
411            (
412                BlockNumberAddress((101, addr1)),
413                StorageEntry { key: B256::with_last_byte(1), value: U256::from(400) },
414            ),
415            (
416                BlockNumberAddress((101, addr2)),
417                StorageEntry { key: B256::with_last_byte(1), value: U256::from(500) },
418            ),
419            (
420                BlockNumberAddress((101, addr2)),
421                StorageEntry { key: B256::with_last_byte(2), value: U256::from(600) },
422            ),
423            (
424                BlockNumberAddress((102, addr3)),
425                StorageEntry { key: B256::with_last_byte(1), value: U256::from(700) },
426            ),
427        ];
428
429        // setup data
430        let tx = source_db.tx_mut().unwrap();
431        {
432            let mut cursor = tx.cursor_dup_write::<StorageChangeSets>().unwrap();
433            for (key, value) in &test_data {
434                cursor.append_dup(*key, *value).unwrap();
435            }
436        }
437        tx.commit().unwrap();
438
439        // import data from source db to target
440        let source_tx = source_db.tx().unwrap();
441        let target_tx = target_db.tx_mut().unwrap();
442
443        target_tx
444            .import_table_with_range::<StorageChangeSets, _>(
445                &source_tx,
446                Some(BlockNumberAddress((100, Address::ZERO))),
447                BlockNumberAddress((102, Address::repeat_byte(0xff))),
448            )
449            .unwrap();
450        target_tx.commit().unwrap();
451
452        // fetch all data from target db
453        let verify_tx = target_db.tx().unwrap();
454        let mut cursor = verify_tx.cursor_dup_read::<StorageChangeSets>().unwrap();
455        let copied: Vec<_> = cursor.walk(None).unwrap().collect::<Result<Vec<_>, _>>().unwrap();
456
457        // verify each entry matches the test data
458        assert_eq!(copied.len(), test_data.len(), "Should copy all entries including duplicates");
459        for ((copied_key, copied_value), (expected_key, expected_value)) in
460            copied.iter().zip(test_data.iter())
461        {
462            assert_eq!(copied_key, expected_key);
463            assert_eq!(copied_value, expected_value);
464        }
465    }
466}