reth_ecies/
codec.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
//! This contains the main codec for `RLPx` ECIES messages

use crate::{algorithm::ECIES, ECIESError, ECIESErrorImpl, EgressECIESValue, IngressECIESValue};
use alloy_primitives::{bytes::BytesMut, B512 as PeerId};
use secp256k1::SecretKey;
use std::{fmt::Debug, io};
use tokio_util::codec::{Decoder, Encoder};
use tracing::{instrument, trace};

/// The max size that the initial handshake packet can be. Currently 2KiB.
const MAX_INITIAL_HANDSHAKE_SIZE: usize = 2048;

/// Tokio codec for ECIES
#[derive(Debug)]
pub struct ECIESCodec {
    ecies: ECIES,
    state: ECIESState,
}

/// Current ECIES state of a connection
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ECIESState {
    /// The first stage of the ECIES handshake, where each side of the connection sends an auth
    /// message containing the ephemeral public key, signature of the public key, nonce, and other
    /// metadata.
    Auth,

    /// The second stage of the ECIES handshake, where each side of the connection sends an ack
    /// message containing the nonce and other metadata.
    Ack,

    /// This is the same as the [`ECIESState::Header`] stage, but occurs only after the first
    /// [`ECIESState::Ack`] message. This is so that the initial handshake message can be properly
    /// validated.
    InitialHeader,

    /// The third stage of the ECIES handshake, where header is parsed, message integrity checks
    /// performed, and message is decrypted.
    Header,

    /// The final stage, where the ECIES message is actually read and returned by the ECIES codec.
    Body,
}

impl ECIESCodec {
    /// Create a new server codec using the given secret key
    pub(crate) fn new_server(secret_key: SecretKey) -> Result<Self, ECIESError> {
        Ok(Self { ecies: ECIES::new_server(secret_key)?, state: ECIESState::Auth })
    }

    /// Create a new client codec using the given secret key and the server's public id
    pub(crate) fn new_client(secret_key: SecretKey, remote_id: PeerId) -> Result<Self, ECIESError> {
        Ok(Self { ecies: ECIES::new_client(secret_key, remote_id)?, state: ECIESState::Auth })
    }
}

impl Decoder for ECIESCodec {
    type Item = IngressECIESValue;
    type Error = ECIESError;

    #[instrument(level = "trace", skip_all, fields(peer=?self.ecies.remote_id, state=?self.state))]
    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        loop {
            match self.state {
                ECIESState::Auth => {
                    trace!("parsing auth");
                    if buf.len() < 2 {
                        return Ok(None)
                    }

                    let payload_size = u16::from_be_bytes([buf[0], buf[1]]) as usize;
                    let total_size = payload_size + 2;

                    if buf.len() < total_size {
                        trace!("current len {}, need {}", buf.len(), total_size);
                        return Ok(None)
                    }

                    self.ecies.read_auth(&mut buf.split_to(total_size))?;

                    self.state = ECIESState::InitialHeader;
                    return Ok(Some(IngressECIESValue::AuthReceive(self.ecies.remote_id())))
                }
                ECIESState::Ack => {
                    trace!("parsing ack with len {}", buf.len());
                    if buf.len() < 2 {
                        return Ok(None)
                    }

                    let payload_size = u16::from_be_bytes([buf[0], buf[1]]) as usize;
                    let total_size = payload_size + 2;

                    if buf.len() < total_size {
                        trace!("current len {}, need {}", buf.len(), total_size);
                        return Ok(None)
                    }

                    self.ecies.read_ack(&mut buf.split_to(total_size))?;

                    self.state = ECIESState::InitialHeader;
                    return Ok(Some(IngressECIESValue::Ack))
                }
                ECIESState::InitialHeader => {
                    if buf.len() < ECIES::header_len() {
                        trace!("current len {}, need {}", buf.len(), ECIES::header_len());
                        return Ok(None)
                    }

                    let body_size =
                        self.ecies.read_header(&mut buf.split_to(ECIES::header_len()))?;

                    if body_size > MAX_INITIAL_HANDSHAKE_SIZE {
                        trace!(?body_size, max=?MAX_INITIAL_HANDSHAKE_SIZE, "Header exceeds max initial handshake size");
                        return Err(ECIESErrorImpl::InitialHeaderBodyTooLarge {
                            body_size,
                            max_body_size: MAX_INITIAL_HANDSHAKE_SIZE,
                        }
                        .into())
                    }

                    self.state = ECIESState::Body;
                }
                ECIESState::Header => {
                    if buf.len() < ECIES::header_len() {
                        trace!("current len {}, need {}", buf.len(), ECIES::header_len());
                        return Ok(None)
                    }

                    self.ecies.read_header(&mut buf.split_to(ECIES::header_len()))?;

                    self.state = ECIESState::Body;
                }
                ECIESState::Body => {
                    if buf.len() < self.ecies.body_len() {
                        return Ok(None)
                    }

                    let mut data = buf.split_to(self.ecies.body_len());
                    let mut ret = BytesMut::new();
                    ret.extend_from_slice(self.ecies.read_body(&mut data)?);

                    self.state = ECIESState::Header;
                    return Ok(Some(IngressECIESValue::Message(ret)))
                }
            }
        }
    }
}

impl Encoder<EgressECIESValue> for ECIESCodec {
    type Error = io::Error;

    #[instrument(level = "trace", skip(self, buf), fields(peer=?self.ecies.remote_id, state=?self.state))]
    fn encode(&mut self, item: EgressECIESValue, buf: &mut BytesMut) -> Result<(), Self::Error> {
        match item {
            EgressECIESValue::Auth => {
                self.state = ECIESState::Ack;
                self.ecies.write_auth(buf);
                Ok(())
            }
            EgressECIESValue::Ack => {
                self.state = ECIESState::InitialHeader;
                self.ecies.write_ack(buf);
                Ok(())
            }
            EgressECIESValue::Message(data) => {
                self.ecies.write_header(buf, data.len());
                self.ecies.write_body(buf, &data);
                Ok(())
            }
        }
    }
}