1use crate::{
4 errors::{P2PHandshakeError, P2PStreamError},
5 p2pstream::MAX_RESERVED_MESSAGE_ID,
6 protocol::{ProtoVersion, Protocol},
7 version::ParseVersionError,
8 Capability, EthMessageID, EthVersion,
9};
10use derive_more::{Deref, DerefMut};
11use std::{
12 borrow::Cow,
13 collections::{BTreeSet, HashMap},
14};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub enum SharedCapability {
24 Eth {
26 version: EthVersion,
28 offset: u8,
33 },
34 UnknownCapability {
36 cap: Capability,
38 offset: u8,
43 messages: u8,
46 },
47}
48
49impl SharedCapability {
50 pub(crate) fn new(
55 name: &str,
56 version: u8,
57 offset: u8,
58 messages: u8,
59 ) -> Result<Self, SharedCapabilityError> {
60 if offset <= MAX_RESERVED_MESSAGE_ID {
61 return Err(SharedCapabilityError::ReservedMessageIdOffset(offset))
62 }
63
64 match name {
65 "eth" => Ok(Self::eth(EthVersion::try_from(version)?, offset)),
66 _ => Ok(Self::UnknownCapability {
67 cap: Capability::new(name.to_string(), version as usize),
68 offset,
69 messages,
70 }),
71 }
72 }
73
74 pub(crate) const fn eth(version: EthVersion, offset: u8) -> Self {
76 Self::Eth { version, offset }
77 }
78
79 pub const fn capability(&self) -> Cow<'_, Capability> {
81 match self {
82 Self::Eth { version, .. } => Cow::Owned(Capability::eth(*version)),
83 Self::UnknownCapability { cap, .. } => Cow::Borrowed(cap),
84 }
85 }
86
87 #[inline]
89 pub fn name(&self) -> &str {
90 match self {
91 Self::Eth { .. } => "eth",
92 Self::UnknownCapability { cap, .. } => cap.name.as_ref(),
93 }
94 }
95
96 #[inline]
98 pub const fn is_eth(&self) -> bool {
99 matches!(self, Self::Eth { .. })
100 }
101
102 pub const fn version(&self) -> u8 {
104 match self {
105 Self::Eth { version, .. } => *version as u8,
106 Self::UnknownCapability { cap, .. } => cap.version as u8,
107 }
108 }
109
110 pub const fn eth_version(&self) -> Option<EthVersion> {
112 match self {
113 Self::Eth { version, .. } => Some(*version),
114 _ => None,
115 }
116 }
117
118 pub const fn message_id_offset(&self) -> u8 {
123 match self {
124 Self::Eth { offset, .. } | Self::UnknownCapability { offset, .. } => *offset,
125 }
126 }
127
128 pub const fn relative_message_id_offset(&self) -> u8 {
131 self.message_id_offset() - MAX_RESERVED_MESSAGE_ID - 1
132 }
133
134 pub const fn num_messages(&self) -> u8 {
136 match self {
137 Self::Eth { version, .. } => EthMessageID::message_count(*version),
138 Self::UnknownCapability { messages, .. } => *messages,
139 }
140 }
141}
142
143#[derive(Debug, Clone, Deref, DerefMut, PartialEq, Eq)]
147pub struct SharedCapabilities(Vec<SharedCapability>);
148
149impl SharedCapabilities {
150 #[inline]
152 pub fn try_new(
153 local_protocols: Vec<Protocol>,
154 peer_capabilities: Vec<Capability>,
155 ) -> Result<Self, P2PStreamError> {
156 shared_capability_offsets(local_protocols, peer_capabilities).map(Self)
157 }
158
159 #[inline]
161 pub fn iter_caps(&self) -> impl Iterator<Item = &SharedCapability> {
162 self.0.iter()
163 }
164
165 #[inline]
167 pub fn eth(&self) -> Result<&SharedCapability, P2PStreamError> {
168 self.iter_caps().find(|c| c.is_eth()).ok_or(P2PStreamError::CapabilityNotShared)
169 }
170
171 #[inline]
173 pub fn eth_version(&self) -> Result<EthVersion, P2PStreamError> {
174 self.iter_caps()
175 .find_map(SharedCapability::eth_version)
176 .ok_or(P2PStreamError::CapabilityNotShared)
177 }
178
179 #[inline]
181 pub fn contains(&self, cap: &Capability) -> bool {
182 self.find(cap).is_some()
183 }
184
185 #[inline]
187 pub fn find(&self, cap: &Capability) -> Option<&SharedCapability> {
188 self.0.iter().find(|c| c.version() == cap.version as u8 && c.name() == cap.name)
189 }
190
191 #[inline]
200 pub fn find_by_relative_offset(&self, offset: u8) -> Option<&SharedCapability> {
201 self.find_by_offset(offset.saturating_add(MAX_RESERVED_MESSAGE_ID + 1))
202 }
203
204 #[inline]
212 pub fn find_by_offset(&self, offset: u8) -> Option<&SharedCapability> {
213 let mut iter = self.0.iter();
214 let mut cap = iter.next()?;
215 if offset < cap.message_id_offset() {
216 return None
218 }
219
220 for next in iter {
221 if offset < next.message_id_offset() {
222 return Some(cap)
223 }
224 cap = next
225 }
226
227 Some(cap)
228 }
229
230 #[inline]
232 pub fn ensure_matching_capability(
233 &self,
234 cap: &Capability,
235 ) -> Result<&SharedCapability, UnsupportedCapabilityError> {
236 self.find(cap).ok_or_else(|| UnsupportedCapabilityError { capability: cap.clone() })
237 }
238
239 #[inline]
241 pub const fn len(&self) -> usize {
242 self.0.len()
243 }
244
245 #[inline]
247 pub const fn is_empty(&self) -> bool {
248 self.0.is_empty()
249 }
250}
251
252#[inline]
261pub fn shared_capability_offsets(
262 local_protocols: Vec<Protocol>,
263 peer_capabilities: Vec<Capability>,
264) -> Result<Vec<SharedCapability>, P2PStreamError> {
265 let our_capabilities =
267 local_protocols.into_iter().map(Protocol::split).collect::<HashMap<_, _>>();
268
269 let mut shared_capabilities: HashMap<_, ProtoVersion> = HashMap::default();
271
272 let mut shared_capability_names = BTreeSet::new();
284
285 for peer_capability in peer_capabilities {
287 if let Some(messages) = our_capabilities.get(&peer_capability).copied() {
289 if shared_capabilities
292 .get(&peer_capability.name)
293 .is_none_or(|v| peer_capability.version > v.version)
294 {
295 shared_capabilities.insert(
296 peer_capability.name.clone(),
297 ProtoVersion { version: peer_capability.version, messages },
298 );
299 shared_capability_names.insert(peer_capability.name);
300 }
301 }
302 }
303
304 if shared_capabilities.is_empty() {
306 return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
307 }
308
309 let mut shared_with_offsets = Vec::new();
312
313 let mut offset = MAX_RESERVED_MESSAGE_ID + 1;
317 for name in shared_capability_names {
318 let proto_version = &shared_capabilities[&name];
319 let shared_capability = SharedCapability::new(
320 &name,
321 proto_version.version as u8,
322 offset,
323 proto_version.messages,
324 )?;
325 offset += shared_capability.num_messages();
326 shared_with_offsets.push(shared_capability);
327 }
328
329 if shared_with_offsets.is_empty() {
330 return Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
331 }
332
333 Ok(shared_with_offsets)
334}
335
336#[derive(Debug, thiserror::Error)]
338pub enum SharedCapabilityError {
339 #[error(transparent)]
341 UnsupportedVersion(#[from] ParseVersionError),
342 #[error("message id offset `{0}` is reserved")]
345 ReservedMessageIdOffset(u8),
346}
347
348#[derive(Debug, thiserror::Error)]
350#[error("unsupported capability {capability}")]
351pub struct UnsupportedCapabilityError {
352 capability: Capability,
353}
354
355impl UnsupportedCapabilityError {
356 pub const fn new(capability: Capability) -> Self {
358 Self { capability }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::{Capabilities, Capability};
366 use alloy_primitives::bytes::Bytes;
367 use alloy_rlp::{Decodable, Encodable};
368 use reth_eth_wire_types::RawCapabilityMessage;
369
370 #[test]
371 fn from_eth_68() {
372 let capability = SharedCapability::new("eth", 68, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
373
374 assert_eq!(capability.name(), "eth");
375 assert_eq!(capability.version(), 68);
376 assert_eq!(
377 capability,
378 SharedCapability::Eth {
379 version: EthVersion::Eth68,
380 offset: MAX_RESERVED_MESSAGE_ID + 1
381 }
382 );
383 }
384
385 #[test]
386 fn from_eth_67() {
387 let capability = SharedCapability::new("eth", 67, MAX_RESERVED_MESSAGE_ID + 1, 13).unwrap();
388
389 assert_eq!(capability.name(), "eth");
390 assert_eq!(capability.version(), 67);
391 assert_eq!(
392 capability,
393 SharedCapability::Eth {
394 version: EthVersion::Eth67,
395 offset: MAX_RESERVED_MESSAGE_ID + 1
396 }
397 );
398 }
399
400 #[test]
401 fn from_eth_66() {
402 let capability = SharedCapability::new("eth", 66, MAX_RESERVED_MESSAGE_ID + 1, 15).unwrap();
403
404 assert_eq!(capability.name(), "eth");
405 assert_eq!(capability.version(), 66);
406 assert_eq!(
407 capability,
408 SharedCapability::Eth {
409 version: EthVersion::Eth66,
410 offset: MAX_RESERVED_MESSAGE_ID + 1
411 }
412 );
413 }
414
415 #[test]
416 fn capabilities_supports_eth() {
417 let capabilities: Capabilities = vec![
418 Capability::new_static("eth", 66),
419 Capability::new_static("eth", 67),
420 Capability::new_static("eth", 68),
421 Capability::new_static("eth", 69),
422 Capability::new_static("eth", 70),
423 ]
424 .into();
425
426 assert!(capabilities.supports_eth());
427 assert!(capabilities.supports_eth_v66());
428 assert!(capabilities.supports_eth_v67());
429 assert!(capabilities.supports_eth_v68());
430 assert!(capabilities.supports_eth_v69());
431 assert!(capabilities.supports_eth_v70());
432 }
433
434 #[test]
435 fn test_peer_capability_version_zero() {
436 let cap = Capability::new_static("TestName", 0);
437 let local_capabilities: Vec<Protocol> =
438 vec![Protocol::new(cap.clone(), 0), EthVersion::Eth67.into(), EthVersion::Eth68.into()];
439 let peer_capabilities = vec![cap.clone()];
440
441 let shared = shared_capability_offsets(local_capabilities, peer_capabilities).unwrap();
442 assert_eq!(shared.len(), 1);
443 assert_eq!(shared[0], SharedCapability::UnknownCapability { cap, offset: 16, messages: 0 })
444 }
445
446 #[test]
447 fn test_peer_lower_capability_version() {
448 let local_capabilities: Vec<Protocol> =
449 vec![EthVersion::Eth66.into(), EthVersion::Eth67.into(), EthVersion::Eth68.into()];
450 let peer_capabilities: Vec<Capability> = vec![EthVersion::Eth66.into()];
451
452 let shared_capability =
453 shared_capability_offsets(local_capabilities, peer_capabilities).unwrap()[0].clone();
454
455 assert_eq!(
456 shared_capability,
457 SharedCapability::Eth {
458 version: EthVersion::Eth66,
459 offset: MAX_RESERVED_MESSAGE_ID + 1
460 }
461 )
462 }
463
464 #[test]
465 fn test_peer_capability_version_too_low() {
466 let local: Vec<Protocol> = vec![EthVersion::Eth67.into()];
467 let peer_capabilities: Vec<Capability> = vec![EthVersion::Eth66.into()];
468
469 let shared_capability = shared_capability_offsets(local, peer_capabilities);
470
471 assert!(matches!(
472 shared_capability,
473 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
474 ))
475 }
476
477 #[test]
478 fn test_peer_capability_version_too_high() {
479 let local_capabilities = vec![EthVersion::Eth66.into()];
480 let peer_capabilities = vec![EthVersion::Eth67.into()];
481
482 let shared_capability = shared_capability_offsets(local_capabilities, peer_capabilities);
483
484 assert!(matches!(
485 shared_capability,
486 Err(P2PStreamError::HandshakeError(P2PHandshakeError::NoSharedCapabilities))
487 ))
488 }
489
490 #[test]
491 fn test_find_by_offset() {
492 let local_capabilities = vec![EthVersion::Eth66.into()];
493 let peer_capabilities = vec![EthVersion::Eth66.into()];
494
495 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
496
497 let shared_eth = shared.find_by_relative_offset(0).unwrap();
498 assert_eq!(shared_eth.name(), "eth");
499
500 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
501 assert_eq!(shared_eth.name(), "eth");
502
503 assert!(shared.find_by_offset(MAX_RESERVED_MESSAGE_ID).is_none());
505 }
506
507 #[test]
508 fn test_find_by_offset_many() {
509 let cap = Capability::new_static("aaa", 1);
510 let proto = Protocol::new(cap.clone(), 5);
511 let local_capabilities = vec![proto.clone(), EthVersion::Eth66.into()];
512 let peer_capabilities = vec![cap, EthVersion::Eth66.into()];
513
514 let shared = SharedCapabilities::try_new(local_capabilities, peer_capabilities).unwrap();
515
516 let shared_eth = shared.find_by_relative_offset(0).unwrap();
517 assert_eq!(shared_eth.name(), proto.cap.name);
518
519 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 1).unwrap();
520 assert_eq!(shared_eth.name(), proto.cap.name);
521
522 let shared_eth = shared.find_by_relative_offset(4).unwrap();
524 assert_eq!(shared_eth.name(), proto.cap.name);
525 let shared_eth = shared.find_by_offset(MAX_RESERVED_MESSAGE_ID + 5).unwrap();
526 assert_eq!(shared_eth.name(), proto.cap.name);
527
528 let shared_eth = shared.find_by_relative_offset(1 + proto.messages()).unwrap();
530 assert_eq!(shared_eth.name(), "eth");
531 }
532
533 #[test]
534 fn test_raw_capability_rlp() {
535 let msg = RawCapabilityMessage { id: 1, payload: Bytes::from(vec![0x01, 0x02, 0x03]) };
536
537 let mut encoded = Vec::new();
539 msg.encode(&mut encoded);
540
541 let decoded = RawCapabilityMessage::decode(&mut &encoded[..]).unwrap();
543
544 assert_eq!(msg, decoded);
546 }
547}