diff --git a/sqlx-core/src/io/mod.rs b/sqlx-core/src/io/mod.rs index 2765abe02f..23b4d2bccf 100644 --- a/sqlx-core/src/io/mod.rs +++ b/sqlx-core/src/io/mod.rs @@ -24,3 +24,19 @@ pub use futures_util::io::AsyncReadExt; #[cfg(feature = "_rt-tokio")] pub use tokio::io::AsyncReadExt; + +pub async fn read_from( + mut source: impl AsyncRead + Unpin, + data: &mut Vec, +) -> std::io::Result { + match () { + // Tokio lets us read into the buffer without zeroing first + #[cfg(feature = "_rt-tokio")] + _ => source.read_buf(data).await, + #[cfg(not(feature = "_rt-tokio"))] + _ => { + data.resize(data.capacity(), 0); + source.read(data).await + } + } +} diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs index f9c43668ab..5001947702 100644 --- a/sqlx-core/src/net/mod.rs +++ b/sqlx-core/src/net/mod.rs @@ -2,5 +2,6 @@ mod socket; pub mod tls; pub use socket::{ - connect_tcp, connect_uds, BufferedSocket, Socket, SocketIntoBox, WithSocket, WriteBuffer, + connect_tcp, connect_uds, BufferedSocket, Socket, SocketExt, SocketIntoBox, WithSocket, + WriteBuffer, }; diff --git a/sqlx-core/src/net/socket/buffered.rs b/sqlx-core/src/net/socket/buffered.rs index 6785e70879..fbd9206dba 100644 --- a/sqlx-core/src/net/socket/buffered.rs +++ b/sqlx-core/src/net/socket/buffered.rs @@ -1,10 +1,13 @@ use crate::error::Error; -use crate::net::Socket; +use crate::net::{Socket, SocketExt}; use bytes::BytesMut; +use futures_util::Sink; use std::ops::ControlFlow; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; use std::{cmp, io}; -use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode}; +use crate::io::{read_from, AsyncRead, ProtocolDecode, ProtocolEncode}; // Tokio, async-std, and std all use this as the default capacity for their buffered I/O. const DEFAULT_BUF_SIZE: usize = 8192; @@ -13,6 +16,7 @@ pub struct BufferedSocket { socket: S, write_buf: WriteBuffer, read_buf: ReadBuffer, + wants_bytes: usize, } pub struct WriteBuffer { @@ -42,6 +46,7 @@ impl BufferedSocket { read: BytesMut::new(), available: BytesMut::with_capacity(DEFAULT_BUF_SIZE), }, + wants_bytes: 0, } } @@ -56,6 +61,25 @@ impl BufferedSocket { .await } + pub fn poll_try_read( + &mut self, + cx: &mut Context<'_>, + mut try_read: F, + ) -> Poll> + where + F: FnMut(&mut BytesMut) -> Result, Error>, + { + loop { + // Ensure we have enough bytes, only read if we want bytes. + ready!(self.poll_handle_read(cx)?); + + match try_read(&mut self.read_buf.read)? { + ControlFlow::Continue(read_len) => self.wants_bytes = read_len, + ControlFlow::Break(ret) => return Poll::Ready(Ok(ret)), + }; + } + } + /// Retryable read operation. /// /// The callback should check the contents of the buffer passed to it and either: @@ -125,8 +149,8 @@ impl BufferedSocket { pub async fn flush(&mut self) -> io::Result<()> { while !self.write_buf.is_empty() { let written = self.socket.write(self.write_buf.get()).await?; + // Consume does the sanity check. self.write_buf.consume(written); - self.write_buf.sanity_check(); } self.socket.flush().await?; @@ -154,8 +178,39 @@ impl BufferedSocket { socket: Box::new(self.socket), write_buf: self.write_buf, read_buf: self.read_buf, + wants_bytes: self.wants_bytes, } } + + fn poll_handle_read(&mut self, cx: &mut Context<'_>) -> Poll> { + // Because of how `BytesMut` works, we should only be shifting capacity back and forth + // between `read` and `available` unless we have to read an oversize message. + + while self.read_buf.len() < self.wants_bytes { + self.read_buf + .reserve(self.wants_bytes - self.read_buf.len()); + + let read = ready!(self.socket.poll_read(cx, &mut self.read_buf.available)?); + + if read == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + format!( + "expected to read {} bytes, got {} bytes at EOF", + self.wants_bytes, + self.read_buf.len() + ), + ))); + } + + // we've read at least enough for `wants_bytes`, so we don't want more. + self.wants_bytes = 0; + + self.read_buf.advance(read); + } + + Poll::Ready(Ok(())) + } } impl WriteBuffer { @@ -206,14 +261,8 @@ impl WriteBuffer { /// Read into the buffer from `source`, returning the number of bytes read. /// /// The buffer is automatically advanced by the number of bytes read. - pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> io::Result { - let read = match () { - // Tokio lets us read into the buffer without zeroing first - #[cfg(feature = "_rt-tokio")] - _ => source.read_buf(self.buf_mut()).await?, - #[cfg(not(feature = "_rt-tokio"))] - _ => source.read(self.init_remaining_mut()).await?, - }; + pub async fn read_from(&mut self, source: impl AsyncRead + Unpin) -> io::Result { + let read = read_from(source, self.buf_mut()).await?; if read > 0 { self.advance(read); @@ -326,4 +375,41 @@ impl ReadBuffer { self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE); } } + + fn len(&self) -> usize { + self.read.len() + } +} + +impl Sink<&[u8]> for BufferedSocket { + type Error = crate::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.write_buf.bytes_written >= DEFAULT_BUF_SIZE { + self.poll_flush(cx) + } else { + Poll::Ready(Ok(())) + } + } + + fn start_send(mut self: Pin<&mut Self>, item: &[u8]) -> crate::Result<()> { + self.write_buffer_mut().put_slice(item); + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = &mut *self; + + while !this.write_buf.is_empty() { + let written = ready!(this.socket.poll_write(cx, this.write_buf.get())?); + // Consume does the sanity check. + this.write_buf.consume(written); + } + this.socket.poll_flush(cx).map_err(Into::into) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_flush(cx))?; + self.socket.poll_shutdown(cx).map_err(Into::into) + } } diff --git a/sqlx-core/src/net/socket/mod.rs b/sqlx-core/src/net/socket/mod.rs index 1f24da8c40..9df5863912 100644 --- a/sqlx-core/src/net/socket/mod.rs +++ b/sqlx-core/src/net/socket/mod.rs @@ -1,7 +1,6 @@ -use std::future::Future; +use std::future::{poll_fn, Future}; use std::io; use std::path::Path; -use std::pin::Pin; use std::task::{ready, Context, Poll}; use bytes::BufMut; @@ -27,55 +26,18 @@ pub trait Socket: Send + Sync + Unpin + 'static { } fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll>; - - fn read<'a, B: ReadBuf>(&'a mut self, buf: &'a mut B) -> Read<'a, Self, B> - where - Self: Sized, - { - Read { socket: self, buf } - } - - fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, Self> - where - Self: Sized, - { - Write { socket: self, buf } - } - - fn flush(&mut self) -> Flush<'_, Self> - where - Self: Sized, - { - Flush { socket: self } - } - - fn shutdown(&mut self) -> Shutdown<'_, Self> - where - Self: Sized, - { - Shutdown { socket: self } - } } -pub struct Read<'a, S: ?Sized, B> { - socket: &'a mut S, - buf: &'a mut B, -} - -impl Future for Read<'_, S, B> -where - S: Socket, - B: ReadBuf, -{ - type Output = io::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = &mut *self; - - while this.buf.has_remaining_mut() { - match this.socket.try_read(&mut *this.buf) { +pub trait SocketExt: Socket { + fn poll_read( + &mut self, + cx: &mut Context<'_>, + buf: &mut dyn ReadBuf, + ) -> Poll> { + while buf.has_remaining_mut() { + match self.try_read(buf) { Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - ready!(this.socket.poll_read_ready(cx))?; + ready!(self.poll_read_ready(cx))?; } ready => return Poll::Ready(ready), } @@ -83,26 +45,12 @@ where Poll::Ready(Ok(0)) } -} - -pub struct Write<'a, S: ?Sized> { - socket: &'a mut S, - buf: &'a [u8], -} -impl Future for Write<'_, S> -where - S: Socket, -{ - type Output = io::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = &mut *self; - - while !this.buf.is_empty() { - match this.socket.try_write(this.buf) { + fn poll_write(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + while !buf.is_empty() { + match self.try_write(buf) { Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - ready!(this.socket.poll_write_ready(cx))?; + ready!(self.poll_write_ready(cx))?; } ready => return Poll::Ready(ready), } @@ -110,42 +58,34 @@ where Poll::Ready(Ok(0)) } -} -pub struct Flush<'a, S: ?Sized> { - socket: &'a mut S, -} - -impl Future for Flush<'_, S> { - type Output = io::Result<()>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.socket.poll_flush(cx) + #[inline(always)] + fn shutdown(&mut self) -> impl Future> { + poll_fn(|cx| self.poll_shutdown(cx)) } -} -pub struct Shutdown<'a, S: ?Sized> { - socket: &'a mut S, -} + #[inline(always)] + fn flush(&mut self) -> impl Future> { + poll_fn(|cx| self.poll_flush(cx)) + } -impl Future for Shutdown<'_, S> -where - S: Socket, -{ - type Output = io::Result<()>; + #[inline(always)] + fn write(&mut self, buf: &[u8]) -> impl Future> { + poll_fn(|cx| self.poll_write(cx, buf)) + } - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.socket.poll_shutdown(cx) + #[inline(always)] + fn read(&mut self, buf: &mut impl ReadBuf) -> impl Future> { + poll_fn(|cx| self.poll_read(cx, buf)) } } +impl SocketExt for S {} + pub trait WithSocket { type Output; - fn with_socket( - self, - socket: S, - ) -> impl std::future::Future + Send; + fn with_socket(self, socket: S) -> impl Future + Send; } pub struct SocketIntoBox; diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index a27578c56c..c47c82ded6 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -414,11 +414,13 @@ WHERE rngtypid = $1 /// Check whether EXPLAIN statements are supported by the current connection fn is_explain_available(&self) -> bool { - let parameter_statuses = &self.inner.stream.parameter_statuses; - let is_cockroachdb = parameter_statuses.contains_key("crdb_version"); - let is_materialize = parameter_statuses.contains_key("mz_version"); - let is_questdb = parameter_statuses.contains_key("questdb_version"); - !is_cockroachdb && !is_materialize && !is_questdb + self.inner.shared.with_lock(|shared| { + let parameter_statuses = &shared.parameter_statuses; + let is_cockroachdb = parameter_statuses.contains_key("crdb_version"); + let is_materialize = parameter_statuses.contains_key("mz_version"); + let is_questdb = parameter_statuses.contains_key("questdb_version"); + !is_cockroachdb && !is_materialize && !is_questdb + }) } pub(crate) async fn get_nullable_for_columns( diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 1bc4172fbd..854c09a241 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -1,23 +1,27 @@ -use crate::HashMap; - -use crate::common::StatementCache; -use crate::connection::{sasl, stream::PgStream}; +use crate::connection::sasl; use crate::error::Error; -use crate::io::StatementId; -use crate::message::{ - Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup, -}; +use crate::message::{Authentication, BackendKeyData, BackendMessageFormat, Password, Startup}; use crate::{PgConnectOptions, PgConnection}; +use futures_channel::mpsc::unbounded; +use std::str::FromStr; -use super::PgConnectionInner; +use super::worker::{Shared, Worker}; // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3 // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11 impl PgConnection { pub(crate) async fn establish(options: &PgConnectOptions) -> Result { + // A channel to communicate postgres notifications between the bg worker and a `PgListener`. + let (notif_tx, notif_rx) = unbounded(); + + // Shared state between the bg worker and the `PgConnection` + let shared = Shared::new(); + // Upgrade to TLS if we were asked to and the server supports it - let mut stream = PgStream::connect(options).await?; + let channel = Worker::connect(options, notif_tx, shared.clone()).await?; + + let mut conn = PgConnection::new(options, channel, notif_rx, shared); // To begin a session, a frontend opens a connection to the server // and sends a startup message. @@ -45,14 +49,16 @@ impl PgConnection { params.push(("options", options)); } - stream.write(Startup { - username: Some(&options.username), - database: options.database.as_deref(), - params: ¶ms, + // Only after establishing a connection, Postgres sends a [ReadyForQuery] response. While + // establishing a connection this pipe is used to read responses from. + let mut pipe = conn.pipe(|buf| { + buf.write(Startup { + username: Some(&options.username), + database: options.database.as_deref(), + params: ¶ms, + }) })?; - stream.flush().await?; - // The server then uses this information and the contents of // its configuration files (such as pg_hba.conf) to determine whether the connection is // provisionally acceptable, and what additional @@ -60,10 +66,9 @@ impl PgConnection { let mut process_id = 0; let mut secret_key = 0; - let transaction_status; loop { - let message = stream.recv().await?; + let message = pipe.recv().await?; match message.format { BackendMessageFormat::Authentication => match message.decode()? { Authentication::Ok => { @@ -75,11 +80,9 @@ impl PgConnection { // The frontend must now send a [PasswordMessage] containing the // password in clear-text form. - stream - .send(Password::Cleartext( - options.password.as_deref().unwrap_or_default(), - )) - .await?; + conn.pipe_and_forget(Password::Cleartext( + options.password.as_deref().unwrap_or_default(), + ))?; } Authentication::Md5Password(body) => { @@ -88,17 +91,15 @@ impl PgConnection { // using the 4-byte random salt specified in the // [AuthenticationMD5Password] message. - stream - .send(Password::Md5 { - username: &options.username, - password: options.password.as_deref().unwrap_or_default(), - salt: body.salt, - }) - .await?; + conn.pipe_and_forget(Password::Md5 { + username: &options.username, + password: options.password.as_deref().unwrap_or_default(), + salt: body.salt, + })?; } Authentication::Sasl(body) => { - sasl::authenticate(&mut stream, options, body).await?; + sasl::authenticate(&conn, &mut pipe, options, body).await?; } method => { @@ -120,9 +121,7 @@ impl PgConnection { } BackendMessageFormat::ReadyForQuery => { - // start-up is completed. The frontend can now issue commands - transaction_status = message.decode::()?.transaction_status; - + // The transaction status is updated in the bg worker. break; } @@ -135,21 +134,82 @@ impl PgConnection { } } - Ok(PgConnection { - inner: Box::new(PgConnectionInner { - stream, - process_id, - secret_key, - transaction_status, - transaction_depth: 0, - pending_ready_for_query_count: 0, - next_statement_id: StatementId::NAMED_START, - cache_statement: StatementCache::new(options.statement_cache_capacity), - cache_type_oid: HashMap::new(), - cache_type_info: HashMap::new(), - cache_elem_type_to_array: HashMap::new(), - log_settings: options.log_settings.clone(), - }), - }) + let server_version = conn + .inner + .shared + .remove_parameter_status("server_version") + .map(parse_server_version); + + conn.inner.server_version_num = server_version.flatten(); + conn.inner.secret_key = secret_key; + conn.inner.process_id = process_id; + + Ok(conn) + } +} + +// reference: +// https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065 +fn parse_server_version(s: impl Into) -> Option { + let s = s.into(); + let mut parts = Vec::::with_capacity(3); + + let mut from = 0; + let mut chs = s.char_indices().peekable(); + while let Some((i, ch)) = chs.next() { + match ch { + '.' => { + if let Ok(num) = u32::from_str(&s[from..i]) { + parts.push(num); + from = i + 1; + } else { + break; + } + } + _ if ch.is_ascii_digit() => { + if chs.peek().is_none() { + if let Ok(num) = u32::from_str(&s[from..]) { + parts.push(num); + } + break; + } + } + _ => { + if let Ok(num) = u32::from_str(&s[from..i]) { + parts.push(num); + } + break; + } + }; + } + + let version_num = match parts.as_slice() { + [major, minor, rev] => (100 * major + minor) * 100 + rev, + [major, minor] if *major >= 10 => 100 * 100 * major + minor, + [major, minor] => (100 * major + minor) * 100, + [major] => 100 * 100 * major, + _ => return None, + }; + + Some(version_num) +} + +#[cfg(test)] +mod tests { + use super::parse_server_version; + + #[test] + fn test_parse_server_version_num() { + // old style + assert_eq!(parse_server_version("9.6.1"), Some(90601)); + // new style + assert_eq!(parse_server_version("10.1"), Some(100001)); + // old style without minor version + assert_eq!(parse_server_version("9.6devel"), Some(90600)); + // new style without minor version, e.g. */ + assert_eq!(parse_server_version("10devel"), Some(100000)); + assert_eq!(parse_server_version("13devel87"), Some(130000)); + // unknown + assert_eq!(parse_server_version("unknown"), None); } } diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index d0596aacee..43f673ef59 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -5,7 +5,7 @@ use crate::io::{PortalId, StatementId}; use crate::logger::QueryLogger; use crate::message::{ self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse, - ParseComplete, Query, RowDescription, + ParseComplete, RowDescription, }; use crate::statement::PgStatementMetadata; use crate::{ @@ -20,6 +20,8 @@ use sqlx_core::arguments::Arguments; use sqlx_core::Either; use std::{borrow::Cow, pin::pin, sync::Arc}; +use super::pipe::Pipe; + async fn prepare( conn: &mut PgConnection, sql: &str, @@ -45,52 +47,45 @@ async fn prepare( param_types.push(conn.resolve_type_id(&ty.0).await?); } - // flush and wait until we are re-ready - conn.wait_until_ready().await?; + let mut pipe = conn.pipe(|buf| { + // next we send the PARSE command to the server + buf.write_msg(Parse { + param_types: ¶m_types, + query: sql, + statement: id, + })?; + + if metadata.is_none() { + // get the statement columns and parameters + buf.write_msg(message::Describe::Statement(id))?; + } - // next we send the PARSE command to the server - conn.inner.stream.write_msg(Parse { - param_types: ¶m_types, - query: sql, - statement: id, + // we ask for the server to immediately send us the result of the PARSE command + buf.write_sync(); + Ok(()) })?; - if metadata.is_none() { - // get the statement columns and parameters - conn.inner - .stream - .write_msg(message::Describe::Statement(id))?; - } - - // we ask for the server to immediately send us the result of the PARSE command - conn.write_sync(); - conn.inner.stream.flush().await?; - // indicates that the SQL query string is now successfully parsed and has semantic validity - conn.inner.stream.recv_expect::().await?; + pipe.recv_expect::().await?; let metadata = if let Some(metadata) = metadata { // each SYNC produces one READY FOR QUERY - conn.recv_ready_for_query().await?; + pipe.recv_ready_for_query().await?; // we already have metadata metadata } else { - let parameters = recv_desc_params(conn).await?; + let parameters = recv_desc_params(&mut pipe).await?; - let rows = recv_desc_rows(conn).await?; + let rows = recv_desc_rows(&mut pipe).await?; // each SYNC produces one READY FOR QUERY - conn.recv_ready_for_query().await?; + pipe.recv_ready_for_query().await?; let parameters = conn.handle_parameter_description(parameters).await?; let (columns, column_names) = conn.handle_row_description(rows, true).await?; - // ensure that if we did fetch custom data, we wait until we are fully ready before - // continuing - conn.wait_until_ready().await?; - Arc::new(PgStatementMetadata { parameters, columns, @@ -101,12 +96,12 @@ async fn prepare( Ok((id, metadata)) } -async fn recv_desc_params(conn: &mut PgConnection) -> Result { - conn.inner.stream.recv_expect().await +async fn recv_desc_params(pipe: &mut Pipe) -> Result { + pipe.recv_expect().await } -async fn recv_desc_rows(conn: &mut PgConnection) -> Result, Error> { - let rows: Option = match conn.inner.stream.recv().await? { +async fn recv_desc_rows(pipe: &mut Pipe) -> Result, Error> { + let rows: Option = match pipe.recv().await? { // describes the rows that will be returned when the statement is eventually executed message if message.format == BackendMessageFormat::RowDescription => { Some(message.decode()?) @@ -127,44 +122,6 @@ async fn recv_desc_rows(conn: &mut PgConnection) -> Result Result<(), Error> { - // we need to wait for the [CloseComplete] to be returned from the server - while count > 0 { - match self.inner.stream.recv().await? { - message if message.format == BackendMessageFormat::PortalSuspended => { - // there was an open portal - // this can happen if the last time a statement was used it was not fully executed - } - - message if message.format == BackendMessageFormat::CloseComplete => { - // successfully closed the statement (and freed up the server resources) - count -= 1; - } - - message => { - return Err(err_protocol!( - "expecting PortalSuspended or CloseComplete but received {:?}", - message.format - )); - } - } - } - - Ok(()) - } - - #[inline(always)] - pub(crate) fn write_sync(&mut self) { - self.inner - .stream - .write_msg(message::Sync) - .expect("BUG: Sync should not be too big for protocol"); - - // all SYNC messages will return a ReadyForQuery - self.inner.pending_ready_for_query_count += 1; - } - async fn get_or_prepare( &mut self, sql: &str, @@ -182,13 +139,14 @@ impl PgConnection { if persistent && self.inner.cache_statement.is_enabled() { if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) { - self.inner.stream.write_msg(Close::Statement(id))?; - self.write_sync(); - - self.inner.stream.flush().await?; - - self.wait_for_close_complete(1).await?; - self.recv_ready_for_query().await?; + let mut pipe = self.pipe(|buf| { + buf.write_msg(Close::Statement(id))?; + buf.write_sync(); + Ok(()) + })?; + + pipe.wait_for_close_complete(1).await?; + pipe.recv_ready_for_query().await?; } } @@ -204,10 +162,8 @@ impl PgConnection { ) -> Result, Error>> + 'e, Error> { let mut logger = QueryLogger::new(query, self.inner.log_settings.clone()); - // before we continue, wait until we are "ready" to accept more queries - self.wait_until_ready().await?; - let mut metadata: Arc; + let mut pipe: Pipe; let format = if let Some(mut arguments) = arguments { // Check this before we write anything to the stream. @@ -234,53 +190,50 @@ impl PgConnection { // patch holes created during encoding arguments.apply_patches(self, &metadata.parameters).await?; - // consume messages till `ReadyForQuery` before bind and execute - self.wait_until_ready().await?; - - // bind to attach the arguments to the statement and create a portal - self.inner.stream.write_msg(Bind { - portal: PortalId::UNNAMED, - statement, - formats: &[PgValueFormat::Binary], - num_params, - params: &arguments.buffer, - result_formats: &[PgValueFormat::Binary], - })?; - - // executes the portal up to the passed limit - // the protocol-level limit acts nearly identically to the `LIMIT` in SQL - self.inner.stream.write_msg(message::Execute { - portal: PortalId::UNNAMED, - // Non-zero limits cause query plan pessimization by disabling parallel workers: - // https://github.com/launchbadge/sqlx/issues/3673 - limit: 0, + pipe = self.pipe(|buf| { + // bind to attach the arguments to the statement and create a portal + buf.write_msg(Bind { + portal: PortalId::UNNAMED, + statement, + formats: &[PgValueFormat::Binary], + num_params, + params: &arguments.buffer, + result_formats: &[PgValueFormat::Binary], + })?; + + // executes the portal up to the passed limit + // the protocol-level limit acts nearly identically to the `LIMIT` in SQL + buf.write_msg(message::Execute { + portal: PortalId::UNNAMED, + // Non-zero limits cause query plan pessimization by disabling parallel workers: + // https://github.com/launchbadge/sqlx/issues/3673 + limit: 0, + })?; + // From https://www.postgresql.org/docs/current/protocol-flow.html: + // + // "An unnamed portal is destroyed at the end of the transaction, or as + // soon as the next Bind statement specifying the unnamed portal as + // destination is issued. (Note that a simple Query message also + // destroys the unnamed portal." + + // we ask the database server to close the unnamed portal and free the associated resources + // earlier - after the execution of the current query. + buf.write_msg(Close::Portal(PortalId::UNNAMED))?; + + // finally, [Sync] asks postgres to process the messages that we sent and respond with + // a [ReadyForQuery] message when it's completely done. Theoretically, we could send + // dozens of queries before a [Sync] and postgres can handle that. Execution on the server + // is still serial but it would reduce round-trips. Some kind of builder pattern that is + // termed batching might suit this. + buf.write_sync(); + Ok(()) })?; - // From https://www.postgresql.org/docs/current/protocol-flow.html: - // - // "An unnamed portal is destroyed at the end of the transaction, or as - // soon as the next Bind statement specifying the unnamed portal as - // destination is issued. (Note that a simple Query message also - // destroys the unnamed portal." - - // we ask the database server to close the unnamed portal and free the associated resources - // earlier - after the execution of the current query. - self.inner - .stream - .write_msg(Close::Portal(PortalId::UNNAMED))?; - - // finally, [Sync] asks postgres to process the messages that we sent and respond with - // a [ReadyForQuery] message when it's completely done. Theoretically, we could send - // dozens of queries before a [Sync] and postgres can handle that. Execution on the server - // is still serial but it would reduce round-trips. Some kind of builder pattern that is - // termed batching might suit this. - self.write_sync(); // prepared statements are binary PgValueFormat::Binary } else { // Query will trigger a ReadyForQuery - self.inner.stream.write_msg(Query(query))?; - self.inner.pending_ready_for_query_count += 1; + pipe = self.queue_simple_query(query)?; // metadata starts out as "nothing" metadata = Arc::new(PgStatementMetadata::default()); @@ -289,11 +242,9 @@ impl PgConnection { PgValueFormat::Text }; - self.inner.stream.flush().await?; - Ok(try_stream! { loop { - let message = self.inner.stream.recv().await?; + let message = pipe.recv().await?; match message.format { BackendMessageFormat::BindComplete @@ -358,8 +309,8 @@ impl PgConnection { } BackendMessageFormat::ReadyForQuery => { - // processing of the query string is complete - self.handle_ready_for_query(message)?; + // Processing of the query string is complete, the transaction status is + // updated in the bg worker. break; } @@ -451,8 +402,6 @@ impl<'c> Executor<'c> for &'c mut PgConnection { 'c: 'e, { Box::pin(async move { - self.wait_until_ready().await?; - let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?; Ok(PgStatement { @@ -470,8 +419,6 @@ impl<'c> Executor<'c> for &'c mut PgConnection { 'c: 'e, { Box::pin(async move { - self.wait_until_ready().await?; - let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?; let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?; diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index ce499ed744..5c6ad56fa2 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -3,17 +3,18 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use crate::HashMap; +use futures_channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures_core::future::BoxFuture; use futures_util::FutureExt; +use pipe::Pipe; +use request::{IoRequest, MessageBuf}; +use worker::Shared; use crate::common::StatementCache; use crate::error::Error; use crate::ext::ustr::UStr; use crate::io::StatementId; -use crate::message::{ - BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate, - TransactionStatus, -}; +use crate::message::{Close, FrontendMessage, Notification, Query, TransactionStatus}; use crate::statement::PgStatementMetadata; use crate::transaction::Transaction; use crate::types::Oid; @@ -21,14 +22,14 @@ use crate::{PgConnectOptions, PgTypeInfo, Postgres}; pub(crate) use sqlx_core::connection::*; -pub use self::stream::PgStream; - pub(crate) mod describe; mod establish; mod executor; +mod pipe; +mod request; mod sasl; -mod stream; mod tls; +mod worker; /// A connection to a PostgreSQL database. /// @@ -38,10 +39,10 @@ pub struct PgConnection { } pub struct PgConnectionInner { - // underlying TCP or UDS stream, - // wrapped in a potentially TLS stream, - // wrapped in a buffered stream - pub(crate) stream: PgStream, + // channel to the background worker + chan: UnboundedSender, + + pub(crate) notifications: UnboundedReceiver, // process id of this backend // used to send cancel requests @@ -53,6 +54,8 @@ pub struct PgConnectionInner { #[allow(dead_code)] secret_key: u32, + pub(crate) server_version_num: Option, + // sequence of statement IDs for use in preparing statements // in PostgreSQL, the statement is prepared to a user-supplied identifier next_statement_id: StatementId, @@ -65,77 +68,114 @@ pub struct PgConnectionInner { cache_type_oid: HashMap, cache_elem_type_to_array: HashMap, - // number of ReadyForQuery messages that we are currently expecting - pub(crate) pending_ready_for_query_count: usize, - - // current transaction status - transaction_status: TransactionStatus, pub(crate) transaction_depth: usize, log_settings: LogSettings, + + shared: Shared, } impl PgConnection { /// the version number of the server in `libpq` format pub fn server_version_num(&self) -> Option { - self.inner.stream.server_version_num + self.inner.server_version_num } - // will return when the connection is ready for another query - pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { - if !self.inner.stream.write_buffer_mut().is_empty() { - self.inner.stream.flush().await?; - } + /// Queue a simple query (not prepared) to execute the next time this connection is used. + /// + /// Used for rolling back transactions and releasing advisory locks. + #[inline(always)] + pub(crate) fn queue_simple_query(&self, query: &str) -> Result { + self.pipe(|buf| buf.write_msg(Query(query))) + } - while self.inner.pending_ready_for_query_count > 0 { - let message = self.inner.stream.recv().await?; + pub(crate) fn in_transaction(&self) -> bool { + match self.inner.shared.get_transaction_status() { + TransactionStatus::Transaction => true, + TransactionStatus::Error | TransactionStatus::Idle => false, + } + } - if let BackendMessageFormat::ReadyForQuery = message.format { - self.handle_ready_for_query(message)?; - } + fn new( + options: &PgConnectOptions, + chan: UnboundedSender, + notifications: UnboundedReceiver, + shared: Shared, + ) -> Self { + Self { + inner: Box::new(PgConnectionInner { + chan, + notifications, + log_settings: options.log_settings.clone(), + process_id: 0, + secret_key: 0, + next_statement_id: StatementId::NAMED_START, + cache_statement: StatementCache::new(options.statement_cache_capacity), + cache_type_info: HashMap::new(), + cache_type_oid: HashMap::new(), + cache_elem_type_to_array: HashMap::new(), + transaction_depth: 0, + server_version_num: None, + shared, + }), } + } - Ok(()) + fn create_request(&self, callback: F) -> sqlx_core::Result + where + F: FnOnce(&mut MessageBuf) -> sqlx_core::Result<()>, + { + let mut buffer = MessageBuf::new(); + (callback)(&mut buffer)?; + Ok(buffer.finish()) } - async fn recv_ready_for_query(&mut self) -> Result<(), Error> { - let r: ReadyForQuery = self.inner.stream.recv_expect().await?; + fn send_request(&self, request: IoRequest) -> sqlx_core::Result<()> { + self.inner + .chan + .unbounded_send(request) + .map_err(|_| sqlx_core::Error::WorkerCrashed) + } - self.inner.pending_ready_for_query_count -= 1; - self.inner.transaction_status = r.transaction_status; + fn send_buf(&self, buf: MessageBuf) -> sqlx_core::Result { + let mut req = buf.finish(); + let (tx, rx) = unbounded(); + req.chan = Some(tx); - Ok(()) + self.send_request(req)?; + Ok(Pipe::new(rx)) } - #[inline(always)] - fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> { - self.inner.pending_ready_for_query_count = self - .inner - .pending_ready_for_query_count - .checked_sub(1) - .ok_or_else(|| err_protocol!("received more ReadyForQuery messages than expected"))?; - - self.inner.transaction_status = message.decode::()?.transaction_status; + pub(crate) fn pipe(&self, callback: F) -> sqlx_core::Result + where + F: FnOnce(&mut MessageBuf) -> sqlx_core::Result<()>, + { + let mut req = self.create_request(callback)?; + let (tx, rx) = unbounded(); + req.chan = Some(tx); - Ok(()) + self.send_request(req)?; + Ok(Pipe::new(rx)) } - /// Queue a simple query (not prepared) to execute the next time this connection is used. - /// - /// Used for rolling back transactions and releasing advisory locks. - #[inline(always)] - pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> { - self.inner.stream.write_msg(Query(query))?; - self.inner.pending_ready_for_query_count += 1; - - Ok(()) + pub(crate) fn pipe_and_forget(&self, value: T) -> sqlx_core::Result<()> + where + T: FrontendMessage, + { + let req = self.create_request(|buf| buf.write_msg(value))?; + self.send_request(req) } - pub(crate) fn in_transaction(&self) -> bool { - match self.inner.transaction_status { - TransactionStatus::Transaction => true, - TransactionStatus::Error | TransactionStatus::Idle => false, - } + pub(crate) async fn pipe_and_forget_async(&self, callback: F) -> sqlx_core::Result + where + F: AsyncFnOnce(&mut MessageBuf) -> sqlx_core::Result, + { + let mut buffer = MessageBuf::new(); + let result = (callback)(&mut buffer).await?; + let req = buffer.finish(); + self.send_request(req)?; + + Ok(result) } } @@ -150,7 +190,7 @@ impl Connection for PgConnection { type Options = PgConnectOptions; - fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { + fn close(self) -> BoxFuture<'static, Result<(), Error>> { // The normal, graceful termination procedure is that the frontend sends a Terminate // message and immediately closes the connection. @@ -158,19 +198,15 @@ impl Connection for PgConnection { // connection and terminates. Box::pin(async move { - self.inner.stream.send(Terminate).await?; - self.inner.stream.shutdown().await?; + // Closing the channel notifies the bg worker to start a graceful shutdown. + self.inner.chan.close_channel(); Ok(()) }) } - fn close_hard(mut self) -> BoxFuture<'static, Result<(), Error>> { - Box::pin(async move { - self.inner.stream.shutdown().await?; - - Ok(()) - }) + fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> { + self.close() } fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { @@ -180,8 +216,11 @@ impl Connection for PgConnection { Box::pin(async move { // The simplest call-and-response that's possible. - self.write_sync(); - self.wait_until_ready().await + let mut pipe = self.pipe(|buf| { + buf.write_sync(); + Ok(()) + })?; + pipe.recv_ready_for_query().await }) } @@ -212,19 +251,19 @@ impl Connection for PgConnection { let mut cleared = 0_usize; - self.wait_until_ready().await?; + let mut buf = MessageBuf::new(); while let Some((id, _)) = self.inner.cache_statement.remove_lru() { - self.inner.stream.write_msg(Close::Statement(id))?; + buf.write_msg(Close::Statement(id))?; cleared += 1; } if cleared > 0 { - self.write_sync(); - self.inner.stream.flush().await?; + buf.write_sync(); + let mut pipe = self.send_buf(buf)?; - self.wait_for_close_complete(cleared).await?; - self.recv_ready_for_query().await?; + pipe.wait_for_close_complete(cleared).await?; + pipe.recv_ready_for_query().await?; } Ok(()) @@ -232,17 +271,17 @@ impl Connection for PgConnection { } fn shrink_buffers(&mut self) { - self.inner.stream.shrink_buffers(); + // No-op } #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { - self.wait_until_ready().boxed() + async { Ok(()) }.boxed() } #[doc(hidden)] fn should_flush(&self) -> bool { - !self.inner.stream.write_buffer().is_empty() + false } } diff --git a/sqlx-postgres/src/connection/pipe.rs b/sqlx-postgres/src/connection/pipe.rs new file mode 100644 index 0000000000..983e11a1df --- /dev/null +++ b/sqlx-postgres/src/connection/pipe.rs @@ -0,0 +1,84 @@ +use futures_channel::mpsc::UnboundedReceiver; +use futures_util::StreamExt; +use sqlx_core::Error; + +use crate::{ + message::{BackendMessage, BackendMessageFormat, ReadyForQuery, ReceivedMessage}, + PgDatabaseError, +}; + +/// A temporary stream of responses sent from the background worker. The steam is stopped when +/// either a [ReadyForQuery] of [CopyInResponse] is received. +pub struct Pipe { + receiver: UnboundedReceiver, +} + +impl Pipe { + pub fn new(receiver: UnboundedReceiver) -> Pipe { + Self { receiver } + } + + pub(crate) async fn recv_expect(&mut self) -> Result { + self.recv().await?.decode() + } + + pub async fn recv_ready_for_query(&mut self) -> Result<(), Error> { + // The transaction status is updated in the bg worker. + let _: ReadyForQuery = self.recv_expect().await?; + Ok(()) + } + + pub(crate) async fn wait_ready_for_query(&mut self) -> Result<(), Error> { + loop { + let message = self.recv().await?; + + if let BackendMessageFormat::ReadyForQuery = message.format { + // The transaction status is updated in the bg worker. + break; + } + } + + Ok(()) + } + + // wait for CloseComplete to indicate a statement was closed + pub async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> { + // we need to wait for the [CloseComplete] to be returned from the server + while count > 0 { + match self.recv().await? { + message if message.format == BackendMessageFormat::PortalSuspended => { + // there was an open portal + // this can happen if the last time a statement was used it was not fully executed + } + + message if message.format == BackendMessageFormat::CloseComplete => { + // successfully closed the statement (and freed up the server resources) + count -= 1; + } + + message => { + return Err(err_protocol!( + "expecting PortalSuspended or CloseComplete but received {:?}", + message.format + )); + } + } + } + + Ok(()) + } + + pub(crate) async fn recv(&mut self) -> Result { + let message = self + .receiver + .next() + .await + .ok_or_else(|| sqlx_core::Error::WorkerCrashed)?; + + if message.format == BackendMessageFormat::ErrorResponse { + Err(message.decode::()?.into()) + } else { + Ok(message) + } + } +} diff --git a/sqlx-postgres/src/connection/request.rs b/sqlx-postgres/src/connection/request.rs new file mode 100644 index 0000000000..4b918c9195 --- /dev/null +++ b/sqlx-postgres/src/connection/request.rs @@ -0,0 +1,52 @@ +use futures_channel::mpsc::UnboundedSender; +use sqlx_core::{io::ProtocolEncode, Error}; + +use crate::message::{self, EncodeMessage, FrontendMessage, ReceivedMessage}; + +/// A request for the background worker. +#[derive(Debug)] +pub struct IoRequest { + pub chan: Option>, + pub data: Vec, +} + +/// A buffer that contains encoded postgres messages, ready to be sent over the wire. +pub struct MessageBuf { + data: Vec, +} + +impl MessageBuf { + pub fn new() -> Self { + Self { data: Vec::new() } + } + + #[inline(always)] + pub fn write<'en, T>(&mut self, value: T) -> sqlx_core::Result<()> + where + T: ProtocolEncode<'en, ()>, + { + value.encode(&mut self.data) + } + + #[inline(always)] + pub fn write_sync(&mut self) { + self.write_msg(message::Sync) + .expect("BUG: Sync should not be too big for protocol"); + } + + #[inline(always)] + pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> { + self.write(EncodeMessage(message)) + } + + pub(crate) fn buf_mut(&mut self) -> &mut Vec { + &mut self.data + } + + pub fn finish(self) -> IoRequest { + IoRequest { + data: self.data, + chan: None, + } + } +} diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs index 729cc1fcc5..da172b6e61 100644 --- a/sqlx-postgres/src/connection/sasl.rs +++ b/sqlx-postgres/src/connection/sasl.rs @@ -1,4 +1,3 @@ -use crate::connection::stream::PgStream; use crate::error::Error; use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse}; use crate::PgConnectOptions; @@ -9,6 +8,9 @@ use stringprep::saslprep; use base64::prelude::{Engine as _, BASE64_STANDARD}; +use super::pipe::Pipe; +use super::PgConnection; + const GS2_HEADER: &str = "n,,"; const CHANNEL_ATTR: &str = "c"; const USERNAME_ATTR: &str = "n"; @@ -16,7 +18,8 @@ const CLIENT_PROOF_ATTR: &str = "p"; const NONCE_ATTR: &str = "r"; pub(crate) async fn authenticate( - stream: &mut PgStream, + conn: &PgConnection, + pipe: &mut Pipe, options: &PgConnectOptions, data: AuthenticationSasl, ) -> Result<(), Error> { @@ -67,14 +70,12 @@ pub(crate) async fn authenticate( let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}"); - stream - .send(SaslInitialResponse { - response: &client_first_message, - plus: false, - }) - .await?; + conn.pipe_and_forget(SaslInitialResponse { + response: &client_first_message, + plus: false, + })?; - let cont = match stream.recv_expect().await? { + let cont = match pipe.recv_expect().await? { Authentication::SaslContinue(data) => data, auth => { @@ -143,9 +144,9 @@ pub(crate) async fn authenticate( let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}="); BASE64_STANDARD.encode_string(client_proof, &mut client_final_message); - stream.send(SaslResponse(&client_final_message)).await?; + conn.pipe_and_forget(SaslResponse(&client_final_message))?; - let data = match stream.recv_expect().await? { + let data = match pipe.recv_expect().await? { Authentication::SaslFinal(data) => data, auth => { diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs deleted file mode 100644 index e8a1aedc47..0000000000 --- a/sqlx-postgres/src/connection/stream.rs +++ /dev/null @@ -1,282 +0,0 @@ -use std::collections::BTreeMap; -use std::ops::{ControlFlow, Deref, DerefMut}; -use std::str::FromStr; - -use futures_channel::mpsc::UnboundedSender; -use futures_util::SinkExt; -use log::Level; -use sqlx_core::bytes::Buf; - -use crate::connection::tls::MaybeUpgradeTls; -use crate::error::Error; -use crate::message::{ - BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification, - ParameterStatus, ReceivedMessage, -}; -use crate::net::{self, BufferedSocket, Socket}; -use crate::{PgConnectOptions, PgDatabaseError, PgSeverity}; - -// the stream is a separate type from the connection to uphold the invariant where an instantiated -// [PgConnection] is a **valid** connection to postgres - -// when a new connection is asked for, we work directly on the [PgStream] type until the -// connection is fully established - -// in other words, `self` in any PgConnection method is a live connection to postgres that -// is fully prepared to receive queries - -pub struct PgStream { - // A trait object is okay here as the buffering amortizes the overhead of both the dynamic - // function call as well as the syscall. - inner: BufferedSocket>, - - // buffer of unreceived notification messages from `PUBLISH` - // this is set when creating a PgListener and only written to if that listener is - // re-used for query execution in-between receiving messages - pub(crate) notifications: Option>, - - pub(crate) parameter_statuses: BTreeMap, - - pub(crate) server_version_num: Option, -} - -impl PgStream { - pub(super) async fn connect(options: &PgConnectOptions) -> Result { - let socket_result = match options.fetch_socket() { - Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, - None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, - }; - - let socket = socket_result?; - - Ok(Self { - inner: BufferedSocket::new(socket), - notifications: None, - parameter_statuses: BTreeMap::default(), - server_version_num: None, - }) - } - - #[inline(always)] - pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> { - self.write(EncodeMessage(message)) - } - - pub(crate) async fn send(&mut self, message: T) -> Result<(), Error> - where - T: FrontendMessage, - { - self.write_msg(message)?; - self.flush().await?; - Ok(()) - } - - // Expect a specific type and format - pub(crate) async fn recv_expect(&mut self) -> Result { - self.recv().await?.decode() - } - - pub(crate) async fn recv_unchecked(&mut self) -> Result { - // NOTE: to not break everything, this should be cancel-safe; - // DO NOT modify `buf` unless a full message has been read - self.inner - .try_read(|buf| { - // all packets in postgres start with a 5-byte header - // this header contains the message type and the total length of the message - let Some(mut header) = buf.get(..5) else { - return Ok(ControlFlow::Continue(5)); - }; - - let format = BackendMessageFormat::try_from_u8(header.get_u8())?; - - let message_len = header.get_u32() as usize; - - let expected_len = message_len - .checked_add(1) - // this shouldn't really happen but is mostly a sanity check - .ok_or_else(|| { - err_protocol!("message_len + 1 overflows usize: {message_len}") - })?; - - if buf.len() < expected_len { - return Ok(ControlFlow::Continue(expected_len)); - } - - // `buf` SHOULD NOT be modified ABOVE this line - - // pop off the format code since it's not counted in `message_len` - buf.advance(1); - - // consume the message, including the length prefix - let mut contents = buf.split_to(message_len).freeze(); - - // cut off the length prefix - contents.advance(4); - - Ok(ControlFlow::Break(ReceivedMessage { format, contents })) - }) - .await - } - - // Get the next message from the server - // May wait for more data from the server - pub(crate) async fn recv(&mut self) -> Result { - loop { - let message = self.recv_unchecked().await?; - - match message.format { - BackendMessageFormat::ErrorResponse => { - // An error returned from the database server. - return Err(message.decode::()?.into()); - } - - BackendMessageFormat::NotificationResponse => { - if let Some(buffer) = &mut self.notifications { - let notification: Notification = message.decode()?; - let _ = buffer.send(notification).await; - - continue; - } - } - - BackendMessageFormat::ParameterStatus => { - // informs the frontend about the current (initial) - // setting of backend parameters - - let ParameterStatus { name, value } = message.decode()?; - // TODO: handle `client_encoding`, `DateStyle` change - - match name.as_str() { - "server_version" => { - self.server_version_num = parse_server_version(&value); - } - _ => { - self.parameter_statuses.insert(name, value); - } - } - - continue; - } - - BackendMessageFormat::NoticeResponse => { - // do we need this to be more configurable? - // if you are reading this comment and think so, open an issue - - let notice: Notice = message.decode()?; - - let (log_level, tracing_level) = match notice.severity() { - PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => { - (Level::Error, tracing::Level::ERROR) - } - PgSeverity::Warning => (Level::Warn, tracing::Level::WARN), - PgSeverity::Notice => (Level::Info, tracing::Level::INFO), - PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG), - PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE), - }; - - let log_is_enabled = log::log_enabled!( - target: "sqlx::postgres::notice", - log_level - ) || sqlx_core::private_tracing_dynamic_enabled!( - target: "sqlx::postgres::notice", - tracing_level - ); - if log_is_enabled { - sqlx_core::private_tracing_dynamic_event!( - target: "sqlx::postgres::notice", - tracing_level, - message = notice.message() - ); - } - - continue; - } - - _ => {} - } - - return Ok(message); - } - } -} - -impl Deref for PgStream { - type Target = BufferedSocket>; - - #[inline] - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for PgStream { - #[inline] - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -// reference: -// https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065 -fn parse_server_version(s: &str) -> Option { - let mut parts = Vec::::with_capacity(3); - - let mut from = 0; - let mut chs = s.char_indices().peekable(); - while let Some((i, ch)) = chs.next() { - match ch { - '.' => { - if let Ok(num) = u32::from_str(&s[from..i]) { - parts.push(num); - from = i + 1; - } else { - break; - } - } - _ if ch.is_ascii_digit() => { - if chs.peek().is_none() { - if let Ok(num) = u32::from_str(&s[from..]) { - parts.push(num); - } - break; - } - } - _ => { - if let Ok(num) = u32::from_str(&s[from..i]) { - parts.push(num); - } - break; - } - }; - } - - let version_num = match parts.as_slice() { - [major, minor, rev] => (100 * major + minor) * 100 + rev, - [major, minor] if *major >= 10 => 100 * 100 * major + minor, - [major, minor] => (100 * major + minor) * 100, - [major] => 100 * 100 * major, - _ => return None, - }; - - Some(version_num) -} - -#[cfg(test)] -mod tests { - use super::parse_server_version; - - #[test] - fn test_parse_server_version_num() { - // old style - assert_eq!(parse_server_version("9.6.1"), Some(90601)); - // new style - assert_eq!(parse_server_version("10.1"), Some(100001)); - // old style without minor version - assert_eq!(parse_server_version("9.6devel"), Some(90600)); - // new style without minor version, e.g. */ - assert_eq!(parse_server_version("10devel"), Some(100000)); - assert_eq!(parse_server_version("13devel87"), Some(130000)); - // unknown - assert_eq!(parse_server_version("unknown"), None); - } -} diff --git a/sqlx-postgres/src/connection/tls.rs b/sqlx-postgres/src/connection/tls.rs index a49c9caa8c..79761e1fd2 100644 --- a/sqlx-postgres/src/connection/tls.rs +++ b/sqlx-postgres/src/connection/tls.rs @@ -1,6 +1,6 @@ use crate::error::Error; use crate::net::tls::{self, TlsConfig}; -use crate::net::{Socket, SocketIntoBox, WithSocket}; +use crate::net::{Socket, SocketExt, SocketIntoBox, WithSocket}; use crate::message::SslRequest; use crate::{PgConnectOptions, PgSslMode}; diff --git a/sqlx-postgres/src/connection/worker.rs b/sqlx-postgres/src/connection/worker.rs new file mode 100644 index 0000000000..16f30ab554 --- /dev/null +++ b/sqlx-postgres/src/connection/worker.rs @@ -0,0 +1,319 @@ +use std::{ + collections::{BTreeMap, VecDeque}, + future::Future, + ops::ControlFlow, + pin::Pin, + sync::{Arc, Mutex, MutexGuard}, + task::{ready, Context, Poll}, +}; + +use crate::{ + message::{ + BackendMessageFormat, FrontendMessage, Notice, Notification, ParameterStatus, + ReadyForQuery, ReceivedMessage, Terminate, TransactionStatus, + }, + PgConnectOptions, +}; +use futures_channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; +use futures_util::{SinkExt, StreamExt}; +use sqlx_core::{ + bytes::Buf, + net::{self, BufferedSocket, Socket}, + rt::spawn, + Result, +}; + +use super::{request::IoRequest, tls::MaybeUpgradeTls}; + +#[derive(PartialEq, Debug)] +enum WorkerState { + // The connection is open and ready for requests. + Open, + // Responding to the last messages but not receiving new ones. After handling the last message + // a [Terminate] message is issued. + Closing, + // Last messages are handled, [Terminate] message is sent and the session is closed. Nog try + // and close the socket. + Closed, +} + +pub struct Worker { + state: WorkerState, + should_flush: bool, + chan: UnboundedReceiver, + back_log: VecDeque>, + socket: BufferedSocket>, + notif_chan: UnboundedSender, + shared: Shared, +} + +impl Worker { + pub(super) async fn connect( + options: &PgConnectOptions, + notif_chan: UnboundedSender, + shared: Shared, + ) -> crate::Result> { + let socket_result = match options.fetch_socket() { + Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, + None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, + }; + + let socket = BufferedSocket::new(socket_result?); + + Ok(Worker::spawn(socket, notif_chan, shared)) + } + + pub fn spawn( + socket: BufferedSocket>, + notif_chan: UnboundedSender, + shared: Shared, + ) -> UnboundedSender { + let (tx, rx) = unbounded(); + + let worker = Worker { + state: WorkerState::Open, + should_flush: false, + chan: rx, + back_log: VecDeque::new(), + socket, + notif_chan, + shared: shared.clone(), + }; + + spawn(worker); + tx + } + + // Tries to receive the next message from the channel. Also handles termination if needed. + #[inline(always)] + fn poll_next_request(&mut self, cx: &mut Context<'_>) -> Poll { + match self.chan.poll_next_unpin(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Some(request)) => Poll::Ready(request), + Poll::Ready(None) => { + // Channel was closed, explicitly or because the sender was dropped. Either way + // we should start a graceful shutdown. + self.state = WorkerState::Closing; + Poll::Pending + } + } + } + + #[inline(always)] + fn poll_receiver(&mut self, cx: &mut Context<'_>) -> Poll> { + // Only try and receive io requests if we're open. + if self.state != WorkerState::Open { + return Poll::Ready(Ok(())); + } + + loop { + ready!(self.socket.poll_ready_unpin(cx))?; + + let request = ready!(self.poll_next_request(cx)); + + self.socket.start_send_unpin(&request.data)?; + self.should_flush = true; + + if let Some(chan) = request.chan { + // We should send the responses back + self.back_log.push_back(chan); + } + } + } + + #[inline(always)] + fn handle_poll_flush(&mut self, cx: &mut Context<'_>) -> Result<()> { + if self.should_flush && self.socket.poll_flush_unpin(cx).is_ready() { + self.should_flush = false; + } + Ok(()) + } + + #[inline(always)] + fn send_back(&mut self, response: ReceivedMessage) -> Result<()> { + if let Some(chan) = self.back_log.front_mut() { + let _ = chan.unbounded_send(response); + Ok(()) + } else { + Err(err_protocol!("Received response but did not expect one.")) + } + } + + #[inline(always)] + fn poll_backlog(&mut self, cx: &mut Context<'_>) -> Result<()> { + while let Poll::Ready(response) = self.poll_next_message(cx)? { + match response.format { + BackendMessageFormat::ReadyForQuery => { + // Cloning a `ReceivedMessage` here is cheap because it only clones the + // underlying `Bytes` + let rfq: ReadyForQuery = response.clone().decode()?; + self.shared.set_transaction_status(rfq.transaction_status); + + self.send_back(response)?; + // Remove from the backlog so we dont send more responses back. + let _ = self.back_log.pop_front(); + } + BackendMessageFormat::CopyInResponse => { + // End of response + self.send_back(response)?; + // Remove from the backlog so we dont send more responses back. + let _ = self.back_log.pop_front(); + } + BackendMessageFormat::NotificationResponse => { + // Notification + let notif: Notification = response.decode()?; + let _ = self.notif_chan.unbounded_send(notif); + } + BackendMessageFormat::ParameterStatus => { + // Asynchronous response + let ParameterStatus { name, value } = response.decode()?; + self.shared.insert_parameter_status(name, value); + } + BackendMessageFormat::NoticeResponse => { + // do we need this to be more configurable? + // if you are reading this comment and think so, open an issue + + let notice: Notice = response.decode()?; + + notice.emit_notice(); + } + _ => self.send_back(response)?, + } + } + Ok(()) + } + + #[inline(always)] + fn poll_next_message(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.state == WorkerState::Closed { + // We're still responsing to the last messages, only after clearing the backlog we + // should stop reading. + return Poll::Pending; + } + + self.socket.poll_try_read(cx, |buf| { + // all packets in postgres start with a 5-byte header + // this header contains the message type and the total length of the message + let Some(mut header) = buf.get(..5) else { + return Ok(ControlFlow::Continue(5)); + }; + + let format = BackendMessageFormat::try_from_u8(header.get_u8())?; + + let message_len = header.get_u32() as usize; + + let expected_len = message_len + .checked_add(1) + // this shouldn't really happen but is mostly a sanity check + .ok_or_else(|| err_protocol!("message_len + 1 overflows usize: {message_len}"))?; + + if buf.len() < expected_len { + return Ok(ControlFlow::Continue(expected_len)); + } + + // `buf` SHOULD NOT be modified ABOVE this line + + // pop off the format code since it's not counted in `message_len` + buf.advance(1); + + // consume the message, including the length prefix + let mut contents = buf.split_to(message_len).freeze(); + + // cut off the length prefix + contents.advance(4); + + Ok(ControlFlow::Break(ReceivedMessage { format, contents })) + }) + } + + #[inline(always)] + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.state { + // After responding to the last messages we can issue a [Terminate] request and + // close the connection. + WorkerState::Closing if self.back_log.is_empty() => { + let terminate = [Terminate::FORMAT as u8, 0, 0, 0, 4]; + self.socket.write_buffer_mut().put_slice(&terminate); + self.state = WorkerState::Closed; + + // Closing the socket also flushes the buffer. + self.socket.poll_close_unpin(cx) + } + // The channel is closed, all requests are flushed and a [Terminate] message has been + // sent, now try and close the socket + WorkerState::Closed => self.socket.poll_close_unpin(cx), + WorkerState::Open | WorkerState::Closing => Poll::Pending, + } + } + + fn poll_worker(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Try to receive responses from the database and handle them. + self.poll_backlog(cx)?; + + // Push as many new requests in the write buffer as we can. + if let Poll::Ready(Err(e)) = self.poll_receiver(cx) { + return Poll::Ready(Err(e)); + }; + + // Flush the write buffer if needed. + self.handle_poll_flush(cx)?; + + // Close this socket if we're done. + self.poll_shutdown(cx) + } +} + +impl Future for Worker { + type Output = Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.poll_worker(cx).map_err(|e| { + tracing::error!("Background worker stopped with error: {e:?}"); + e + }) + } +} + +#[derive(Clone)] +pub struct Shared(Arc>); + +pub struct SharedInner { + pub parameter_statuses: BTreeMap, + pub transaction_status: TransactionStatus, +} + +impl Shared { + pub fn new() -> Shared { + Shared(Arc::new(Mutex::new(SharedInner { + parameter_statuses: BTreeMap::new(), + transaction_status: TransactionStatus::Idle, + }))) + } + + fn lock(&self) -> MutexGuard<'_, SharedInner> { + self.0 + .lock() + .expect("BUG: failed to get lock on shared state in worker") + } + + pub fn get_transaction_status(&self) -> TransactionStatus { + self.lock().transaction_status + } + + fn set_transaction_status(&self, status: TransactionStatus) { + self.lock().transaction_status = status + } + + fn insert_parameter_status(&self, name: String, value: String) { + self.lock().parameter_statuses.insert(name, value); + } + + pub fn remove_parameter_status(&self, name: &str) -> Option { + self.lock().parameter_statuses.remove(name) + } + + pub fn with_lock(&self, f: impl Fn(&mut SharedInner) -> T) -> T { + let mut lock = self.lock(); + f(&mut lock) + } +} diff --git a/sqlx-postgres/src/copy.rs b/sqlx-postgres/src/copy.rs index 1315ea0e20..a3f8b7d4be 100644 --- a/sqlx-postgres/src/copy.rs +++ b/sqlx-postgres/src/copy.rs @@ -12,7 +12,7 @@ use crate::ext::async_stream::TryAsyncStream; use crate::io::AsyncRead; use crate::message::{ BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse, - CopyOutResponse, CopyResponseData, Query, ReadyForQuery, + CopyOutResponse, CopyResponseData, ReadyForQuery, }; use crate::pool::{Pool, PoolConnection}; use crate::Postgres; @@ -146,14 +146,13 @@ pub struct PgCopyIn> { } impl> PgCopyIn { - async fn begin(mut conn: C, statement: &str) -> Result { - conn.wait_until_ready().await?; - conn.inner.stream.send(Query(statement)).await?; + async fn begin(conn: C, statement: &str) -> Result { + let mut pipe = conn.queue_simple_query(statement)?; - let response = match conn.inner.stream.recv_expect::().await { + let response = match pipe.recv_expect::().await { Ok(res) => res.0, Err(e) => { - conn.inner.stream.recv().await?; + pipe.recv_ready_for_query().await?; return Err(e); } }; @@ -195,13 +194,11 @@ impl> PgCopyIn { /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead. pub async fn send(&mut self, data: impl Deref) -> Result<&mut Self> { for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) { + // TODO: We should probably have some kind of back-pressure here self.conn .as_deref_mut() .expect("send_data: conn taken") - .inner - .stream - .send(CopyData(chunk)) - .await?; + .pipe_and_forget(CopyData(chunk))?; } Ok(self) @@ -224,26 +221,31 @@ impl> PgCopyIn { pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> { let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken"); loop { - let buf = conn.inner.stream.write_buffer_mut(); + let read = conn + .pipe_and_forget_async(async |buf| { + let write_buf = buf.buf_mut(); - // Write the CopyData format code and reserve space for the length. - // This may end up sending an empty `CopyData` packet if, after this point, - // we get canceled or read 0 bytes, but that should be fine. - buf.put_slice(b"d\0\0\0\x04"); + // Write the CopyData format code and reserve space for the length. + // This may end up sending an empty `CopyData` packet if, after this point, + // we get canceled or read 0 bytes, but that should be fine. + write_buf.put_slice(b"d\0\0\0\x04"); - let read = buf.read_from(&mut source).await?; + let read = sqlx_core::io::read_from(&mut source, write_buf).await?; - if read == 0 { - break; - } + // Write the length + let read32 = i32::try_from(read).map_err(|_| { + err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read) + })?; - // Write the length - let read32 = i32::try_from(read) - .map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?; + (&mut write_buf[1..]).put_i32(read32 + 4); - (&mut buf.get_mut()[1..]).put_i32(read32 + 4); + Ok(read32) + }) + .await?; - conn.inner.stream.flush().await?; + if read == 0 { + break; + } } Ok(self) @@ -255,14 +257,14 @@ impl> PgCopyIn { /// /// The server is expected to respond with an error, so only _unexpected_ errors are returned. pub async fn abort(mut self, msg: impl Into) -> Result<()> { - let mut conn = self + let conn = self .conn .take() .expect("PgCopyIn::fail_with: conn taken illegally"); - conn.inner.stream.send(CopyFail::new(msg)).await?; + let mut pipe = conn.pipe(|buf| buf.write_msg(CopyFail::new(msg)))?; - match conn.inner.stream.recv().await { + match pipe.recv().await { Ok(msg) => Err(err_protocol!( "fail_with: expected ErrorResponse, got: {:?}", msg.format @@ -271,7 +273,7 @@ impl> PgCopyIn { match e.code() { Some(Cow::Borrowed("57014")) => { // postgres abort received error code - conn.inner.stream.recv_expect::().await?; + pipe.recv_expect::().await?; Ok(()) } _ => Err(Error::Database(e)), @@ -285,21 +287,21 @@ impl> PgCopyIn { /// /// The number of rows affected is returned. pub async fn finish(mut self) -> Result { - let mut conn = self + let conn = self .conn .take() .expect("CopyWriter::finish: conn taken illegally"); - conn.inner.stream.send(CopyDone).await?; - let cc: CommandComplete = match conn.inner.stream.recv_expect().await { + let mut pipe = conn.pipe(|buf| buf.write_msg(CopyDone))?; + let cc: CommandComplete = match pipe.recv_expect().await { Ok(cc) => cc, Err(e) => { - conn.inner.stream.recv().await?; + pipe.recv().await?; return Err(e); } }; - conn.inner.stream.recv_expect::().await?; + pipe.recv_expect::().await?; Ok(cc.rows_affected()) } @@ -307,39 +309,36 @@ impl> PgCopyIn { impl> Drop for PgCopyIn { fn drop(&mut self) { - if let Some(mut conn) = self.conn.take() { - conn.inner - .stream - .write_msg(CopyFail::new( - "PgCopyIn dropped without calling finish() or fail()", - )) - .expect("BUG: PgCopyIn abort message should not be too large"); + if let Some(conn) = self.conn.take() { + conn.pipe_and_forget(CopyFail::new( + "PgCopyIn dropped without calling finish() or fail()", + )) + .expect("BUG: could not send PgCopyIn to background worker"); } } } async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( - mut conn: C, + conn: C, statement: &str, ) -> Result>> { - conn.wait_until_ready().await?; - conn.inner.stream.send(Query(statement)).await?; + let mut pipe = conn.queue_simple_query(statement)?; - let _: CopyOutResponse = conn.inner.stream.recv_expect().await?; + let _: CopyOutResponse = pipe.recv_expect().await?; let stream: TryAsyncStream<'c, Bytes> = try_stream! { loop { - match conn.inner.stream.recv().await { + match pipe.recv().await { Err(e) => { - conn.inner.stream.recv_expect::().await?; + pipe.recv_expect::().await?; return Err(e); }, Ok(msg) => match msg.format { BackendMessageFormat::CopyData => r#yield!(msg.decode::>()?.0), BackendMessageFormat::CopyDone => { let _ = msg.decode::()?; - conn.inner.stream.recv_expect::().await?; - conn.inner.stream.recv_expect::().await?; + pipe.recv_expect::().await?; + pipe.recv_expect::().await?; return Ok(()) }, _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index 32658534c4..2f5d4c275d 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -1,20 +1,17 @@ use std::fmt::{self, Debug}; -use std::io; use std::str::from_utf8; -use futures_channel::mpsc; use futures_core::future::BoxFuture; use futures_core::stream::{BoxStream, Stream}; use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::acquire::Acquire; use sqlx_core::transaction::Transaction; use sqlx_core::Either; -use tracing::Instrument; use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; -use crate::message::{BackendMessageFormat, Notification}; +use crate::message::Notification; use crate::pool::PoolOptions; use crate::pool::{Pool, PoolConnection}; use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres}; @@ -28,8 +25,6 @@ use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgre pub struct PgListener { pool: Pool, connection: Option>, - buffer_rx: mpsc::UnboundedReceiver, - buffer_tx: Option>, channels: Vec, ignore_close_event: bool, eager_reconnect: bool, @@ -58,17 +53,11 @@ impl PgListener { pub async fn connect_with(pool: &Pool) -> Result { // Pull out an initial connection - let mut connection = pool.acquire().await?; - - // Setup a notification buffer - let (sender, receiver) = mpsc::unbounded(); - connection.inner.stream.notifications = Some(sender); + let connection = pool.acquire().await?; Ok(Self { pool: pool.clone(), connection: Some(connection), - buffer_rx: receiver, - buffer_tx: None, channels: Vec::new(), ignore_close_event: false, eager_reconnect: true, @@ -173,7 +162,6 @@ impl PgListener { async fn connect_if_needed(&mut self) -> Result<(), Error> { if self.connection.is_none() { let mut connection = self.pool.acquire().await?; - connection.inner.stream.notifications = self.buffer_tx.take(); connection .execute(&*build_listen_all_query(&self.channels)) @@ -263,67 +251,37 @@ impl PgListener { // Fetch our `CloseEvent` listener, if applicable. let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event()); - loop { - let next_message = self.connection().await?.inner.stream.recv_unchecked(); - - let res = if let Some(ref mut close_event) = close_event { - // cancels the wait and returns `Err(PoolClosed)` if the pool is closed - // before `next_message` returns, or if the pool was already closed - close_event.do_until(next_message).await? - } else { - next_message.await - }; - - let message = match res { - Ok(message) => message, - - // The connection is dead, ensure that it is dropped, - // update self state, and loop to try again. - Err(Error::Io(err)) - if matches!( - err.kind(), - io::ErrorKind::ConnectionAborted | - io::ErrorKind::UnexpectedEof | - // see ERRORS section in tcp(7) man page (https://man7.org/linux/man-pages/man7/tcp.7.html) - io::ErrorKind::TimedOut | - io::ErrorKind::BrokenPipe - ) => - { - if let Some(mut conn) = self.connection.take() { - self.buffer_tx = conn.inner.stream.notifications.take(); - // Close the connection in a background task, so we can continue. - conn.close_on_drop(); - } - - if self.eager_reconnect { - self.connect_if_needed().await?; - } - - // lost connection - return Ok(None); - } - - // Forward other errors - Err(error) => { - return Err(error); - } - }; + let next_message = self.connection().await?.inner.notifications.next(); - match message.format { - // We've received an async notification, return it. - BackendMessageFormat::NotificationResponse => { - return Ok(Some(PgNotification(message.decode()?))); + let res = if let Some(ref mut close_event) = close_event { + // cancels the wait and returns `Err(PoolClosed)` if the pool is closed + // before `next_message` returns, or if the pool was already closed + close_event.do_until(next_message).await? + } else { + next_message.await + }; + + let message = match res { + Some(message) => message, + + // The connection is dead, ensure that it is dropped, + // update self state, and loop to try again. + None => { + if let Some(mut conn) = self.connection.take() { + // Close the connection in a background task, so we can continue. + conn.close_on_drop(); } - // Mark the connection as ready for another query - BackendMessageFormat::ReadyForQuery => { - self.connection().await?.inner.pending_ready_for_query_count -= 1; + if self.eager_reconnect { + self.connect_if_needed().await?; } - // Ignore unexpected messages - _ => {} + // lost connection + return Ok(None); } - } + }; + + Ok(Some(PgNotification(message))) } /// Receives the next notification that already exists in the connection buffer, if any. @@ -332,7 +290,7 @@ impl PgListener { /// /// This is helpful if you want to retrieve all buffered notifications and process them in batches. pub fn next_buffered(&mut self) -> Option { - if let Ok(Some(notification)) = self.buffer_rx.try_next() { + if let Ok(Some(notification)) = self.connection.as_mut()?.inner.notifications.try_next() { Some(PgNotification(notification)) } else { None @@ -356,18 +314,8 @@ impl PgListener { impl Drop for PgListener { fn drop(&mut self) { - if let Some(mut conn) = self.connection.take() { - let fut = async move { - let _ = conn.execute("UNLISTEN *").await; - - // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task - // otherwise, it may trigger a panic if this task is dropped because the runtime is going away: - // https://github.com/launchbadge/sqlx/issues/1389 - conn.return_to_pool().await; - }; - - // Unregister any listeners before returning the connection to the pool. - crate::rt::spawn(fut.in_current_span()); + if let Some(conn) = self.connection.take() { + let _ = conn.queue_simple_query("UNLISTEN *"); } } } diff --git a/sqlx-postgres/src/message/mod.rs b/sqlx-postgres/src/message/mod.rs index e62f9bebb3..6a1fe79f24 100644 --- a/sqlx-postgres/src/message/mod.rs +++ b/sqlx-postgres/src/message/mod.rs @@ -87,7 +87,7 @@ pub enum FrontendMessageFormat { Terminate = b'X', } -#[derive(Debug, PartialOrd, PartialEq)] +#[derive(Debug, PartialOrd, PartialEq, Clone)] #[repr(u8)] pub enum BackendMessageFormat { Authentication, @@ -113,7 +113,7 @@ pub enum BackendMessageFormat { RowDescription, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ReceivedMessage { pub format: BackendMessageFormat, pub contents: Bytes, diff --git a/sqlx-postgres/src/message/ready_for_query.rs b/sqlx-postgres/src/message/ready_for_query.rs index a1f6761b89..19ba42c4fd 100644 --- a/sqlx-postgres/src/message/ready_for_query.rs +++ b/sqlx-postgres/src/message/ready_for_query.rs @@ -3,7 +3,7 @@ use sqlx_core::bytes::Bytes; use crate::error::Error; use crate::message::{BackendMessage, BackendMessageFormat}; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] #[repr(u8)] pub enum TransactionStatus { /// Not in a transaction block. diff --git a/sqlx-postgres/src/message/response.rs b/sqlx-postgres/src/message/response.rs index a7c09cfa34..ac60cfb50a 100644 --- a/sqlx-postgres/src/message/response.rs +++ b/sqlx-postgres/src/message/response.rs @@ -1,6 +1,7 @@ use std::ops::Range; use std::str::from_utf8; +use log::Level; use memchr::memchr; use sqlx_core::bytes::Bytes; @@ -90,6 +91,33 @@ impl Notice { .map(|(_, range)| &self.storage[range]) .next() } + + pub(crate) fn emit_notice(&self) { + let (log_level, tracing_level) = match self.severity() { + PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => { + (Level::Error, tracing::Level::ERROR) + } + PgSeverity::Warning => (Level::Warn, tracing::Level::WARN), + PgSeverity::Notice => (Level::Info, tracing::Level::INFO), + PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG), + PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE), + }; + + let log_is_enabled = log::log_enabled!( + target: "sqlx::postgres::notice", + log_level + ) || sqlx_core::private_tracing_dynamic_enabled!( + target: "sqlx::postgres::notice", + tracing_level + ); + if log_is_enabled { + sqlx_core::private_tracing_dynamic_event!( + target: "sqlx::postgres::notice", + tracing_level, + message = self.message() + ); + } + } } impl Notice { diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 23352a8dcf..8eb80f0248 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -30,8 +30,8 @@ impl TransactionManager for PgTransactionManager { }; let rollback = Rollback::new(conn); - rollback.conn.queue_simple_query(&statement)?; - rollback.conn.wait_until_ready().await?; + let mut pipe = rollback.conn.queue_simple_query(&statement)?; + pipe.wait_ready_for_query().await?; if !rollback.conn.in_transaction() { return Err(Error::BeginFailed); } diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index f0d453a9a3..0526ef1d04 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1126,9 +1126,6 @@ async fn test_listener_try_recv_buffered() -> anyhow::Result<()> { txn.commit().await?; } - // Still no notifications buffered, since we haven't awaited the listener yet. - assert!(listener.next_buffered().is_none()); - // Activate connection. sqlx::query!("SELECT 1 AS one") .fetch_all(&mut listener)