1use super::compression::{
2    compress, CompressionEncoding, CompressionSettings, SingleMessageCompressionOverride,
3};
4use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
5use crate::Status;
6use bytes::{BufMut, Bytes, BytesMut};
7use http::HeaderMap;
8use http_body::{Body, Frame};
9use pin_project::pin_project;
10use std::{
11    pin::Pin,
12    task::{ready, Context, Poll},
13};
14use tokio_stream::{adapters::Fuse, Stream, StreamExt};
15
16#[pin_project(project = EncodedBytesProj)]
22#[derive(Debug)]
23struct EncodedBytes<T, U> {
24    #[pin]
25    source: Fuse<U>,
26    encoder: T,
27    compression_encoding: Option<CompressionEncoding>,
28    max_message_size: Option<usize>,
29    buf: BytesMut,
30    uncompression_buf: BytesMut,
31    error: Option<Status>,
32}
33
34impl<T: Encoder, U: Stream> EncodedBytes<T, U> {
35    fn new(
36        encoder: T,
37        source: U,
38        compression_encoding: Option<CompressionEncoding>,
39        compression_override: SingleMessageCompressionOverride,
40        max_message_size: Option<usize>,
41    ) -> Self {
42        let buffer_settings = encoder.buffer_settings();
43        let buf = BytesMut::with_capacity(buffer_settings.buffer_size);
44
45        let compression_encoding =
46            if compression_override == SingleMessageCompressionOverride::Disable {
47                None
48            } else {
49                compression_encoding
50            };
51
52        let uncompression_buf = if compression_encoding.is_some() {
53            BytesMut::with_capacity(buffer_settings.buffer_size)
54        } else {
55            BytesMut::new()
56        };
57
58        Self {
59            source: source.fuse(),
60            encoder,
61            compression_encoding,
62            max_message_size,
63            buf,
64            uncompression_buf,
65            error: None,
66        }
67    }
68}
69
70impl<T, U> Stream for EncodedBytes<T, U>
71where
72    T: Encoder<Error = Status>,
73    U: Stream<Item = Result<T::Item, Status>>,
74{
75    type Item = Result<Bytes, Status>;
76
77    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78        let EncodedBytesProj {
79            mut source,
80            encoder,
81            compression_encoding,
82            max_message_size,
83            buf,
84            uncompression_buf,
85            error,
86        } = self.project();
87        let buffer_settings = encoder.buffer_settings();
88
89        if let Some(status) = error.take() {
90            return Poll::Ready(Some(Err(status)));
91        }
92
93        loop {
94            match source.as_mut().poll_next(cx) {
95                Poll::Pending if buf.is_empty() => {
96                    return Poll::Pending;
97                }
98                Poll::Ready(None) if buf.is_empty() => {
99                    return Poll::Ready(None);
100                }
101                Poll::Pending | Poll::Ready(None) => {
102                    return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
103                }
104                Poll::Ready(Some(Ok(item))) => {
105                    if let Err(status) = encode_item(
106                        encoder,
107                        buf,
108                        uncompression_buf,
109                        *compression_encoding,
110                        *max_message_size,
111                        buffer_settings,
112                        item,
113                    ) {
114                        return Poll::Ready(Some(Err(status)));
115                    }
116
117                    if buf.len() >= buffer_settings.yield_threshold {
118                        return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
119                    }
120                }
121                Poll::Ready(Some(Err(status))) => {
122                    if buf.is_empty() {
123                        return Poll::Ready(Some(Err(status)));
124                    }
125                    *error = Some(status);
126                    return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
127                }
128            }
129        }
130    }
131}
132
133fn encode_item<T>(
134    encoder: &mut T,
135    buf: &mut BytesMut,
136    uncompression_buf: &mut BytesMut,
137    compression_encoding: Option<CompressionEncoding>,
138    max_message_size: Option<usize>,
139    buffer_settings: BufferSettings,
140    item: T::Item,
141) -> Result<(), Status>
142where
143    T: Encoder<Error = Status>,
144{
145    let offset = buf.len();
146
147    buf.reserve(HEADER_SIZE);
148    unsafe {
149        buf.advance_mut(HEADER_SIZE);
150    }
151
152    if let Some(encoding) = compression_encoding {
153        uncompression_buf.clear();
154
155        encoder
156            .encode(item, &mut EncodeBuf::new(uncompression_buf))
157            .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
158
159        let uncompressed_len = uncompression_buf.len();
160
161        compress(
162            CompressionSettings {
163                encoding,
164                buffer_growth_interval: buffer_settings.buffer_size,
165            },
166            uncompression_buf,
167            buf,
168            uncompressed_len,
169        )
170        .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
171    } else {
172        encoder
173            .encode(item, &mut EncodeBuf::new(buf))
174            .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
175    }
176
177    finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
179}
180
181fn finish_encoding(
182    compression_encoding: Option<CompressionEncoding>,
183    max_message_size: Option<usize>,
184    buf: &mut [u8],
185) -> Result<(), Status> {
186    let len = buf.len() - HEADER_SIZE;
187    let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
188    if len > limit {
189        return Err(Status::out_of_range(format!(
190            "Error, encoded message length too large: found {} bytes, the limit is: {} bytes",
191            len, limit
192        )));
193    }
194
195    if len > u32::MAX as usize {
196        return Err(Status::resource_exhausted(format!(
197            "Cannot return body with more than 4GB of data but got {len} bytes"
198        )));
199    }
200    {
201        let mut buf = &mut buf[..HEADER_SIZE];
202        buf.put_u8(compression_encoding.is_some() as u8);
203        buf.put_u32(len as u32);
204    }
205
206    Ok(())
207}
208
209#[derive(Debug)]
210enum Role {
211    Client,
212    Server,
213}
214
215#[pin_project]
217#[derive(Debug)]
218pub struct EncodeBody<T, U> {
219    #[pin]
220    inner: EncodedBytes<T, U>,
221    state: EncodeState,
222}
223
224#[derive(Debug)]
225struct EncodeState {
226    error: Option<Status>,
227    role: Role,
228    is_end_stream: bool,
229}
230
231impl<T: Encoder, U: Stream> EncodeBody<T, U> {
232    pub fn new_client(
235        encoder: T,
236        source: U,
237        compression_encoding: Option<CompressionEncoding>,
238        max_message_size: Option<usize>,
239    ) -> Self {
240        Self {
241            inner: EncodedBytes::new(
242                encoder,
243                source,
244                compression_encoding,
245                SingleMessageCompressionOverride::default(),
246                max_message_size,
247            ),
248            state: EncodeState {
249                error: None,
250                role: Role::Client,
251                is_end_stream: false,
252            },
253        }
254    }
255
256    pub fn new_server(
259        encoder: T,
260        source: U,
261        compression_encoding: Option<CompressionEncoding>,
262        compression_override: SingleMessageCompressionOverride,
263        max_message_size: Option<usize>,
264    ) -> Self {
265        Self {
266            inner: EncodedBytes::new(
267                encoder,
268                source,
269                compression_encoding,
270                compression_override,
271                max_message_size,
272            ),
273            state: EncodeState {
274                error: None,
275                role: Role::Server,
276                is_end_stream: false,
277            },
278        }
279    }
280}
281
282impl EncodeState {
283    fn trailers(&mut self) -> Option<Result<HeaderMap, Status>> {
284        match self.role {
285            Role::Client => None,
286            Role::Server => {
287                if self.is_end_stream {
288                    return None;
289                }
290
291                self.is_end_stream = true;
292                let status = if let Some(status) = self.error.take() {
293                    status
294                } else {
295                    Status::ok("")
296                };
297                Some(status.to_header_map())
298            }
299        }
300    }
301}
302
303impl<T, U> Body for EncodeBody<T, U>
304where
305    T: Encoder<Error = Status>,
306    U: Stream<Item = Result<T::Item, Status>>,
307{
308    type Data = Bytes;
309    type Error = Status;
310
311    fn is_end_stream(&self) -> bool {
312        self.state.is_end_stream
313    }
314
315    fn poll_frame(
316        self: Pin<&mut Self>,
317        cx: &mut Context<'_>,
318    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
319        let self_proj = self.project();
320        match ready!(self_proj.inner.poll_next(cx)) {
321            Some(Ok(d)) => Some(Ok(Frame::data(d))).into(),
322            Some(Err(status)) => match self_proj.state.role {
323                Role::Client => Some(Err(status)).into(),
324                Role::Server => {
325                    self_proj.state.is_end_stream = true;
326                    Some(Ok(Frame::trailers(status.to_header_map()?))).into()
327                }
328            },
329            None => self_proj
330                .state
331                .trailers()
332                .map(|t| t.map(Frame::trailers))
333                .into(),
334        }
335    }
336}