1use crate::{
4 errors::{P2PHandshakeError, P2PStreamError},
5 p2pstream::MAX_RESERVED_MESSAGE_ID,
6 protocol::{ProtoVersion, Protocol},
7 version::ParseVersionError,
8 Capability, EthMessageID, EthVersion,
9};
10use derive_more::{Deref, DerefMut};
11use std::{
12 borrow::Cow,
13 collections::{BTreeSet, HashMap},
14};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum SharedCapability {
24 Eth {
26 version: EthVersion,
28 offset: u8,
33 },
34 UnknownCapability {
36 cap: Capability,
38 offset: u8,
43 messages: u8,
46 },
47}
48
49impl SharedCapability {
50 pub(crate) fn new(
55 name: &str,
56 version: u8,
57 offset: u8,
58 messages: u8,
59 ) -> Result<Self, SharedCapabilityError> {
60 if offset <= MAX_RESERVED_MESSAGE_ID {
61 return Err(SharedCapabilityError::ReservedMessageIdOffset(offset))
62 }
63
64 match name {
65 "eth" => Ok(Self::eth(EthVersion::try_from(version)?, offset)),
66 _ => Ok(Self::UnknownCapability {
67 cap: Capability::new(name.to_string(), version as usize),
68 offset,
69 messages,
70 }),
71 }
72 }
73
74 pub(crate) const fn eth(version: EthVersion, offset: u8) -> Self {
76 Self::Eth { version, offset }
77 }
78
79 pub const fn capability(&self) -> Cow<'_, Capability> {
81 match self {
82 Self::Eth { version, .. } => Cow::Owned(Capability::eth(*version)),
83 Self::UnknownCapability { cap, .. } => Cow::Borrowed(cap),
84 }
85 }
86
87 #[inline]
89 pub fn name(&self) -> &str {
90 match self {
91 Self::Eth { .. } => "eth",
92 Self::UnknownCapability { cap, .. } => cap.name.as_ref(),
93 }
94 }
95
96 #[inline]
98 pub const fn is_eth(&self) -> bool {
99 matches!(self, Self::Eth { .. })
100 }
101
102 pub const fn version(&self) -> u8 {
104 match self {
105 Self::Eth { version, .. } => *version as u8,
106 Self::UnknownCapability { cap, .. } => cap.version as u8,
107 }
108 }
109
110 pub const fn eth_version(&self) -> Option<EthVersion> {
112 match self {
113 Self::Eth { version, .. } => Some(*version),
114 _ => None,
115 }
116 }
117
118 pub const fn message_id_offset(&self) -> u8 {
123 match self {
124 Self::Eth { offset, .. } | Self::UnknownCapability { offset, .. } => *offset,
125 }
126 }
127
128 pub const fn relative_message_id_offset(&self) -> u8 {
131 self.message_id_offset() - MAX_RESERVED_MESSAGE_ID - 1
132 }
133
134 pub const fn num_messages(&self) -> u8 {
136 match self {
137 Self::Eth { version: _version, .. } => EthMessageID::max() + 1,
138 Self::UnknownCapability { messages, .. } => *messages,
139 }
140 }
141}
142
143#[derive(Debug, Clone, Deref, DerefMut, PartialEq, Eq)]
147pub struct SharedCapabilities(Vec<SharedCapability>);
148
149impl SharedCapabilities {
150 #[inline]
152 pub fn try_new(
153 local_protocols: Vec<Protocol>,
154 peer_capabilities: Vec<Capability>,
155 ) -> Result<Self, P2PStreamError> {
156 shared_capability_offsets(local_protocols, peer_capabilities).map(Self)
157 }
158
159 #[inline]
161 pub fn iter_caps(&self) -> impl Iterator<Item = &SharedCapability> {
162 self.0.iter()
163 }
164
165 #[inline]
167 pub fn eth(&self) -> Result<&SharedCapability, P2PStreamError> {
168 self.iter_caps().find(|c| c.is_eth()).ok_or(P2PStreamError::CapabilityNotShared)
169 }
170
171 #[inline]
173 pub fn eth_version(&self) -> Result<EthVersion, P2PStreamError> {
174 self.iter_caps()
175 .find_map(SharedCapability::eth_version)
176 .ok_or(P2PStreamError::CapabilityNotShared)
177 }
178
179 #[inline]
181 pub fn contains(&self, cap: &Capability) -> bool {
182 self.find(cap).is_some()
183 }
184
185 #[inline]
187 pub fn find(&self, cap: &Capability) -> Option<&SharedCapability> {
188 self.0.iter().find(|c| c.version() == cap.version as u8 && c.name() == cap.name)
189 }
190
191 #[inline]
200 pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> {
201 self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID + 1))
202 }
203
204 #[inline]
212 pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> {
213 let mut iter = self.0.iter();
214 let mut cap = iter.next()?;
215 if offset < cap.message_id_offset() {
216 return None
218 }
219
220 for next in iter {
221 if offset < next.message_id_offset() {
222 return Some(cap)
223 }
224 cap = next
225 }
226
227 Some(cap)
228 }
229
230 #[inline]
232 pub fn ensure_matching_capability(
233 &self,
234 cap: &Capability,
235 ) -> Result<&SharedCapability, UnsupportedCapabilityError> {
236 self.find(cap).ok_or_else(|| UnsupportedCapabilityError { capability: cap.clone() })
237 }
238
239 #[inline]
241 pub fn len(&self) -> usize {
242 self.0.len()
243 }
244
245 #[inline]
247 pub fn is_empty(&self) -> bool {
248 self.0.is_empty()
249 }
250}
251
252#[inline]
261pub fn shared_capability_offsets(
262 local_protocols: Vec<Protocol>,
263 peer_capabilities: Vec<Capability>,
264) -> Result<Vec<SharedCapability>, P2PStreamError> {
265 let our_capabilities =
267 local_protocols.into_iter().map(Protocol::split).collect::<HashMap<_, _>>();
268
269 let mut shared_capabilities: HashMap<_, ProtoVersion> = HashMap::default();
271
272 let mut shared_capability_names = BTreeSet::new();
284
285 for peer_capability in peer_capabilities {
287 if let Some(messages) = our_capabilities.get(&peer_capability).copied() {
289 if shared_capabilities
292 .get(&peer_capability.name)
293 .is_none_or(|v| peer_capability.version > v.version)
294 {
295 shared_capabilities.insert(
296 peer_capability.name.clone(),
297 ProtoVersion { version: peer_capability.version, messages },
298 );
299 shared_capability_names.insert(peer_capability.name);
300 }
301 }
302 }
303
304 if shared_capabilities.is_empty() {
306 return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
307 }
308
309 let mut shared_with_offsets = Vec::new();
312
313 let mut offset = MAX_RESERVED_MESSAGE_ID + 1;
317 for name in shared_capability_names {
318 let proto_version = &shared_capabilities[&name];
319 let shared_capability = SharedCapability::new(
320 &name,
321 proto_version.version as u8,
322 offset,
323 proto_version.messages,
324 )?;
325 offset += shared_capability.num_messages();
326 shared_with_offsets.push(shared_capability);
327 }
328
329 if shared_with_offsets.is_empty() {
330 return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
331 }
332
333 Ok(shared_with_offsets)
334}
335
336#[derive(Debug, thiserror::Error)]
338pub enum SharedCapabilityError {
339 #[error(transparent)]
341 UnsupportedVersion(#[from] ParseVersionError),
342 #[error("message id offset `{0}` is reserved")]
345 ReservedMessageIdOffset(u8),
346}
347
348#[derive(Debug, thiserror::Error)]
350#[error("unsupported capability {capability}")]
351pub struct UnsupportedCapabilityError {
352 capability: Capability,
353}
354
355impl UnsupportedCapabilityError {
356 pub const fn new(capability: Capability) -> Self {
358 Self { capability }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::{Capabilities, Capability};
366 use alloy_primitives::bytes::Bytes;
367 use alloy_rlp::{Decodable, Encodable};
368 use reth_eth_wire_types::RawCapabilityMessage;
369
370 #[test]
371 fn from_eth_68() {
372 let capability = SharedCapability::new("eth", 68, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
373
374 assert_eq!(capability.name(), "eth");
375 assert_eq!(capability.version(), 68);
376 assert_eq!(
377 capability,
378 SharedCapability::Eth {
379 version: EthVersion::Eth68,
380 offset: MAX_RESERVED_MESSAGE_ID + 1
381 }
382 );
383 }
384
385 #[test]
386 fn from_eth_67() {
387 let capability = SharedCapability::new("eth", 67, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
388
389 assert_eq!(capability.name(), "eth");
390 assert_eq!(capability.version(), 67);
391 assert_eq!(
392 capability,
393 SharedCapability::Eth {
394 version: EthVersion::Eth67,
395 offset: MAX_RESERVED_MESSAGE_ID + 1
396 }
397 );
398 }
399
400 #[test]
401 fn from_eth_66() {
402 let capability = SharedCapability::new("eth", 66, MAX_RESERVED_MESSAGE_ID + 1, 15).unwrap();
403
404 assert_eq!(capability.name(), "eth");
405 assert_eq!(capability.version(), 66);
406 assert_eq!(
407 capability,
408 SharedCapability::Eth {
409 version: EthVersion::Eth66,
410 offset: MAX_RESERVED_MESSAGE_ID + 1
411 }
412 );
413 }
414
415 #[test]
416 fn capabilities_supports_eth() {
417 let capabilities: Capabilities = vec![
418 Capability::new_static("eth", 66),
419 Capability::new_static("eth", 67),
420 Capability::new_static("eth", 68),
421 ]
422 .into();
423
424 assert!(capabilities.supports_eth());
425 assert!(capabilities.supports_eth_v66());
426 assert!(capabilities.supports_eth_v67());
427 assert!(capabilities.supports_eth_v68());
428 }
429
430 #[test]
431 fn test_peer_capability_version_zero() {
432 let cap = Capability::new_static("TestName", 0);
433 let local_capabilities: Vec<Protocol> =
434 vec![Protocol::new(cap.clone(), 0), EthVersion::Eth67.into(), EthVersion::Eth68.into()];
435 let peer_capabilities = vec![cap.clone()];
436
437 let shared = shared_capability_offsets(local_capabilities, peer_capabilities).unwrap();
438 assert_eq!(shared.len(), 1);
439 assert_eq!(shared[0], SharedCapability::UnknownCapability { cap, offset: 16, messages: 0 })
440 }
441
442 #[test]
443 fn test_peer_lower_capability_version() {
444 let local_capabilities: Vec<Protocol> =
445 vec![EthVersion::Eth66.into(), EthVersion::Eth67.into(), EthVersion::Eth68.into()];
446 let peer_capabilities: Vec<Capability> = vec![EthVersion::Eth66.into()];
447
448 let shared_capability =
449 shared_capability_offsets(local_capabilities, peer_capabilities).unwrap()[0].clone();
450
451 assert_eq!(
452 shared_capability,
453 SharedCapability::Eth {
454 version: EthVersion::Eth66,
455 offset: MAX_RESERVED_MESSAGE_ID + 1
456 }
457 )
458 }
459
460 #[test]
461 fn test_peer_capability_version_too_low() {
462 let local: Vec<Protocol> = vec![EthVersion::Eth67.into()];
463 let peer_capabilities: Vec<Capability> = vec![EthVersion::Eth66.into()];
464
465 let shared_capability = shared_capability_offsets(local, peer_capabilities);
466
467 assert!(matches!(
468 shared_capability,
469 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
470 ))
471 }
472
473 #[test]
474 fn test_peer_capability_version_too_high() {
475 let local_capabilities = vec![EthVersion::Eth66.into()];
476 let peer_capabilities = vec![EthVersion::Eth67.into()];
477
478 let shared_capability = shared_capability_offsets(local_capabilities, peer_capabilities);
479
480 assert!(matches!(
481 shared_capability,
482 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
483 ))
484 }
485
486 #[test]
487 fn test_find_by_offset() {
488 let local_capabilities = vec![EthVersion::Eth66.into()];
489 let peer_capabilities = vec![EthVersion::Eth66.into()];
490
491 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
492
493 let shared_eth = shared.find_by_relative_offset(0).unwrap();
494 assert_eq!(shared_eth.name(), "eth");
495
496 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
497 assert_eq!(shared_eth.name(), "eth");
498
499 assert!(shared.find_by_offset(MAX_RESERVED_MESSAGE_ID).is_none());
501 }
502
503 #[test]
504 fn test_find_by_offset_many() {
505 let cap = Capability::new_static("aaa", 1);
506 let proto = Protocol::new(cap.clone(), 5);
507 let local_capabilities = vec![proto.clone(), EthVersion::Eth66.into()];
508 let peer_capabilities = vec![cap, EthVersion::Eth66.into()];
509
510 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
511
512 let shared_eth = shared.find_by_relative_offset(0).unwrap();
513 assert_eq!(shared_eth.name(), proto.cap.name);
514
515 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
516 assert_eq!(shared_eth.name(), proto.cap.name);
517
518 let shared_eth = shared.find_by_relative_offset(4).unwrap();
520 assert_eq!(shared_eth.name(), proto.cap.name);
521 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap();
522 assert_eq!(shared_eth.name(), proto.cap.name);
523
524 let shared_eth = shared.find_by_relative_offset(1 + proto.messages()).unwrap();
526 assert_eq!(shared_eth.name(), "eth");
527 }
528
529 #[test]
530 fn test_raw_capability_rlp() {
531 let msg = RawCapabilityMessage { id: 1, payload: Bytes::from(vec![0x01, 0x02, 0x03]) };
532
533 let mut encoded = Vec::new();
535 msg.encode(&mut encoded);
536
537 let decoded = RawCapabilityMessage::decode(&mut &encoded[..]).unwrap();
539
540 assert_eq!(msg, decoded);
542 }
543}