#![cfg_attr(feature = "disable-lock", allow(dead_code))]
use reth_storage_errors::lockfile::StorageLockError;
use reth_tracing::tracing::error;
use std::{
path::{Path, PathBuf},
process,
sync::{Arc, OnceLock},
};
use sysinfo::{ProcessRefreshKind, RefreshKind, System};
const LOCKFILE_NAME: &str = "lock";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StorageLock(Arc<StorageLockInner>);
impl StorageLock {
pub fn try_acquire(path: &Path) -> Result<Self, StorageLockError> {
#[cfg(feature = "disable-lock")]
{
let file_path = path.join(LOCKFILE_NAME);
Ok(Self(Arc::new(StorageLockInner { file_path })))
}
#[cfg(not(feature = "disable-lock"))]
Self::try_acquire_file_lock(path)
}
#[cfg(any(test, not(feature = "disable-lock")))]
fn try_acquire_file_lock(path: &Path) -> Result<Self, StorageLockError> {
let file_path = path.join(LOCKFILE_NAME);
if let Some(process_lock) = ProcessUID::parse(&file_path)? {
if process_lock.pid != (process::id() as usize) && process_lock.is_active() {
error!(
target: "reth::db::lockfile",
path = ?file_path,
pid = process_lock.pid,
start_time = process_lock.start_time,
"Storage lock already taken."
);
return Err(StorageLockError::Taken(process_lock.pid))
}
}
Ok(Self(Arc::new(StorageLockInner::new(file_path)?)))
}
}
impl Drop for StorageLock {
fn drop(&mut self) {
if Arc::strong_count(&self.0) == 1 && self.0.file_path.exists() {
if let Err(err) = reth_fs_util::remove_file(&self.0.file_path) {
error!(%err, "Failed to delete lock file");
}
}
}
}
#[derive(Debug, PartialEq, Eq)]
struct StorageLockInner {
file_path: PathBuf,
}
impl StorageLockInner {
fn new(file_path: PathBuf) -> Result<Self, StorageLockError> {
if let Some(parent) = file_path.parent() {
reth_fs_util::create_dir_all(parent)?;
}
ProcessUID::own().write(&file_path)?;
Ok(Self { file_path })
}
}
#[derive(Clone, Debug)]
struct ProcessUID {
pid: usize,
start_time: u64,
}
impl ProcessUID {
fn new(pid: usize) -> Option<Self> {
let mut system = System::new();
let pid2 = sysinfo::Pid::from(pid);
system.refresh_processes_specifics(
sysinfo::ProcessesToUpdate::Some(&[pid2]),
ProcessRefreshKind::new(),
);
system.process(pid2).map(|process| Self { pid, start_time: process.start_time() })
}
fn own() -> Self {
static CACHE: OnceLock<ProcessUID> = OnceLock::new();
CACHE.get_or_init(|| Self::new(process::id() as usize).expect("own process")).clone()
}
fn parse(path: &Path) -> Result<Option<Self>, StorageLockError> {
if path.exists() {
if let Ok(contents) = reth_fs_util::read_to_string(path) {
let mut lines = contents.lines();
if let (Some(Ok(pid)), Some(Ok(start_time))) = (
lines.next().map(str::trim).map(str::parse),
lines.next().map(str::trim).map(str::parse),
) {
return Ok(Some(Self { pid, start_time }));
}
}
}
Ok(None)
}
fn is_active(&self) -> bool {
System::new_with_specifics(RefreshKind::new().with_processes(ProcessRefreshKind::new()))
.process(self.pid.into())
.is_some_and(|p| p.start_time() == self.start_time)
}
fn write(&self, path: &Path) -> Result<(), StorageLockError> {
Ok(reth_fs_util::write(path, format!("{}\n{}", self.pid, self.start_time))?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, MutexGuard, OnceLock};
static SERIAL: OnceLock<Mutex<()>> = OnceLock::new();
fn serial_lock() -> MutexGuard<'static, ()> {
SERIAL.get_or_init(|| Mutex::new(())).lock().unwrap()
}
#[test]
fn test_lock() {
let _guard = serial_lock();
let temp_dir = tempfile::tempdir().unwrap();
let lock = StorageLock::try_acquire_file_lock(temp_dir.path()).unwrap();
assert_eq!(Ok(lock.clone()), StorageLock::try_acquire_file_lock(temp_dir.path()));
let lock_file = temp_dir.path().join(LOCKFILE_NAME);
let mut fake_pid = 1337;
let system = System::new_all();
while system.process(fake_pid.into()).is_some() {
fake_pid += 1;
}
ProcessUID { pid: fake_pid, start_time: u64::MAX }.write(&lock_file).unwrap();
assert_eq!(Ok(lock.clone()), StorageLock::try_acquire_file_lock(temp_dir.path()));
let mut pid_1 = ProcessUID::new(1).unwrap();
pid_1.write(&lock_file).unwrap();
assert_eq!(
Err(StorageLockError::Taken(1)),
StorageLock::try_acquire_file_lock(temp_dir.path())
);
pid_1.start_time += 1;
pid_1.write(&lock_file).unwrap();
assert_eq!(Ok(lock), StorageLock::try_acquire_file_lock(temp_dir.path()));
}
#[test]
fn test_drop_lock() {
let _guard = serial_lock();
let temp_dir = tempfile::tempdir().unwrap();
let lock_file = temp_dir.path().join(LOCKFILE_NAME);
let lock = StorageLock::try_acquire_file_lock(temp_dir.path()).unwrap();
assert!(lock_file.exists());
drop(lock);
assert!(!lock_file.exists());
}
}