reth_ecies/
codec.rs

1//! This contains the main codec for `RLPx` ECIES messages
2
3use crate::{algorithm::ECIES, ECIESError, ECIESErrorImpl, EgressECIESValue, IngressECIESValue};
4use alloy_primitives::{bytes::BytesMut, B512 as PeerId};
5use secp256k1::SecretKey;
6use std::{fmt::Debug, io};
7use tokio_util::codec::{Decoder, Encoder};
8use tracing::{instrument, trace};
9
10/// The max size that the initial handshake packet can be. Currently 2KiB.
11const MAX_INITIAL_HANDSHAKE_SIZE: usize = 2048;
12
13/// Tokio codec for ECIES
14#[derive(Debug)]
15pub struct ECIESCodec {
16    ecies: ECIES,
17    state: ECIESState,
18}
19
20/// Current ECIES state of a connection
21#[derive(Clone, Copy, Debug, PartialEq, Eq)]
22pub enum ECIESState {
23    /// The first stage of the ECIES handshake, where each side of the connection sends an auth
24    /// message containing the ephemeral public key, signature of the public key, nonce, and other
25    /// metadata.
26    Auth,
27
28    /// The second stage of the ECIES handshake, where each side of the connection sends an ack
29    /// message containing the nonce and other metadata.
30    Ack,
31
32    /// This is the same as the [`ECIESState::Header`] stage, but occurs only after the first
33    /// [`ECIESState::Ack`] message. This is so that the initial handshake message can be properly
34    /// validated.
35    InitialHeader,
36
37    /// The third stage of the ECIES handshake, where header is parsed, message integrity checks
38    /// performed, and message is decrypted.
39    Header,
40
41    /// The final stage, where the ECIES message is actually read and returned by the ECIES codec.
42    Body,
43}
44
45impl ECIESCodec {
46    /// Create a new server codec using the given secret key
47    pub(crate) fn new_server(secret_key: SecretKey) -> Result<Self, ECIESError> {
48        Ok(Self { ecies: ECIES::new_server(secret_key)?, state: ECIESState::Auth })
49    }
50
51    /// Create a new client codec using the given secret key and the server's public id
52    pub(crate) fn new_client(secret_key: SecretKey, remote_id: PeerId) -> Result<Self, ECIESError> {
53        Ok(Self { ecies: ECIES::new_client(secret_key, remote_id)?, state: ECIESState::Auth })
54    }
55}
56
57impl Decoder for ECIESCodec {
58    type Item = IngressECIESValue;
59    type Error = ECIESError;
60
61    #[instrument(level = "trace", skip_all, fields(peer=?self.ecies.remote_id, state=?self.state))]
62    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
63        loop {
64            match self.state {
65                ECIESState::Auth => {
66                    trace!("parsing auth");
67                    if buf.len() < 2 {
68                        return Ok(None)
69                    }
70
71                    let payload_size = u16::from_be_bytes([buf[0], buf[1]]) as usize;
72                    let total_size = payload_size + 2;
73
74                    if buf.len() < total_size {
75                        trace!("current len {}, need {}", buf.len(), total_size);
76                        return Ok(None)
77                    }
78
79                    self.ecies.read_auth(&mut buf.split_to(total_size))?;
80
81                    self.state = ECIESState::InitialHeader;
82                    return Ok(Some(IngressECIESValue::AuthReceive(self.ecies.remote_id())))
83                }
84                ECIESState::Ack => {
85                    trace!("parsing ack with len {}", buf.len());
86                    if buf.len() < 2 {
87                        return Ok(None)
88                    }
89
90                    let payload_size = u16::from_be_bytes([buf[0], buf[1]]) as usize;
91                    let total_size = payload_size + 2;
92
93                    if buf.len() < total_size {
94                        trace!("current len {}, need {}", buf.len(), total_size);
95                        return Ok(None)
96                    }
97
98                    self.ecies.read_ack(&mut buf.split_to(total_size))?;
99
100                    self.state = ECIESState::InitialHeader;
101                    return Ok(Some(IngressECIESValue::Ack))
102                }
103                ECIESState::InitialHeader => {
104                    if buf.len() < ECIES::header_len() {
105                        trace!("current len {}, need {}", buf.len(), ECIES::header_len());
106                        return Ok(None)
107                    }
108
109                    let body_size =
110                        self.ecies.read_header(&mut buf.split_to(ECIES::header_len()))?;
111
112                    if body_size > MAX_INITIAL_HANDSHAKE_SIZE {
113                        trace!(?body_size, max=?MAX_INITIAL_HANDSHAKE_SIZE, "Header exceeds max initial handshake size");
114                        return Err(ECIESErrorImpl::InitialHeaderBodyTooLarge {
115                            body_size,
116                            max_body_size: MAX_INITIAL_HANDSHAKE_SIZE,
117                        }
118                        .into())
119                    }
120
121                    self.state = ECIESState::Body;
122                }
123                ECIESState::Header => {
124                    if buf.len() < ECIES::header_len() {
125                        trace!("current len {}, need {}", buf.len(), ECIES::header_len());
126                        return Ok(None)
127                    }
128
129                    self.ecies.read_header(&mut buf.split_to(ECIES::header_len()))?;
130
131                    self.state = ECIESState::Body;
132                }
133                ECIESState::Body => {
134                    if buf.len() < self.ecies.body_len() {
135                        return Ok(None)
136                    }
137
138                    let mut data = buf.split_to(self.ecies.body_len());
139                    let mut ret = BytesMut::new();
140                    ret.extend_from_slice(self.ecies.read_body(&mut data)?);
141
142                    self.state = ECIESState::Header;
143                    return Ok(Some(IngressECIESValue::Message(ret)))
144                }
145            }
146        }
147    }
148}
149
150impl Encoder<EgressECIESValue> for ECIESCodec {
151    type Error = io::Error;
152
153    #[instrument(level = "trace", skip(self, buf), fields(peer=?self.ecies.remote_id, state=?self.state))]
154    fn encode(&mut self, item: EgressECIESValue, buf: &mut BytesMut) -> Result<(), Self::Error> {
155        match item {
156            EgressECIESValue::Auth => {
157                self.state = ECIESState::Ack;
158                self.ecies.write_auth(buf);
159                Ok(())
160            }
161            EgressECIESValue::Ack => {
162                self.state = ECIESState::InitialHeader;
163                self.ecies.write_ack(buf);
164                Ok(())
165            }
166            EgressECIESValue::Message(data) => {
167                self.ecies.write_header(buf, data.len());
168                self.ecies.write_body(buf, &data);
169                Ok(())
170            }
171        }
172    }
173}