1use crate::{
4 errors::{P2PHandshakeError, P2PStreamError},
5 p2pstream::MAX_RESERVED_MESSAGE_ID,
6 protocol::{ProtoVersion, Protocol},
7 version::ParseVersionError,
8 Capability, EthMessageID, EthVersion,
9};
10use derive_more::{Deref, DerefMut};
11use std::{
12 borrow::Cow,
13 collections::{BTreeSet, HashMap},
14};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum SharedCapability {
24 Eth {
26 version: EthVersion,
28 offset: u8,
33 },
34 UnknownCapability {
36 cap: Capability,
38 offset: u8,
43 messages: u8,
46 },
47}
48
49impl SharedCapability {
50 pub(crate) fn new(
55 name: &str,
56 version: u8,
57 offset: u8,
58 messages: u8,
59 ) -> Result<Self, SharedCapabilityError> {
60 if offset <= MAX_RESERVED_MESSAGE_ID {
61 return Err(SharedCapabilityError::ReservedMessageIdOffset(offset))
62 }
63
64 match name {
65 "eth" => Ok(Self::eth(EthVersion::try_from(version)?, offset)),
66 _ => Ok(Self::UnknownCapability {
67 cap: Capability::new(name.to_string(), version as usize),
68 offset,
69 messages,
70 }),
71 }
72 }
73
74 pub(crate) const fn eth(version: EthVersion, offset: u8) -> Self {
76 Self::Eth { version, offset }
77 }
78
79 pub const fn capability(&self) -> Cow<'_, Capability> {
81 match self {
82 Self::Eth { version, .. } => Cow::Owned(Capability::eth(*version)),
83 Self::UnknownCapability { cap, .. } => Cow::Borrowed(cap),
84 }
85 }
86
87 #[inline]
89 pub fn name(&self) -> &str {
90 match self {
91 Self::Eth { .. } => "eth",
92 Self::UnknownCapability { cap, .. } => cap.name.as_ref(),
93 }
94 }
95
96 #[inline]
98 pub const fn is_eth(&self) -> bool {
99 matches!(self, Self::Eth { .. })
100 }
101
102 pub const fn version(&self) -> u8 {
104 match self {
105 Self::Eth { version, .. } => *version as u8,
106 Self::UnknownCapability { cap, .. } => cap.version as u8,
107 }
108 }
109
110 pub const fn eth_version(&self) -> Option<EthVersion> {
112 match self {
113 Self::Eth { version, .. } => Some(*version),
114 _ => None,
115 }
116 }
117
118 pub const fn message_id_offset(&self) -> u8 {
123 match self {
124 Self::Eth { offset, .. } | Self::UnknownCapability { offset, .. } => *offset,
125 }
126 }
127
128 pub const fn relative_message_id_offset(&self) -> u8 {
131 self.message_id_offset() - MAX_RESERVED_MESSAGE_ID - 1
132 }
133
134 pub const fn num_messages(&self) -> u8 {
136 match self {
137 Self::Eth { version, .. } => EthMessageID::message_count(*version),
138 Self::UnknownCapability { messages, .. } => *messages,
139 }
140 }
141}
142
143#[derive(Debug, Clone, Deref, DerefMut, PartialEq, Eq)]
147pub struct SharedCapabilities(Vec<SharedCapability>);
148
149impl SharedCapabilities {
150 #[inline]
152 pub fn try_new(
153 local_protocols: Vec<Protocol>,
154 peer_capabilities: Vec<Capability>,
155 ) -> Result<Self, P2PStreamError> {
156 shared_capability_offsets(local_protocols, peer_capabilities).map(Self)
157 }
158
159 #[inline]
161 pub fn iter_caps(&self) -> impl Iterator<Item = &SharedCapability> {
162 self.0.iter()
163 }
164
165 #[inline]
167 pub fn eth(&self) -> Result<&SharedCapability, P2PStreamError> {
168 self.iter_caps().find(|c| c.is_eth()).ok_or(P2PStreamError::CapabilityNotShared)
169 }
170
171 #[inline]
173 pub fn eth_version(&self) -> Result<EthVersion, P2PStreamError> {
174 self.iter_caps()
175 .find_map(SharedCapability::eth_version)
176 .ok_or(P2PStreamError::CapabilityNotShared)
177 }
178
179 #[inline]
181 pub fn contains(&self, cap: &Capability) -> bool {
182 self.find(cap).is_some()
183 }
184
185 #[inline]
187 pub fn find(&self, cap: &Capability) -> Option<&SharedCapability> {
188 self.0.iter().find(|c| c.version() == cap.version as u8 && c.name() == cap.name)
189 }
190
191 #[inline]
197 pub fn relative_message_id(&self, cap: &Capability, message_id: u8) -> Option<u8> {
198 let shared = self.find(cap)?;
199 if message_id >= shared.num_messages() {
200 return None
201 }
202
203 shared.relative_message_id_offset().checked_add(message_id)
204 }
205
206 #[inline]
211 pub fn capability_message_id(&self, cap: &Capability, relative_message_id: u8) -> Option<u8> {
212 let shared = self.find(cap)?;
213 let start = shared.relative_message_id_offset();
214 let end = start.checked_add(shared.num_messages())?;
215
216 (start..end).contains(&relative_message_id).then(|| relative_message_id - start)
217 }
218
219 #[inline]
228 pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> {
229 self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID + 1))
230 }
231
232 #[inline]
240 pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> {
241 let mut iter = self.0.iter();
242 let mut cap = iter.next()?;
243 if offset < cap.message_id_offset() {
244 return None
246 }
247
248 for next in iter {
249 if offset < next.message_id_offset() {
250 return Some(cap)
251 }
252 cap = next
253 }
254
255 Some(cap)
256 }
257
258 #[inline]
260 pub fn ensure_matching_capability(
261 &self,
262 cap: &Capability,
263 ) -> Result<&SharedCapability, UnsupportedCapabilityError> {
264 self.find(cap).ok_or_else(|| UnsupportedCapabilityError { capability: cap.clone() })
265 }
266
267 #[inline]
269 pub const fn len(&self) -> usize {
270 self.0.len()
271 }
272
273 #[inline]
275 pub const fn is_empty(&self) -> bool {
276 self.0.is_empty()
277 }
278}
279
280#[inline]
289pub fn shared_capability_offsets(
290 local_protocols: Vec<Protocol>,
291 peer_capabilities: Vec<Capability>,
292) -> Result<Vec<SharedCapability>, P2PStreamError> {
293 let our_capabilities =
295 local_protocols.into_iter().map(Protocol::split).collect::<HashMap<_, _>>();
296
297 let mut shared_capabilities: HashMap<_, ProtoVersion> = HashMap::default();
299
300 let mut shared_capability_names = BTreeSet::new();
312
313 for peer_capability in peer_capabilities {
315 if let Some(messages) = our_capabilities.get(&peer_capability).copied() {
317 if shared_capabilities
320 .get(&peer_capability.name)
321 .is_none_or(|v| peer_capability.version > v.version)
322 {
323 shared_capabilities.insert(
324 peer_capability.name.clone(),
325 ProtoVersion { version: peer_capability.version, messages },
326 );
327 shared_capability_names.insert(peer_capability.name);
328 }
329 }
330 }
331
332 if shared_capabilities.is_empty() {
334 return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
335 }
336
337 let mut shared_with_offsets = Vec::new();
340
341 let mut offset = MAX_RESERVED_MESSAGE_ID + 1;
345 for name in shared_capability_names {
346 let proto_version = &shared_capabilities[&name];
347 let shared_capability = SharedCapability::new(
348 &name,
349 proto_version.version as u8,
350 offset,
351 proto_version.messages,
352 )?;
353 offset += shared_capability.num_messages();
354 shared_with_offsets.push(shared_capability);
355 }
356
357 if shared_with_offsets.is_empty() {
358 return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
359 }
360
361 Ok(shared_with_offsets)
362}
363
364#[derive(Debug, thiserror::Error)]
366pub enum SharedCapabilityError {
367 #[error(transparent)]
369 UnsupportedVersion(#[from] ParseVersionError),
370 #[error("message id offset `{0}` is reserved")]
373 ReservedMessageIdOffset(u8),
374}
375
376#[derive(Debug, thiserror::Error)]
378#[error("unsupported capability {capability}")]
379pub struct UnsupportedCapabilityError {
380 capability: Capability,
381}
382
383impl UnsupportedCapabilityError {
384 pub const fn new(capability: Capability) -> Self {
386 Self { capability }
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::{Capabilities, Capability, SnapVersion};
394 use alloy_primitives::bytes::Bytes;
395 use alloy_rlp::{Decodable, Encodable};
396 use reth_eth_wire_types::RawCapabilityMessage;
397
398 #[test]
399 fn from_eth_68() {
400 let capability = SharedCapability::new("eth", 68, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
401
402 assert_eq!(capability.name(), "eth");
403 assert_eq!(capability.version(), 68);
404 assert_eq!(
405 capability,
406 SharedCapability::Eth {
407 version: EthVersion::Eth68,
408 offset: MAX_RESERVED_MESSAGE_ID + 1
409 }
410 );
411 }
412
413 #[test]
414 fn from_eth_67() {
415 let capability = SharedCapability::new("eth", 67, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
416
417 assert_eq!(capability.name(), "eth");
418 assert_eq!(capability.version(), 67);
419 assert_eq!(
420 capability,
421 SharedCapability::Eth {
422 version: EthVersion::Eth67,
423 offset: MAX_RESERVED_MESSAGE_ID + 1
424 }
425 );
426 }
427
428 #[test]
429 fn from_eth_66() {
430 let capability = SharedCapability::new("eth", 66, MAX_RESERVED_MESSAGE_ID + 1, 15).unwrap();
431
432 assert_eq!(capability.name(), "eth");
433 assert_eq!(capability.version(), 66);
434 assert_eq!(
435 capability,
436 SharedCapability::Eth {
437 version: EthVersion::Eth66,
438 offset: MAX_RESERVED_MESSAGE_ID + 1
439 }
440 );
441 }
442
443 #[test]
444 fn capabilities_supports_eth() {
445 let capabilities: Capabilities = vec![
446 Capability::new_static("eth", 66),
447 Capability::new_static("eth", 67),
448 Capability::new_static("eth", 68),
449 Capability::new_static("eth", 69),
450 Capability::new_static("eth", 70),
451 ]
452 .into();
453
454 assert!(capabilities.supports_eth());
455 assert!(capabilities.supports_eth_v66());
456 assert!(capabilities.supports_eth_v67());
457 assert!(capabilities.supports_eth_v68());
458 assert!(capabilities.supports_eth_v69());
459 assert!(capabilities.supports_eth_v70());
460 }
461
462 #[test]
463 fn test_peer_capability_version_zero() {
464 let cap = Capability::new_static("TestName", 0);
465 let local_capabilities: Vec<Protocol> =
466 vec![Protocol::new(cap.clone(), 0), EthVersion::Eth67.into(), EthVersion::Eth68.into()];
467 let peer_capabilities = vec![cap.clone()];
468
469 let shared = shared_capability_offsets(local_capabilities, peer_capabilities).unwrap();
470 assert_eq!(shared.len(), 1);
471 assert_eq!(shared[0], SharedCapability::UnknownCapability { cap, offset: 16, messages: 0 })
472 }
473
474 #[test]
475 fn test_peer_lower_capability_version() {
476 let local_capabilities: Vec<Protocol> =
477 vec![EthVersion::Eth66.into(), EthVersion::Eth67.into(), EthVersion::Eth68.into()];
478 let peer_capabilities: Vec<Capability> = vec![EthVersion::Eth66.into()];
479
480 let shared_capability =
481 shared_capability_offsets(local_capabilities, peer_capabilities).unwrap()[0].clone();
482
483 assert_eq!(
484 shared_capability,
485 SharedCapability::Eth {
486 version: EthVersion::Eth66,
487 offset: MAX_RESERVED_MESSAGE_ID + 1
488 }
489 )
490 }
491
492 #[test]
493 fn test_peer_capability_version_too_low() {
494 let local: Vec<Protocol> = vec![EthVersion::Eth67.into()];
495 let peer_capabilities: Vec<Capability> = vec![EthVersion::Eth66.into()];
496
497 let shared_capability = shared_capability_offsets(local, peer_capabilities);
498
499 assert!(matches!(
500 shared_capability,
501 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
502 ))
503 }
504
505 #[test]
506 fn test_peer_capability_version_too_high() {
507 let local_capabilities = vec![EthVersion::Eth66.into()];
508 let peer_capabilities = vec![EthVersion::Eth67.into()];
509
510 let shared_capability = shared_capability_offsets(local_capabilities, peer_capabilities);
511
512 assert!(matches!(
513 shared_capability,
514 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
515 ))
516 }
517
518 #[test]
519 fn test_find_by_offset() {
520 let local_capabilities = vec![EthVersion::Eth66.into()];
521 let peer_capabilities = vec![EthVersion::Eth66.into()];
522
523 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
524
525 let shared_eth = shared.find_by_relative_offset(0).unwrap();
526 assert_eq!(shared_eth.name(), "eth");
527
528 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
529 assert_eq!(shared_eth.name(), "eth");
530
531 assert!(shared.find_by_offset(MAX_RESERVED_MESSAGE_ID).is_none());
533 }
534
535 #[test]
536 fn test_find_by_offset_many() {
537 let cap = Capability::new_static("aaa", 1);
538 let proto = Protocol::new(cap.clone(), 5);
539 let local_capabilities = vec![proto.clone(), EthVersion::Eth66.into()];
540 let peer_capabilities = vec![cap, EthVersion::Eth66.into()];
541
542 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
543
544 let shared_eth = shared.find_by_relative_offset(0).unwrap();
545 assert_eq!(shared_eth.name(), proto.cap.name);
546
547 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
548 assert_eq!(shared_eth.name(), proto.cap.name);
549
550 let shared_eth = shared.find_by_relative_offset(4).unwrap();
552 assert_eq!(shared_eth.name(), proto.cap.name);
553 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap();
554 assert_eq!(shared_eth.name(), proto.cap.name);
555
556 let shared_eth = shared.find_by_relative_offset(1 + proto.messages()).unwrap();
558 assert_eq!(shared_eth.name(), "eth");
559 }
560
561 #[test]
562 fn relative_message_id_accounts_for_intermediate_capabilities() {
563 let intermediate_cap = Capability::new_static("foo", 1);
564 let intermediate = Protocol::new(intermediate_cap.clone(), 3);
565 let snap = Capability::snap(SnapVersion::V1);
566 let eth = Capability::eth(EthVersion::Eth69);
567 let local_capabilities =
568 vec![EthVersion::Eth69.into(), intermediate, Protocol::snap(SnapVersion::V1)];
569 let peer_capabilities = vec![eth, intermediate_cap, snap.clone()];
570
571 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
572 let snap_id = shared.relative_message_id(&snap, 2).unwrap();
573
574 assert_eq!(snap_id, EthMessageID::message_count(EthVersion::Eth69) + 3 + 2);
575 assert_eq!(shared.capability_message_id(&snap, snap_id), Some(2));
576 }
577
578 #[test]
579 fn capability_message_id_rejects_other_capability_range() {
580 let intermediate_cap = Capability::new_static("foo", 1);
581 let intermediate = Protocol::new(intermediate_cap.clone(), 3);
582 let snap = Capability::snap(SnapVersion::V1);
583 let local_capabilities =
584 vec![EthVersion::Eth69.into(), intermediate, Protocol::snap(SnapVersion::V1)];
585 let peer_capabilities =
586 vec![Capability::eth(EthVersion::Eth69), intermediate_cap.clone(), snap.clone()];
587
588 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
589 let intermediate_id = shared.relative_message_id(&intermediate_cap, 1).unwrap();
590
591 assert_eq!(shared.capability_message_id(&snap, intermediate_id), None);
592 assert_eq!(shared.relative_message_id(&snap, SnapVersion::V1.message_count()), None);
593 }
594
595 #[test]
596 fn test_raw_capability_rlp() {
597 let msg = RawCapabilityMessage { id: 1, payload: Bytes::from(vec![0x01, 0x02, 0x03]) };
598
599 let mut encoded = Vec::new();
601 msg.encode(&mut encoded);
602
603 let decoded = RawCapabilityMessage::decode(&mut &encoded[..]).unwrap();
605
606 assert_eq!(msg, decoded);
608 }
609}