1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
use reth_network_api::Direction;
use reth_network_types::SessionLimits;

use super::ExceedsSessionLimit;

/// Keeps track of all sessions.
#[derive(Debug, Clone)]
pub struct SessionCounter {
    /// Limits to enforce.
    limits: SessionLimits,
    /// Number of pending incoming sessions.
    pending_inbound: u32,
    /// Number of pending outgoing sessions.
    pending_outbound: u32,
    /// Number of active inbound sessions.
    active_inbound: u32,
    /// Number of active outbound sessions.
    active_outbound: u32,
}

// === impl SessionCounter ===

impl SessionCounter {
    pub(crate) const fn new(limits: SessionLimits) -> Self {
        Self {
            limits,
            pending_inbound: 0,
            pending_outbound: 0,
            active_inbound: 0,
            active_outbound: 0,
        }
    }

    pub(crate) fn inc_pending_inbound(&mut self) {
        self.pending_inbound += 1;
    }

    pub(crate) fn inc_pending_outbound(&mut self) {
        self.pending_outbound += 1;
    }

    pub(crate) fn dec_pending(&mut self, direction: &Direction) {
        match direction {
            Direction::Outgoing(_) => {
                self.pending_outbound -= 1;
            }
            Direction::Incoming => {
                self.pending_inbound -= 1;
            }
        }
    }

    pub(crate) fn inc_active(&mut self, direction: &Direction) {
        match direction {
            Direction::Outgoing(_) => {
                self.active_outbound += 1;
            }
            Direction::Incoming => {
                self.active_inbound += 1;
            }
        }
    }

    pub(crate) fn dec_active(&mut self, direction: &Direction) {
        match direction {
            Direction::Outgoing(_) => {
                self.active_outbound -= 1;
            }
            Direction::Incoming => {
                self.active_inbound -= 1;
            }
        }
    }

    pub(crate) const fn ensure_pending_outbound(&self) -> Result<(), ExceedsSessionLimit> {
        Self::ensure(self.pending_outbound, self.limits.max_pending_outbound)
    }

    pub(crate) const fn ensure_pending_inbound(&self) -> Result<(), ExceedsSessionLimit> {
        Self::ensure(self.pending_inbound, self.limits.max_pending_inbound)
    }

    const fn ensure(current: u32, limit: Option<u32>) -> Result<(), ExceedsSessionLimit> {
        if let Some(limit) = limit {
            if current >= limit {
                return Err(ExceedsSessionLimit(limit))
            }
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_limits() {
        let mut limits = SessionCounter::new(SessionLimits::default().with_max_pending_inbound(2));
        assert!(limits.ensure_pending_outbound().is_ok());
        limits.inc_pending_inbound();
        assert!(limits.ensure_pending_inbound().is_ok());
        limits.inc_pending_inbound();
        assert!(limits.ensure_pending_inbound().is_err());
    }
}