1#![warn(missing_docs)]
9
10use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write};
11use hyper_util::client::legacy::connect::{Connected, Connection};
12use pin_project_lite::pin_project;
13use std::future::Future;
14use std::io;
15use std::pin::Pin;
16use std::task::{ready, Context, Poll};
17use std::time::Duration;
18use tokio::time::{sleep_until, Instant, Sleep};
19
20pin_project! {
21 #[derive(Debug)]
22 struct TimeoutState {
23 timeout: Option<Duration>,
24 #[pin]
25 cur: Sleep,
26 active: bool,
27 }
28}
29
30impl TimeoutState {
31 #[inline]
32 fn new() -> TimeoutState {
33 TimeoutState {
34 timeout: None,
35 cur: sleep_until(Instant::now()),
36 active: false,
37 }
38 }
39
40 #[inline]
41 fn timeout(&self) -> Option<Duration> {
42 self.timeout
43 }
44
45 #[inline]
46 fn set_timeout(&mut self, timeout: Option<Duration>) {
47 self.timeout = timeout;
49 }
50
51 #[inline]
52 fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
53 *self.as_mut().project().timeout = timeout;
54 self.reset();
55 }
56
57 #[inline]
58 fn reset(self: Pin<&mut Self>) {
59 let this = self.project();
60
61 if *this.active {
62 *this.active = false;
63 this.cur.reset(Instant::now());
64 }
65 }
66
67 #[inline]
68 fn restart(self: Pin<&mut Self>) {
69 let this = self.project();
70
71 if *this.active {
72 let timeout = match this.timeout {
73 Some(timeout) => *timeout,
74 None => return,
75 };
76
77 this.cur.reset(Instant::now() + timeout);
78 }
79 }
80
81 #[inline]
82 fn poll_check(self: Pin<&mut Self>, cx: &mut Context) -> io::Result<()> {
83 let mut this = self.project();
84
85 let timeout = match this.timeout {
86 Some(timeout) => *timeout,
87 None => return Ok(()),
88 };
89
90 if !*this.active {
91 this.cur.as_mut().reset(Instant::now() + timeout);
92 *this.active = true;
93 }
94
95 match this.cur.poll(cx) {
96 Poll::Ready(()) => Err(io::Error::from(io::ErrorKind::TimedOut)),
97 Poll::Pending => Ok(()),
98 }
99 }
100}
101
102pin_project! {
103 #[derive(Debug)]
105 pub struct TimeoutReader<R> {
106 #[pin]
107 reader: R,
108 #[pin]
109 state: TimeoutState,
110 reset_on_write: bool,
111 }
112}
113
114impl<R> TimeoutReader<R>
115where
116 R: Read,
117{
118 pub fn new(reader: R) -> TimeoutReader<R> {
122 TimeoutReader {
123 reader,
124 state: TimeoutState::new(),
125 reset_on_write: false,
126 }
127 }
128
129 pub fn timeout(&self) -> Option<Duration> {
131 self.state.timeout()
132 }
133
134 pub fn set_timeout(&mut self, timeout: Option<Duration>) {
139 self.state.set_timeout(timeout);
140 }
141
142 pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
147 self.project().state.set_timeout_pinned(timeout);
148 }
149
150 pub fn get_ref(&self) -> &R {
152 &self.reader
153 }
154
155 pub fn get_mut(&mut self) -> &mut R {
157 &mut self.reader
158 }
159
160 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
162 self.project().reader
163 }
164
165 pub fn into_inner(self) -> R {
167 self.reader
168 }
169}
170
171impl<R> TimeoutReader<R>
172where
173 R: Read + Write,
174{
175 pub fn set_reset_on_write(&mut self, reset: bool) {
181 self.reset_on_write = reset
182 }
183}
184
185impl<R> Read for TimeoutReader<R>
186where
187 R: Read,
188{
189 fn poll_read(
190 self: Pin<&mut Self>,
191 cx: &mut Context,
192 buf: ReadBufCursor,
193 ) -> Poll<Result<(), io::Error>> {
194 let this = self.project();
195 let r = this.reader.poll_read(cx, buf);
196 match r {
197 Poll::Pending => this.state.poll_check(cx)?,
198 _ => this.state.reset(),
199 }
200 r
201 }
202}
203
204impl<R> Write for TimeoutReader<R>
205where
206 R: Write,
207{
208 fn poll_write(
209 self: Pin<&mut Self>,
210 cx: &mut Context,
211 buf: &[u8],
212 ) -> Poll<Result<usize, io::Error>> {
213 let this = self.project();
214 let r = this.reader.poll_write(cx, buf);
215 if *this.reset_on_write && r.is_ready() {
216 this.state.restart();
217 }
218 r
219 }
220
221 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
222 let this = self.project();
223 let r = this.reader.poll_flush(cx);
224 if *this.reset_on_write && r.is_ready() {
225 this.state.restart();
226 }
227 r
228 }
229
230 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
231 let this = self.project();
232 let r = this.reader.poll_shutdown(cx);
233 if *this.reset_on_write && r.is_ready() {
234 this.state.restart();
235 }
236 r
237 }
238
239 fn poll_write_vectored(
240 self: Pin<&mut Self>,
241 cx: &mut Context,
242 bufs: &[io::IoSlice],
243 ) -> Poll<io::Result<usize>> {
244 let this = self.project();
245 let r = this.reader.poll_write_vectored(cx, bufs);
246 if *this.reset_on_write && r.is_ready() {
247 this.state.restart();
248 }
249 r
250 }
251
252 fn is_write_vectored(&self) -> bool {
253 self.reader.is_write_vectored()
254 }
255}
256
257pin_project! {
258 #[derive(Debug)]
260 pub struct TimeoutWriter<W> {
261 #[pin]
262 writer: W,
263 #[pin]
264 state: TimeoutState,
265 }
266}
267
268impl<W> TimeoutWriter<W>
269where
270 W: Write,
271{
272 pub fn new(writer: W) -> TimeoutWriter<W> {
276 TimeoutWriter {
277 writer,
278 state: TimeoutState::new(),
279 }
280 }
281
282 pub fn timeout(&self) -> Option<Duration> {
284 self.state.timeout()
285 }
286
287 pub fn set_timeout(&mut self, timeout: Option<Duration>) {
292 self.state.set_timeout(timeout);
293 }
294
295 pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
300 self.project().state.set_timeout_pinned(timeout);
301 }
302
303 pub fn get_ref(&self) -> &W {
305 &self.writer
306 }
307
308 pub fn get_mut(&mut self) -> &mut W {
310 &mut self.writer
311 }
312
313 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
315 self.project().writer
316 }
317
318 pub fn into_inner(self) -> W {
320 self.writer
321 }
322}
323
324impl<W> Write for TimeoutWriter<W>
325where
326 W: Write,
327{
328 fn poll_write(
329 self: Pin<&mut Self>,
330 cx: &mut Context,
331 buf: &[u8],
332 ) -> Poll<Result<usize, io::Error>> {
333 let this = self.project();
334 let r = this.writer.poll_write(cx, buf);
335 match r {
336 Poll::Pending => this.state.poll_check(cx)?,
337 _ => this.state.reset(),
338 }
339 r
340 }
341
342 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
343 let this = self.project();
344 let r = this.writer.poll_flush(cx);
345 match r {
346 Poll::Pending => this.state.poll_check(cx)?,
347 _ => this.state.reset(),
348 }
349 r
350 }
351
352 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
353 let this = self.project();
354 let r = this.writer.poll_shutdown(cx);
355 match r {
356 Poll::Pending => this.state.poll_check(cx)?,
357 _ => this.state.reset(),
358 }
359 r
360 }
361
362 fn poll_write_vectored(
363 self: Pin<&mut Self>,
364 cx: &mut Context,
365 bufs: &[io::IoSlice],
366 ) -> Poll<io::Result<usize>> {
367 let this = self.project();
368 let r = this.writer.poll_write_vectored(cx, bufs);
369 match r {
370 Poll::Pending => this.state.poll_check(cx)?,
371 _ => this.state.reset(),
372 }
373 r
374 }
375
376 fn is_write_vectored(&self) -> bool {
377 self.writer.is_write_vectored()
378 }
379}
380
381impl<W> Read for TimeoutWriter<W>
382where
383 W: Read,
384{
385 fn poll_read(
386 self: Pin<&mut Self>,
387 cx: &mut Context,
388 buf: ReadBufCursor,
389 ) -> Poll<Result<(), io::Error>> {
390 self.project().writer.poll_read(cx, buf)
391 }
392}
393
394pin_project! {
395 #[derive(Debug)]
397 pub struct TimeoutStream<S> {
398 #[pin]
399 stream: TimeoutReader<TimeoutWriter<S>>
400 }
401}
402
403impl<S> TimeoutStream<S>
404where
405 S: Read + Write,
406{
407 pub fn new(stream: S) -> TimeoutStream<S> {
411 let writer = TimeoutWriter::new(stream);
412 let stream = TimeoutReader::new(writer);
413 TimeoutStream { stream }
414 }
415
416 pub fn read_timeout(&self) -> Option<Duration> {
418 self.stream.timeout()
419 }
420
421 pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
426 self.stream.set_timeout(timeout)
427 }
428
429 pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
434 self.project().stream.set_timeout_pinned(timeout)
435 }
436
437 pub fn write_timeout(&self) -> Option<Duration> {
439 self.stream.get_ref().timeout()
440 }
441
442 pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
447 self.stream.get_mut().set_timeout(timeout)
448 }
449
450 pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
455 self.project()
456 .stream
457 .get_pin_mut()
458 .set_timeout_pinned(timeout)
459 }
460
461 pub fn set_reset_reader_on_write(&mut self, reset: bool) {
467 self.stream.set_reset_on_write(reset);
468 }
469
470 pub fn get_ref(&self) -> &S {
472 self.stream.get_ref().get_ref()
473 }
474
475 pub fn get_mut(&mut self) -> &mut S {
477 self.stream.get_mut().get_mut()
478 }
479
480 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
482 self.project().stream.get_pin_mut().get_pin_mut()
483 }
484
485 pub fn into_inner(self) -> S {
487 self.stream.into_inner().into_inner()
488 }
489}
490
491impl<S> Read for TimeoutStream<S>
492where
493 S: Read + Write,
494{
495 fn poll_read(
496 self: Pin<&mut Self>,
497 cx: &mut Context,
498 buf: ReadBufCursor,
499 ) -> Poll<Result<(), io::Error>> {
500 self.project().stream.poll_read(cx, buf)
501 }
502}
503
504impl<S> Write for TimeoutStream<S>
505where
506 S: Read + Write,
507{
508 fn poll_write(
509 self: Pin<&mut Self>,
510 cx: &mut Context,
511 buf: &[u8],
512 ) -> Poll<Result<usize, io::Error>> {
513 self.project().stream.poll_write(cx, buf)
514 }
515
516 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
517 self.project().stream.poll_flush(cx)
518 }
519
520 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
521 self.project().stream.poll_shutdown(cx)
522 }
523
524 fn poll_write_vectored(
525 self: Pin<&mut Self>,
526 cx: &mut Context,
527 bufs: &[io::IoSlice],
528 ) -> Poll<io::Result<usize>> {
529 self.project().stream.poll_write_vectored(cx, bufs)
530 }
531
532 fn is_write_vectored(&self) -> bool {
533 self.stream.is_write_vectored()
534 }
535}
536
537impl<S> Connection for TimeoutStream<S>
538where
539 S: Read + Write + Connection + Unpin,
540{
541 fn connected(&self) -> Connected {
542 self.get_ref().connected()
543 }
544}
545
546impl<S> Connection for Pin<Box<TimeoutStream<S>>>
547where
548 S: Read + Write + Connection + Unpin,
549{
550 fn connected(&self) -> Connected {
551 self.get_ref().connected()
552 }
553}
554
555pin_project! {
556 struct ReadFut<'a, R: ?Sized> {
559 reader: &'a mut R,
560 buf: &'a mut [u8],
561 }
562}
563
564#[cfg(test)]
570fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> ReadFut<'a, R>
571where
572 R: Read + Unpin + ?Sized,
573{
574 ReadFut { reader, buf }
575}
576
577impl<R> Future for ReadFut<'_, R>
578where
579 R: Read + Unpin + ?Sized,
580{
581 type Output = io::Result<usize>;
582
583 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
584 let me = self.project();
585 let mut buf = ReadBuf::new(me.buf);
586 ready!(Pin::new(me.reader).poll_read(cx, buf.unfilled()))?;
587 Poll::Ready(Ok(buf.filled().len()))
588 }
589}
590
591#[cfg(test)]
592trait ReadExt: Read {
593 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self>
596 where
597 Self: Unpin,
598 {
599 read(self, buf)
600 }
601}
602
603pin_project! {
604 struct WriteFut<'a, W: ?Sized> {
606 writer: &'a mut W,
607 buf: &'a [u8],
608 }
609}
610
611#[cfg(test)]
614fn write<'a, W>(writer: &'a mut W, buf: &'a [u8]) -> WriteFut<'a, W>
615where
616 W: Write + Unpin + ?Sized,
617{
618 WriteFut { writer, buf }
619}
620
621impl<W> Future for WriteFut<'_, W>
622where
623 W: Write + Unpin + ?Sized,
624{
625 type Output = io::Result<usize>;
626
627 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
628 let me = self.project();
629 Pin::new(&mut *me.writer).poll_write(cx, me.buf)
630 }
631}
632
633#[cfg(test)]
634trait WriteExt: Write {
635 fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self>
638 where
639 Self: Unpin,
640 {
641 write(self, src)
642 }
643}
644
645#[cfg(test)]
646impl<R> ReadExt for Pin<&mut TimeoutReader<R>>
647where
648 R: Read,
649{
650 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> {
651 read(self, buf)
652 }
653}
654
655#[cfg(test)]
656impl<W> WriteExt for Pin<&mut TimeoutWriter<W>>
657where
658 W: Write,
659{
660 fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> {
661 write(self, src)
662 }
663}
664
665#[cfg(test)]
666impl<S> ReadExt for Pin<&mut TimeoutStream<S>>
667where
668 S: Read + Write,
669{
670 fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFut<'a, Self> {
671 read(self, buf)
672 }
673}
674
675#[cfg(test)]
676impl<S> WriteExt for Pin<&mut TimeoutStream<S>>
677where
678 S: Read + Write,
679{
680 fn write<'a>(&'a mut self, src: &'a [u8]) -> WriteFut<'a, Self> {
681 write(self, src)
682 }
683}
684
685#[cfg(test)]
686mod test {
687 use super::*;
688 use hyper_util::rt::TokioIo;
689 use std::io::Write;
690 use std::net::TcpListener;
691 use std::thread;
692 use tokio::net::TcpStream;
693 use tokio::pin;
694
695 pin_project! {
696 struct DelayStream {
697 #[pin]
698 sleep: Sleep,
699 }
700 }
701
702 impl DelayStream {
703 fn new(until: Instant) -> Self {
704 DelayStream {
705 sleep: sleep_until(until),
706 }
707 }
708 }
709
710 impl Read for DelayStream {
711 fn poll_read(
712 self: Pin<&mut Self>,
713 cx: &mut Context,
714 _buf: ReadBufCursor,
715 ) -> Poll<Result<(), io::Error>> {
716 match self.project().sleep.poll(cx) {
717 Poll::Ready(()) => Poll::Ready(Ok(())),
718 Poll::Pending => Poll::Pending,
719 }
720 }
721 }
722
723 impl hyper::rt::Write for DelayStream {
724 fn poll_write(
725 self: Pin<&mut Self>,
726 cx: &mut Context,
727 buf: &[u8],
728 ) -> Poll<Result<usize, io::Error>> {
729 match self.project().sleep.poll(cx) {
730 Poll::Ready(()) => Poll::Ready(Ok(buf.len())),
731 Poll::Pending => Poll::Pending,
732 }
733 }
734
735 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
736 Poll::Ready(Ok(()))
737 }
738
739 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
740 Poll::Ready(Ok(()))
741 }
742 }
743
744 #[tokio::test]
745 async fn read_timeout() {
746 let reader = DelayStream::new(Instant::now() + Duration::from_millis(500));
747 let mut reader = TimeoutReader::new(reader);
748 reader.set_timeout(Some(Duration::from_millis(100)));
749 pin!(reader);
750
751 let r = reader.read(&mut [0, 1, 2]).await;
752 assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
753 }
754
755 #[tokio::test]
756 async fn read_ok() {
757 let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
758 let mut reader = TimeoutReader::new(reader);
759 reader.set_timeout(Some(Duration::from_millis(500)));
760 pin!(reader);
761
762 reader.read(&mut [0]).await.unwrap();
763 }
764
765 #[tokio::test]
766 async fn write_timeout() {
767 let writer = DelayStream::new(Instant::now() + Duration::from_millis(500));
768 let mut writer = TimeoutWriter::new(writer);
769 writer.set_timeout(Some(Duration::from_millis(100)));
770 pin!(writer);
771
772 let r = writer.write(&[0]).await;
773 assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
774 }
775
776 #[tokio::test]
777 async fn write_ok() {
778 let writer = DelayStream::new(Instant::now() + Duration::from_millis(100));
779 let mut writer = TimeoutWriter::new(writer);
780 writer.set_timeout(Some(Duration::from_millis(500)));
781 pin!(writer);
782
783 writer.write(&[0]).await.unwrap();
784 }
785
786 #[tokio::test]
787 async fn tcp_read() {
788 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
789 let addr = listener.local_addr().unwrap();
790
791 thread::spawn(move || {
792 let mut socket = listener.accept().unwrap().0;
793 thread::sleep(Duration::from_millis(10));
794 socket.write_all(b"f").unwrap();
795 thread::sleep(Duration::from_millis(500));
796 let _ = socket.write_all(b"f"); });
798
799 let s = TcpStream::connect(&addr).await.unwrap();
800 let s = TokioIo::new(s);
801 let mut s = TimeoutStream::new(s);
802 s.set_read_timeout(Some(Duration::from_millis(100)));
803 pin!(s);
804 s.read(&mut [0]).await.unwrap();
805 let r = s.read(&mut [0]).await;
806
807 match r {
808 Ok(_) => panic!("unexpected success"),
809 Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (),
810 Err(e) => panic!("{:?}", e),
811 }
812 }
813}