1use crate::codec::UserError;
2use crate::frame::{Reason, StreamId};
3use crate::{client, server};
4
5use crate::frame::DEFAULT_INITIAL_WINDOW_SIZE;
6use crate::proto::*;
7
8use bytes::Bytes;
9use futures_core::Stream;
10use std::io;
11use std::marker::PhantomData;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15use tokio::io::AsyncRead;
16
17#[derive(Debug)]
19pub(crate) struct Connection<T, P, B: Buf = Bytes>
20where
21 P: Peer,
22{
23 codec: Codec<T, Prioritized<B>>,
25
26 inner: ConnectionInner<P, B>,
27}
28
29#[derive(Debug)]
32struct ConnectionInner<P, B: Buf = Bytes>
33where
34 P: Peer,
35{
36 state: State,
38
39 error: Option<frame::GoAway>,
44
45 go_away: GoAway,
47
48 ping_pong: PingPong,
50
51 settings: Settings,
53
54 streams: Streams<B, P>,
56
57 span: tracing::Span,
59
60 _phantom: PhantomData<P>,
62}
63
64struct DynConnection<'a, B: Buf = Bytes> {
65 state: &'a mut State,
66
67 go_away: &'a mut GoAway,
68
69 streams: DynStreams<'a, B>,
70
71 error: &'a mut Option<frame::GoAway>,
72
73 ping_pong: &'a mut PingPong,
74}
75
76#[derive(Debug, Clone)]
77pub(crate) struct Config {
78 pub next_stream_id: StreamId,
79 pub initial_max_send_streams: usize,
80 pub max_send_buffer_size: usize,
81 pub reset_stream_duration: Duration,
82 pub reset_stream_max: usize,
83 pub remote_reset_stream_max: usize,
84 pub local_error_reset_streams_max: Option<usize>,
85 pub settings: frame::Settings,
86}
87
88#[derive(Debug)]
89enum State {
90 Open,
92
93 Closing(Reason, Initiator),
95
96 Closed(Reason, Initiator),
98}
99
100impl<T, P, B> Connection<T, P, B>
101where
102 T: AsyncRead + AsyncWrite + Unpin,
103 P: Peer,
104 B: Buf,
105{
106 pub fn new(codec: Codec<T, Prioritized<B>>, config: Config) -> Connection<T, P, B> {
107 fn streams_config(config: &Config) -> streams::Config {
108 streams::Config {
109 initial_max_send_streams: config.initial_max_send_streams,
110 local_max_buffer_size: config.max_send_buffer_size,
111 local_next_stream_id: config.next_stream_id,
112 local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
113 extended_connect_protocol_enabled: config
114 .settings
115 .is_extended_connect_protocol_enabled()
116 .unwrap_or(false),
117 local_reset_duration: config.reset_stream_duration,
118 local_reset_max: config.reset_stream_max,
119 remote_reset_max: config.remote_reset_stream_max,
120 remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
121 remote_max_initiated: config
122 .settings
123 .max_concurrent_streams()
124 .map(|max| max as usize),
125 local_max_error_reset_streams: config.local_error_reset_streams_max,
126 }
127 }
128 let streams = Streams::new(streams_config(&config));
129 Connection {
130 codec,
131 inner: ConnectionInner {
132 state: State::Open,
133 error: None,
134 go_away: GoAway::new(),
135 ping_pong: PingPong::new(),
136 settings: Settings::new(config.settings),
137 streams,
138 span: tracing::debug_span!("Connection", peer = %P::NAME),
139 _phantom: PhantomData,
140 },
141 }
142 }
143
144 pub(crate) fn set_target_window_size(&mut self, size: WindowSize) {
146 let _res = self.inner.streams.set_target_connection_window_size(size);
147 debug_assert!(_res.is_ok());
149 }
150
151 pub(crate) fn set_initial_window_size(&mut self, size: WindowSize) -> Result<(), UserError> {
153 let mut settings = frame::Settings::default();
154 settings.set_initial_window_size(Some(size));
155 self.inner.settings.send_settings(settings)
156 }
157
158 pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> {
160 let mut settings = frame::Settings::default();
161 settings.set_enable_connect_protocol(Some(1));
162 self.inner.settings.send_settings(settings)
163 }
164
165 pub(crate) fn max_send_streams(&self) -> usize {
168 self.inner.streams.max_send_streams()
169 }
170
171 pub(crate) fn max_recv_streams(&self) -> usize {
174 self.inner.streams.max_recv_streams()
175 }
176
177 #[cfg(feature = "unstable")]
178 pub fn num_wired_streams(&self) -> usize {
179 self.inner.streams.num_wired_streams()
180 }
181
182 fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
187 let _e = self.inner.span.enter();
188 let span = tracing::trace_span!("poll_ready");
189 let _e = span.enter();
190 ready!(self.inner.ping_pong.send_pending_pong(cx, &mut self.codec))?;
192 ready!(self.inner.ping_pong.send_pending_ping(cx, &mut self.codec))?;
193 ready!(self
194 .inner
195 .settings
196 .poll_send(cx, &mut self.codec, &mut self.inner.streams))?;
197 ready!(self.inner.streams.send_pending_refusal(cx, &mut self.codec))?;
198
199 Poll::Ready(Ok(()))
200 }
201
202 fn poll_go_away(&mut self, cx: &mut Context) -> Poll<Option<io::Result<Reason>>> {
207 self.inner.go_away.send_pending_go_away(cx, &mut self.codec)
208 }
209
210 pub fn go_away_from_user(&mut self, e: Reason) {
211 self.inner.as_dyn().go_away_from_user(e)
212 }
213
214 fn take_error(&mut self, ours: Reason, initiator: Initiator) -> Result<(), Error> {
215 let (debug_data, theirs) = self
216 .inner
217 .error
218 .take()
219 .as_ref()
220 .map_or((Bytes::new(), Reason::NO_ERROR), |frame| {
221 (frame.debug_data().clone(), frame.reason())
222 });
223
224 match (ours, theirs) {
225 (Reason::NO_ERROR, Reason::NO_ERROR) => Ok(()),
226 (ours, Reason::NO_ERROR) => Err(Error::GoAway(Bytes::new(), ours, initiator)),
227 (_, theirs) => Err(Error::remote_go_away(debug_data, theirs)),
232 }
233 }
234
235 pub fn maybe_close_connection_if_no_streams(&mut self) {
238 if !self.inner.streams.has_streams_or_other_references() {
241 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
242 }
243 }
244
245 pub fn has_streams(&self) -> bool {
247 self.inner.streams.has_streams()
248 }
249
250 pub fn has_streams_or_other_references(&self) -> bool {
252 self.inner.streams.has_streams_or_other_references()
255 }
256
257 pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> {
258 self.inner.ping_pong.take_user_pings()
259 }
260
261 pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
263 let span = self.inner.span.clone();
268 let _e = span.enter();
269 let span = tracing::trace_span!("poll");
270 let _e = span.enter();
271
272 loop {
273 tracing::trace!(connection.state = ?self.inner.state);
274 match self.inner.state {
276 State::Open => {
278 let result = match self.poll2(cx) {
279 Poll::Ready(result) => result,
280 Poll::Pending => {
282 ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?;
286
287 if (self.inner.error.is_some()
288 || self.inner.go_away.should_close_on_idle())
289 && !self.inner.streams.has_streams()
290 {
291 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
292 continue;
293 }
294
295 return Poll::Pending;
296 }
297 };
298
299 self.inner.as_dyn().handle_poll2_result(result)?
300 }
301 State::Closing(reason, initiator) => {
302 tracing::trace!("connection closing after flush");
303 ready!(self.codec.shutdown(cx))?;
305
306 self.inner.state = State::Closed(reason, initiator);
308 }
309 State::Closed(reason, initiator) => {
310 return Poll::Ready(self.take_error(reason, initiator));
311 }
312 }
313 }
314 }
315
316 fn poll2(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
317 self.clear_expired_reset_streams();
321
322 loop {
323 if let Some(reason) = ready!(self.poll_go_away(cx)?) {
329 if self.inner.go_away.should_close_now() {
330 if self.inner.go_away.is_user_initiated() {
331 return Poll::Ready(Ok(()));
334 } else {
335 return Poll::Ready(Err(Error::library_go_away(reason)));
336 }
337 }
338 debug_assert_eq!(
340 reason,
341 Reason::NO_ERROR,
342 "graceful GOAWAY should be NO_ERROR"
343 );
344 }
345 ready!(self.poll_ready(cx))?;
346
347 match self
348 .inner
349 .as_dyn()
350 .recv_frame(ready!(Pin::new(&mut self.codec).poll_next(cx)?))?
351 {
352 ReceivedFrame::Settings(frame) => {
353 self.inner.settings.recv_settings(
354 frame,
355 &mut self.codec,
356 &mut self.inner.streams,
357 )?;
358 }
359 ReceivedFrame::Continue => (),
360 ReceivedFrame::Done => {
361 return Poll::Ready(Ok(()));
362 }
363 }
364 }
365 }
366
367 fn clear_expired_reset_streams(&mut self) {
368 self.inner.streams.clear_expired_reset_streams();
369 }
370}
371
372impl<P, B> ConnectionInner<P, B>
373where
374 P: Peer,
375 B: Buf,
376{
377 fn as_dyn(&mut self) -> DynConnection<'_, B> {
378 let ConnectionInner {
379 state,
380 go_away,
381 streams,
382 error,
383 ping_pong,
384 ..
385 } = self;
386 let streams = streams.as_dyn();
387 DynConnection {
388 state,
389 go_away,
390 streams,
391 error,
392 ping_pong,
393 }
394 }
395}
396
397impl<B> DynConnection<'_, B>
398where
399 B: Buf,
400{
401 fn go_away(&mut self, id: StreamId, e: Reason) {
402 let frame = frame::GoAway::new(id, e);
403 self.streams.send_go_away(id);
404 self.go_away.go_away(frame);
405 }
406
407 fn go_away_now(&mut self, e: Reason) {
408 let last_processed_id = self.streams.last_processed_id();
409 let frame = frame::GoAway::new(last_processed_id, e);
410 self.go_away.go_away_now(frame);
411 }
412
413 fn go_away_now_data(&mut self, e: Reason, data: Bytes) {
414 let last_processed_id = self.streams.last_processed_id();
415 let frame = frame::GoAway::with_debug_data(last_processed_id, e, data);
416 self.go_away.go_away_now(frame);
417 }
418
419 fn go_away_from_user(&mut self, e: Reason) {
420 let last_processed_id = self.streams.last_processed_id();
421 let frame = frame::GoAway::new(last_processed_id, e);
422 self.go_away.go_away_from_user(frame);
423
424 self.streams.handle_error(Error::user_go_away(e));
426 }
427
428 fn handle_poll2_result(&mut self, result: Result<(), Error>) -> Result<(), Error> {
429 match result {
430 Ok(()) => {
432 *self.state = State::Closing(Reason::NO_ERROR, Initiator::Library);
433 Ok(())
434 }
435 Err(Error::GoAway(debug_data, reason, initiator)) => {
439 let e = Error::GoAway(debug_data.clone(), reason, initiator);
440 tracing::debug!(error = ?e, "Connection::poll; connection error");
441
442 if self
445 .go_away
446 .going_away()
447 .map_or(false, |frame| frame.reason() == reason)
448 {
449 tracing::trace!(" -> already going away");
450 *self.state = State::Closing(reason, initiator);
451 return Ok(());
452 }
453
454 self.streams.handle_error(e);
456 self.go_away_now_data(reason, debug_data);
457 Ok(())
458 }
459 Err(Error::Reset(id, reason, initiator)) => {
463 debug_assert_eq!(initiator, Initiator::Library);
464 tracing::trace!(?id, ?reason, "stream error");
465 self.streams.send_reset(id, reason);
466 Ok(())
467 }
468 Err(Error::Io(kind, inner)) => {
473 tracing::debug!(error = ?kind, "Connection::poll; IO error");
474 let e = Error::Io(kind, inner);
475
476 self.streams.handle_error(e.clone());
478
479 if self.streams.is_server()
486 && self.streams.is_buffer_empty()
487 && matches!(kind, io::ErrorKind::UnexpectedEof)
488 {
489 *self.state = State::Closed(Reason::NO_ERROR, Initiator::Library);
490 return Ok(());
491 }
492
493 Err(e)
495 }
496 }
497 }
498
499 fn recv_frame(&mut self, frame: Option<Frame>) -> Result<ReceivedFrame, Error> {
500 use crate::frame::Frame::*;
501 match frame {
502 Some(Headers(frame)) => {
503 tracing::trace!(?frame, "recv HEADERS");
504 self.streams.recv_headers(frame)?;
505 }
506 Some(Data(frame)) => {
507 tracing::trace!(?frame, "recv DATA");
508 self.streams.recv_data(frame)?;
509 }
510 Some(Reset(frame)) => {
511 tracing::trace!(?frame, "recv RST_STREAM");
512 self.streams.recv_reset(frame)?;
513 }
514 Some(PushPromise(frame)) => {
515 tracing::trace!(?frame, "recv PUSH_PROMISE");
516 self.streams.recv_push_promise(frame)?;
517 }
518 Some(Settings(frame)) => {
519 tracing::trace!(?frame, "recv SETTINGS");
520 return Ok(ReceivedFrame::Settings(frame));
521 }
522 Some(GoAway(frame)) => {
523 tracing::trace!(?frame, "recv GOAWAY");
524 self.streams.recv_go_away(&frame)?;
529 *self.error = Some(frame);
530 }
531 Some(Ping(frame)) => {
532 tracing::trace!(?frame, "recv PING");
533 let status = self.ping_pong.recv_ping(frame);
534 if status.is_shutdown() {
535 assert!(
536 self.go_away.is_going_away(),
537 "received unexpected shutdown ping"
538 );
539
540 let last_processed_id = self.streams.last_processed_id();
541 self.go_away(last_processed_id, Reason::NO_ERROR);
542 }
543 }
544 Some(WindowUpdate(frame)) => {
545 tracing::trace!(?frame, "recv WINDOW_UPDATE");
546 self.streams.recv_window_update(frame)?;
547 }
548 Some(Priority(frame)) => {
549 tracing::trace!(?frame, "recv PRIORITY");
550 }
552 None => {
553 tracing::trace!("codec closed");
554 self.streams.recv_eof(false).expect("mutex poisoned");
555 return Ok(ReceivedFrame::Done);
556 }
557 }
558 Ok(ReceivedFrame::Continue)
559 }
560}
561
562enum ReceivedFrame {
563 Settings(frame::Settings),
564 Continue,
565 Done,
566}
567
568impl<T, B> Connection<T, client::Peer, B>
569where
570 T: AsyncRead + AsyncWrite,
571 B: Buf,
572{
573 pub(crate) fn streams(&self) -> &Streams<B, client::Peer> {
574 &self.inner.streams
575 }
576}
577
578impl<T, B> Connection<T, server::Peer, B>
579where
580 T: AsyncRead + AsyncWrite + Unpin,
581 B: Buf,
582{
583 pub fn next_incoming(&mut self) -> Option<StreamRef<B>> {
584 self.inner.streams.next_incoming()
585 }
586
587 pub fn go_away_gracefully(&mut self) {
589 if self.inner.go_away.is_going_away() {
590 return;
592 }
593
594 self.inner.as_dyn().go_away(StreamId::MAX, Reason::NO_ERROR);
606
607 self.inner.ping_pong.ping_shutdown();
610 }
611}
612
613impl<T, P, B> Drop for Connection<T, P, B>
614where
615 P: Peer,
616 B: Buf,
617{
618 fn drop(&mut self) {
619 let _ = self.inner.streams.recv_eof(true);
621 }
622}