Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compio-quic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,4 @@ required-features = ["compio-driver/fd-sync"]
[[bench]]
name = "quic"
harness = false
required-features = ["io-compat"]
5 changes: 2 additions & 3 deletions compio-quic/benches/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{

use compio_buf::bytes::Bytes;
use criterion::{BenchmarkId, Criterion, Throughput};
use futures_util::{StreamExt, stream::FuturesUnordered};
use futures_util::{AsyncWriteExt, StreamExt, stream::FuturesUnordered};
use rand::{RngCore, rng};

macro_rules! compio_spawn {
Expand Down Expand Up @@ -125,8 +125,7 @@ async fn compio_quic_echo_client(
send.finish().unwrap();
},
async {
let mut buf = vec![];
recv.read_to_end(&mut buf).await.unwrap();
recv.read_to_end(usize::MAX).await.unwrap();
}
);
})
Expand Down
4 changes: 2 additions & 2 deletions compio-quic/examples/quic-client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::net::{IpAddr, Ipv6Addr, SocketAddr};

use compio_io::AsyncWrite;
use compio_quic::ClientBuilder;
use tracing_subscriber::EnvFilter;

Expand Down Expand Up @@ -30,8 +31,7 @@ async fn main() {
send.write(&[1, 2, 3]).await.unwrap();
send.finish().unwrap();

let mut buf = vec![];
recv.read_to_end(&mut buf).await.unwrap();
let buf = recv.read_to_end(usize::MAX).await.unwrap();
println!("{buf:?}");

conn.close(1u32.into(), b"bye");
Expand Down
8 changes: 3 additions & 5 deletions compio-quic/examples/quic-dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::num::NonZeroUsize;

use compio_dispatcher::Dispatcher;
use compio_io::AsyncWriteExt;
use compio_quic::{ClientBuilder, Endpoint, ServerBuilder};
use compio_runtime::spawn;
use futures_util::{StreamExt, stream::FuturesUnordered};
Expand Down Expand Up @@ -40,9 +41,7 @@ async fn main() {
.await
.unwrap();
let mut send = conn.open_uni().unwrap();
send.write_all(format!("Hello world {i}!").as_bytes())
.await
.unwrap();
send.write_all(format!("Hello world {i}!")).await.unwrap();
send.finish().unwrap();
send.stopped().await.unwrap();
}
Expand All @@ -63,8 +62,7 @@ async fn main() {
.dispatch(move || async move {
let conn = incoming.await.unwrap();
let mut recv = conn.accept_uni().await.unwrap();
let mut buf = vec![];
recv.read_to_end(&mut buf).await.unwrap();
let buf = recv.read_to_end(usize::MAX).await.unwrap();
println!("{}", std::str::from_utf8(&buf).unwrap());
})
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions compio-quic/examples/quic-server.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use compio_io::AsyncWrite;
use compio_quic::ServerBuilder;
use tracing_subscriber::EnvFilter;

Expand All @@ -24,8 +25,7 @@ async fn main() {

let (mut send, mut recv) = conn.accept_bi().await.unwrap();

let mut buf = vec![];
recv.read_to_end(&mut buf).await.unwrap();
let buf = recv.read_to_end(usize::MAX).await.unwrap();
println!("{buf:?}");

send.write(&[4, 5, 6]).await.unwrap();
Expand Down
138 changes: 62 additions & 76 deletions compio-quic/src/recv_stream.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use std::{
collections::BTreeMap,
io,
mem::MaybeUninit,
sync::Arc,
task::{Context, Poll},
};

use compio_buf::{
BufResult, IoBufMut,
bytes::{BufMut, Bytes},
};
use compio_buf::{BufResult, IoBufMut, bytes::Bytes};
use compio_io::AsyncRead;
use futures_util::{future::poll_fn, ready};
use futures_util::future::poll_fn;
use quinn_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt};
use thiserror::Error;

Expand Down Expand Up @@ -221,61 +218,52 @@ impl RecvStream {
}
}

fn poll_read(
/// Attempts to read from the stream into the provided buffer
///
/// On success, returns `Poll::Ready(Ok(num_bytes_read))` and places data
/// into `buf`. If this returns zero bytes read (and `buf` has a
/// non-zero length), that indicates that the remote
/// side has [`finish`]ed the stream and the local side has already read all
/// bytes.
///
/// If no data is available for reading, this returns `Poll::Pending` and
/// arranges for the current task (via `cx.waker()`) to be notified when
/// the stream becomes readable or is closed.
///
/// [`finish`]: crate::SendStream::finish
pub fn poll_read_uninit(
&mut self,
cx: &mut Context,
mut buf: impl BufMut,
) -> Poll<Result<Option<usize>, ReadError>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(Some(0)));
buf: &mut [MaybeUninit<u8>],
) -> Poll<Result<usize, ReadError>> {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}

self.execute_poll_read(cx, true, |chunks| {
let mut read = 0;
loop {
if !buf.has_remaining_mut() {
// We know `read` is `true` because `buf.remaining()` was not 0 before
if read >= buf.len() {
// We know `read > 0` because `buf` cannot be empty here
return ReadStatus::Readable(read);
}

match chunks.next(buf.remaining_mut()) {
match chunks.next(buf.len() - read) {
Ok(Some(chunk)) => {
read += chunk.bytes.len();
buf.put(chunk.bytes);
let bytes = chunk.bytes;
let len = bytes.len();
buf[read..read + len].copy_from_slice(unsafe {
std::slice::from_raw_parts(bytes.as_ptr().cast(), len)
});
read += len;
}
res => {
return (if read == 0 { None } else { Some(read) }, res.err()).into();
}
}
}
})
}

/// Read data contiguously from the stream.
///
/// Yields the number of bytes read into `buf` on success, or `None` if the
/// stream was finished.
///
/// This operation is cancel-safe.
pub async fn read(&mut self, mut buf: impl BufMut) -> Result<Option<usize>, ReadError> {
poll_fn(|cx| self.poll_read(cx, &mut buf)).await
}

/// Read an exact number of bytes contiguously from the stream.
///
/// See [`read()`] for details. This operation is *not* cancel-safe.
///
/// [`read()`]: RecvStream::read
pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> {
poll_fn(|cx| {
while buf.has_remaining_mut() {
if ready!(self.poll_read(cx, &mut buf))?.is_none() {
return Poll::Ready(Err(ReadExactError::FinishedEarly(buf.remaining_mut())));
}
}
Poll::Ready(Ok(()))
})
.await
.map(|res| res.map(|n| n.unwrap_or_default()))
}

/// Read the next segment of data.
Expand Down Expand Up @@ -349,42 +337,42 @@ impl RecvStream {

/// Convenience method to read all remaining data into a buffer.
///
/// Uses unordered reads to be more efficient than using [`AsyncRead`]. If
/// unordered reads have already been made, the resulting buffer may have
/// gaps containing zero.
/// Fails with [`ReadError::TooLong`] on reading more than `size_limit`
/// bytes, discarding all data read. Uses unordered reads to be more
/// efficient than using `AsyncRead` would allow. `size_limit` should be
/// set to limit worst-case memory use.
///
/// Depending on [`BufMut`] implementation, this method may fail with
/// [`ReadError::BufferTooShort`] if the buffer is not large enough to
/// hold the entire stream. For example when using a `&mut [u8]` it will
/// never receive bytes more than the length of the slice, but when using a
/// `&mut Vec<u8>` it will allocate more memory as needed.
/// If unordered reads have already been made, the resulting buffer may have
/// gaps containing zeros.
///
/// This operation is *not* cancel-safe.
pub async fn read_to_end(&mut self, mut buf: impl BufMut) -> Result<usize, ReadError> {
pub async fn read_to_end(&mut self, size_limit: usize) -> Result<Vec<u8>, ReadError> {
let mut start = u64::MAX;
let mut end = 0;
let mut chunks = BTreeMap::new();
let mut chunks = vec![];
loop {
let Some(chunk) = self.read_chunk(usize::MAX, false).await? else {
break;
};
start = start.min(chunk.offset);
end = end.max(chunk.offset + chunk.bytes.len() as u64);
if end - start > buf.remaining_mut() as u64 {
return Err(ReadError::BufferTooShort);
if (end - start) > size_limit as u64 {
return Err(ReadError::TooLong);
}
chunks.insert(chunk.offset, chunk.bytes);
chunks.push((chunk.offset, chunk.bytes));
}
if start == u64::MAX || start >= end {
// no data read
return Ok(vec![]);
}
let mut last = 0;
let len = (end - start) as usize;
let mut buffer = vec![0u8; len];
for (offset, bytes) in chunks {
let offset = (offset - start) as usize;
if offset > last {
buf.put_bytes(0, offset - last);
}
last = offset + bytes.len();
buf.put(bytes);
let buf_len = bytes.len();
buffer[offset..offset + buf_len].copy_from_slice(&bytes);
}
Ok((end - start) as usize)
Ok(buffer)
}
}

Expand Down Expand Up @@ -450,11 +438,11 @@ pub enum ReadError {
/// [`Connecting::into_0rtt()`]: crate::Connecting::into_0rtt()
#[error("0-RTT rejected")]
ZeroRttRejected,
/// The stream is larger than the user-supplied buffer capacity.
/// The stream is larger than the user-supplied limit.
///
/// Can only occur when using [`read_to_end()`](RecvStream::read_to_end).
#[error("buffer too short")]
BufferTooShort,
#[error("the stream is larger than the user-supplied limit")]
TooLong,
}

impl From<ReadableError> for ReadError {
Expand All @@ -481,7 +469,8 @@ impl From<ReadError> for io::Error {
let kind = match x {
Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
IllegalOrderedRead | BufferTooShort => io::ErrorKind::InvalidInput,
IllegalOrderedRead => io::ErrorKind::InvalidInput,
TooLong => io::ErrorKind::InvalidData,
};
Self::new(kind, x)
}
Expand All @@ -500,14 +489,9 @@ pub enum ReadExactError {

impl AsyncRead for RecvStream {
async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
let res = self
.read(buf.as_uninit())
let res = poll_fn(|cx| self.poll_read_uninit(cx, buf.as_uninit()))
.await
.map(|n| {
let n = n.unwrap_or_default();
unsafe { buf.advance_to(n) }
n
})
.inspect(|&n| unsafe { buf.advance_to(n) })
.map_err(Into::into);
BufResult(res, buf)
}
Expand All @@ -520,9 +504,11 @@ impl futures_util::AsyncRead for RecvStream {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
// SAFETY: buf is valid
self.get_mut()
.poll_read(cx, buf)
.map_ok(Option::unwrap_or_default)
.poll_read_uninit(cx, unsafe {
std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len())
})
.map_err(Into::into)
}
}
Expand Down
34 changes: 4 additions & 30 deletions compio-quic/src/send_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,35 +179,6 @@ impl SendStream {
}
}

/// Write bytes to the stream.
///
/// Yields the number of bytes written on success. Congestion and flow
/// control may cause this to be shorter than `buf.len()`, indicating
/// that only a prefix of `buf` was written.
///
/// This operation is cancel-safe.
pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await
}

/// Convenience method to write an entire buffer to the stream.
///
/// This operation is *not* cancel-safe.
pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
let mut count = 0;
poll_fn(|cx| {
loop {
if count == buf.len() {
return Poll::Ready(Ok(()));
}
let n =
ready!(self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..])))?;
count += n;
}
})
.await
}

/// Write chunks to the stream.
///
/// Yields the number of bytes and chunks written on success.
Expand Down Expand Up @@ -329,7 +300,10 @@ impl From<WriteError> for io::Error {

impl AsyncWrite for SendStream {
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
let res = self.write(buf.as_slice()).await.map_err(Into::into);
let res =
poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf.as_slice())))
.await
.map_err(Into::into);
BufResult(res, buf)
}

Expand Down
Loading