Skip to content

Postgres: Move io to background task. #3891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions sqlx-core/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,19 @@ pub use futures_util::io::AsyncReadExt;

#[cfg(feature = "_rt-tokio")]
pub use tokio::io::AsyncReadExt;

pub async fn read_from(
mut source: impl AsyncRead + Unpin,
data: &mut Vec<u8>,
) -> std::io::Result<usize> {
match () {
// Tokio lets us read into the buffer without zeroing first
#[cfg(feature = "_rt-tokio")]
_ => source.read_buf(data).await,
#[cfg(not(feature = "_rt-tokio"))]
_ => {
data.resize(data.capacity(), 0);
source.read(data).await
}
}
}
3 changes: 2 additions & 1 deletion sqlx-core/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ mod socket;
pub mod tls;

pub use socket::{
connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer,
connect_tcp, connect_uds, BufferedSocket, Socket, SocketExt, SocketIntoBox, WithSocket,
WriteBuffer,
};
108 changes: 97 additions & 11 deletions sqlx-core/src/net/socket/buffered.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::error::Error;
use crate::net::Socket;
use crate::net::{Socket, SocketExt};
use bytes::BytesMut;
use futures_util::Sink;
use std::ops::ControlFlow;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::{cmp, io};

use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
use crate::io::{read_from, AsyncRead, ProtocolDecode, ProtocolEncode};

// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
const DEFAULT_BUF_SIZE: usize = 8192;
Expand All @@ -13,6 +16,7 @@ pub struct BufferedSocket<S> {
socket: S,
write_buf: WriteBuffer,
read_buf: ReadBuffer,
wants_bytes: usize,
}

pub struct WriteBuffer {
Expand Down Expand Up @@ -42,6 +46,7 @@ impl<S: Socket> BufferedSocket<S> {
read: BytesMut::new(),
available: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
},
wants_bytes: 0,
}
}

Expand All @@ -56,6 +61,25 @@ impl<S: Socket> BufferedSocket<S> {
.await
}

pub fn poll_try_read<F, R>(
&mut self,
cx: &mut Context<'_>,
mut try_read: F,
) -> Poll<Result<R, Error>>
where
F: FnMut(&mut BytesMut) -> Result<ControlFlow<R, usize>, Error>,
{
loop {
// Ensure we have enough bytes, only read if we want bytes.
ready!(self.poll_handle_read(cx)?);

match try_read(&mut self.read_buf.read)? {
ControlFlow::Continue(read_len) => self.wants_bytes = read_len,
ControlFlow::Break(ret) => return Poll::Ready(Ok(ret)),
};
}
}

/// Retryable read operation.
///
/// The callback should check the contents of the buffer passed to it and either:
Expand Down Expand Up @@ -125,8 +149,8 @@ impl<S: Socket> BufferedSocket<S> {
pub async fn flush(&mut self) -> io::Result<()> {
while !self.write_buf.is_empty() {
let written = self.socket.write(self.write_buf.get()).await?;
// Consume does the sanity check.
self.write_buf.consume(written);
self.write_buf.sanity_check();
}

self.socket.flush().await?;
Expand Down Expand Up @@ -154,8 +178,39 @@ impl<S: Socket> BufferedSocket<S> {
socket: Box::new(self.socket),
write_buf: self.write_buf,
read_buf: self.read_buf,
wants_bytes: self.wants_bytes,
}
}

fn poll_handle_read(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
// Because of how `BytesMut` works, we should only be shifting capacity back and forth
// between `read` and `available` unless we have to read an oversize message.

while self.read_buf.len() < self.wants_bytes {
self.read_buf
.reserve(self.wants_bytes - self.read_buf.len());

let read = ready!(self.socket.poll_read(cx, &mut self.read_buf.available)?);

if read == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"expected to read {} bytes, got {} bytes at EOF",
self.wants_bytes,
self.read_buf.len()
),
)));
}

// we've read at least enough for `wants_bytes`, so we don't want more.
self.wants_bytes = 0;

self.read_buf.advance(read);
}

Poll::Ready(Ok(()))
}
}

impl WriteBuffer {
Expand Down Expand Up @@ -206,14 +261,8 @@ impl WriteBuffer {
/// Read into the buffer from `source`, returning the number of bytes read.
///
/// The buffer is automatically advanced by the number of bytes read.
pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> io::Result<usize> {
let read = match () {
// Tokio lets us read into the buffer without zeroing first
#[cfg(feature = "_rt-tokio")]
_ => source.read_buf(self.buf_mut()).await?,
#[cfg(not(feature = "_rt-tokio"))]
_ => source.read(self.init_remaining_mut()).await?,
};
pub async fn read_from(&mut self, source: impl AsyncRead + Unpin) -> io::Result<usize> {
let read = read_from(source, self.buf_mut()).await?;

if read > 0 {
self.advance(read);
Expand Down Expand Up @@ -326,4 +375,41 @@ impl ReadBuffer {
self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE);
}
}

fn len(&self) -> usize {
self.read.len()
}
}

impl<S: Socket> Sink<&[u8]> for BufferedSocket<S> {
type Error = crate::Error;

fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
if self.write_buf.bytes_written >= DEFAULT_BUF_SIZE {
self.poll_flush(cx)
} else {
Poll::Ready(Ok(()))
}
}

fn start_send(mut self: Pin<&mut Self>, item: &[u8]) -> crate::Result<()> {
self.write_buffer_mut().put_slice(item);
Ok(())
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
let this = &mut *self;

while !this.write_buf.is_empty() {
let written = ready!(this.socket.poll_write(cx, this.write_buf.get())?);
// Consume does the sanity check.
this.write_buf.consume(written);
}
this.socket.poll_flush(cx).map_err(Into::into)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
ready!(self.as_mut().poll_flush(cx))?;
self.socket.poll_shutdown(cx).map_err(Into::into)
}
}
122 changes: 31 additions & 91 deletions sqlx-core/src/net/socket/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::future::Future;
use std::future::{poll_fn, Future};
use std::io;
use std::path::Path;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

use bytes::BufMut;
Expand All @@ -27,125 +26,66 @@ pub trait Socket: Send + Sync + Unpin + 'static {
}

fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>>;

fn read<'a, B: ReadBuf>(&'a mut self, buf: &'a mut B) -> Read<'a, Self, B>
where
Self: Sized,
{
Read { socket: self, buf }
}

fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, Self>
where
Self: Sized,
{
Write { socket: self, buf }
}

fn flush(&mut self) -> Flush<'_, Self>
where
Self: Sized,
{
Flush { socket: self }
}

fn shutdown(&mut self) -> Shutdown<'_, Self>
where
Self: Sized,
{
Shutdown { socket: self }
}
}

pub struct Read<'a, S: ?Sized, B> {
socket: &'a mut S,
buf: &'a mut B,
}

impl<S: ?Sized, B> Future for Read<'_, S, B>
where
S: Socket,
B: ReadBuf,
{
type Output = io::Result<usize>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;

while this.buf.has_remaining_mut() {
match this.socket.try_read(&mut *this.buf) {
pub trait SocketExt: Socket {
fn poll_read(
&mut self,
cx: &mut Context<'_>,
buf: &mut dyn ReadBuf,
) -> Poll<Result<usize, io::Error>> {
while buf.has_remaining_mut() {
match self.try_read(buf) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
ready!(this.socket.poll_read_ready(cx))?;
ready!(self.poll_read_ready(cx))?;
}
ready => return Poll::Ready(ready),
}
}

Poll::Ready(Ok(0))
}
}

pub struct Write<'a, S: ?Sized> {
socket: &'a mut S,
buf: &'a [u8],
}

impl<S: ?Sized> Future for Write<'_, S>
where
S: Socket,
{
type Output = io::Result<usize>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;

while !this.buf.is_empty() {
match this.socket.try_write(this.buf) {
fn poll_write(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
while !buf.is_empty() {
match self.try_write(buf) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
ready!(this.socket.poll_write_ready(cx))?;
ready!(self.poll_write_ready(cx))?;
}
ready => return Poll::Ready(ready),
}
}

Poll::Ready(Ok(0))
}
}

pub struct Flush<'a, S: ?Sized> {
socket: &'a mut S,
}

impl<S: Socket + ?Sized> Future for Flush<'_, S> {
type Output = io::Result<()>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.socket.poll_flush(cx)
#[inline(always)]
fn shutdown(&mut self) -> impl Future<Output = io::Result<()>> {
poll_fn(|cx| self.poll_shutdown(cx))
}
}

pub struct Shutdown<'a, S: ?Sized> {
socket: &'a mut S,
}
#[inline(always)]
fn flush(&mut self) -> impl Future<Output = io::Result<()>> {
poll_fn(|cx| self.poll_flush(cx))
}

impl<S: ?Sized> Future for Shutdown<'_, S>
where
S: Socket,
{
type Output = io::Result<()>;
#[inline(always)]
fn write(&mut self, buf: &[u8]) -> impl Future<Output = io::Result<usize>> {
poll_fn(|cx| self.poll_write(cx, buf))
}

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.socket.poll_shutdown(cx)
#[inline(always)]
fn read(&mut self, buf: &mut impl ReadBuf) -> impl Future<Output = io::Result<usize>> {
poll_fn(|cx| self.poll_read(cx, buf))
}
}

impl<S: Socket> SocketExt for S {}

pub trait WithSocket {
type Output;

fn with_socket<S: Socket>(
self,
socket: S,
) -> impl std::future::Future<Output = Self::Output> + Send;
fn with_socket<S: Socket>(self, socket: S) -> impl Future<Output = Self::Output> + Send;
}

pub struct SocketIntoBox;
Expand Down
12 changes: 7 additions & 5 deletions sqlx-postgres/src/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,13 @@ WHERE rngtypid = $1

/// Check whether EXPLAIN statements are supported by the current connection
fn is_explain_available(&self) -> bool {
let parameter_statuses = &self.inner.stream.parameter_statuses;
let is_cockroachdb = parameter_statuses.contains_key("crdb_version");
let is_materialize = parameter_statuses.contains_key("mz_version");
let is_questdb = parameter_statuses.contains_key("questdb_version");
!is_cockroachdb && !is_materialize && !is_questdb
self.inner.shared.with_lock(|shared| {
let parameter_statuses = &shared.parameter_statuses;
let is_cockroachdb = parameter_statuses.contains_key("crdb_version");
let is_materialize = parameter_statuses.contains_key("mz_version");
let is_questdb = parameter_statuses.contains_key("questdb_version");
!is_cockroachdb && !is_materialize && !is_questdb
})
}

pub(crate) async fn get_nullable_for_columns(
Expand Down
Loading
Loading