From 01dfd067154b9c78622d7409712b2ed5730e26be Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 11 Oct 2024 17:42:11 +0200 Subject: [PATCH 01/14] refactor: tokio main, isolate crossbeam channels --- crates/pg_lsp/Cargo.toml | 1 + crates/pg_lsp/src/main.rs | 7 +- crates/pg_lsp/src/server.rs | 184 +++++++++++++++++++++++------------- 3 files changed, 122 insertions(+), 70 deletions(-) diff --git a/crates/pg_lsp/Cargo.toml b/crates/pg_lsp/Cargo.toml index 122e3ccd..84e97785 100644 --- a/crates/pg_lsp/Cargo.toml +++ b/crates/pg_lsp/Cargo.toml @@ -32,6 +32,7 @@ pg_base_db.workspace = true pg_schema_cache.workspace = true pg_workspace.workspace = true pg_diagnostics.workspace = true +tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "sync"] } [dev-dependencies] diff --git a/crates/pg_lsp/src/main.rs b/crates/pg_lsp/src/main.rs index eb5eddb6..803e0f39 100644 --- a/crates/pg_lsp/src/main.rs +++ b/crates/pg_lsp/src/main.rs @@ -1,9 +1,12 @@ use lsp_server::Connection; use pg_lsp::server::Server; -fn main() -> anyhow::Result<()> { +#[tokio::main] +async fn main() -> anyhow::Result<()> { let (connection, threads) = Connection::stdio(); - Server::init(connection)?; + let server = Server::init(connection)?; + + server.run().await?; threads.join()?; Ok(()) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 927d7f16..351112a0 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -3,7 +3,6 @@ mod dispatch; pub mod options; use async_std::task::{self}; -use crossbeam_channel::{unbounded, Receiver, Sender}; use lsp_server::{Connection, ErrorCode, Message, RequestId}; use lsp_types::{ notification::{ @@ -33,6 +32,8 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; use text_size::TextSize; use threadpool::ThreadPool; +use tokio::sync::{mpsc, oneshot}; + use crate::{ client::{client_flags::ClientFlags, LspClient}, utils::{file_path, from_proto, line_index_ext::LineIndexExt, normalize_uri, to_proto}, @@ -68,11 +69,39 @@ impl DbConnection { } } +/// `lsp-servers` `Connection` type uses a crossbeam channel, which is not compatible with tokio's async runtime. +/// For now, we move it into a separate task and use tokio's channels to communicate. +fn get_client_receiver( + connection: Connection, +) -> (mpsc::UnboundedReceiver, oneshot::Receiver<()>) { + let (message_tx, message_rx) = mpsc::unbounded_channel(); + let (close_tx, close_rx) = oneshot::channel(); + + tokio::task::spawn(async move { + // TODO: improve Result handling + loop { + let msg = connection.receiver.recv().unwrap(); + + match msg { + Message::Request(r) if connection.handle_shutdown(&r).unwrap() => { + close_tx.send(()).unwrap(); + return; + } + + _ => message_tx.send(msg).unwrap(), + }; + } + }); + + (message_rx, close_rx) +} + pub struct Server { - connection: Arc, + client_rx: mpsc::UnboundedReceiver, + close_rx: oneshot::Receiver<()>, client: LspClient, - internal_tx: Sender, - internal_rx: Receiver, + internal_tx: mpsc::UnboundedSender, + internal_rx: mpsc::UnboundedReceiver, pool: Arc, client_flags: Arc, ide: Arc, @@ -81,10 +110,10 @@ pub struct Server { } impl Server { - pub fn init(connection: Connection) -> anyhow::Result<()> { + pub fn init(connection: Connection) -> anyhow::Result { let client = LspClient::new(connection.sender.clone()); - let (internal_tx, internal_rx) = unbounded(); + let (internal_tx, internal_rx) = mpsc::unbounded_channel(); let (id, params) = connection.initialize_start()?; let params: InitializeParams = serde_json::from_value(params)?; @@ -110,8 +139,11 @@ impl Server { let cloned_pool = pool.clone(); let cloned_client = client.clone(); + let (client_rx, close_rx) = get_client_receiver(connection); + let server = Self { - connection: Arc::new(connection), + close_rx, + client_rx, internal_rx, internal_tx, client, @@ -158,8 +190,7 @@ impl Server { pool, }; - server.run()?; - Ok(()) + Ok(server) } fn compute_now(&self) { @@ -763,67 +794,84 @@ impl Server { Ok(()) } - fn process_messages(&mut self) -> anyhow::Result<()> { + async fn process_messages(&mut self) -> anyhow::Result<()> { loop { - crossbeam_channel::select! { - recv(&self.connection.receiver) -> msg => { - match msg? { - Message::Request(request) => { - if self.connection.handle_shutdown(&request)? { - return Ok(()); - } - - if let Some(response) = dispatch::RequestDispatcher::new(request) - .on::(|id, params| self.inlay_hint(id, params))? - .on::(|id, params| self.hover(id, params))? - .on::(|id, params| self.execute_command(id, params))? - .on::(|id, params| { - self.completion(id, params) - })? - .on::(|id, params| { - self.code_actions(id, params) - })? - .default() - { - self.client.send_response(response)?; - } - } - Message::Notification(notification) => { - dispatch::NotificationDispatcher::new(notification) - .on::(|params| { - self.did_change_configuration(params) - })? - .on::(|params| self.did_close(params))? - .on::(|params| self.did_open(params))? - .on::(|params| self.did_change(params))? - .on::(|params| self.did_save(params))? - .on::(|params| self.did_close(params))? - .default(); - } - Message::Response(response) => { - self.client.recv_response(response)?; - } - }; + tokio::select! { + _ = &mut self.close_rx => { + return Ok(()) }, - recv(&self.internal_rx) -> msg => { - match msg? { - InternalMessage::SetSchemaCache(c) => { - self.ide.set_schema_cache(c); - self.compute_now(); - } - InternalMessage::RefreshSchemaCache => { - self.refresh_schema_cache(); - } - InternalMessage::PublishDiagnostics(uri) => { - self.publish_diagnostics(uri)?; - } - InternalMessage::SetOptions(options) => { - self.update_options(options); - } - }; + + msg = self.internal_rx.recv() => { + match msg { + // TODO: handle internal sender close? Is that valid state? + None => return Ok(()), + Some(m) => self.handle_internal_message(m) + } + }, + + msg = self.client_rx.recv() => { + match msg { + // the client sender is closed, we can return + None => return Ok(()), + Some(m) => self.handle_message(m) + } + }, + }?; + } + } + + fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> { + match msg { + Message::Request(request) => { + if let Some(response) = dispatch::RequestDispatcher::new(request) + .on::(|id, params| self.inlay_hint(id, params))? + .on::(|id, params| self.hover(id, params))? + .on::(|id, params| self.execute_command(id, params))? + .on::(|id, params| self.completion(id, params))? + .on::(|id, params| self.code_actions(id, params))? + .default() + { + self.client.send_response(response)?; } - }; + } + Message::Notification(notification) => { + dispatch::NotificationDispatcher::new(notification) + .on::(|params| { + self.did_change_configuration(params) + })? + .on::(|params| self.did_close(params))? + .on::(|params| self.did_open(params))? + .on::(|params| self.did_change(params))? + .on::(|params| self.did_save(params))? + .on::(|params| self.did_close(params))? + .default(); + } + Message::Response(response) => { + self.client.recv_response(response)?; + } } + + Ok(()) + } + + fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> { + match msg { + InternalMessage::SetSchemaCache(c) => { + self.ide.set_schema_cache(c); + self.compute_now(); + } + InternalMessage::RefreshSchemaCache => { + self.refresh_schema_cache(); + } + InternalMessage::PublishDiagnostics(uri) => { + self.publish_diagnostics(uri)?; + } + InternalMessage::SetOptions(options) => { + self.update_options(options); + } + } + + Ok(()) } fn pull_options(&mut self) { @@ -881,10 +929,10 @@ impl Server { } } - pub fn run(mut self) -> anyhow::Result<()> { + pub async fn run(mut self) -> anyhow::Result<()> { self.register_configuration(); self.pull_options(); - self.process_messages()?; + self.process_messages().await?; self.pool.join(); Ok(()) } From 29fa6299cce3083896a97c32ea43f579e99711e9 Mon Sep 17 00:00:00 2001 From: Julian Date: Sun, 13 Oct 2024 10:40:53 +0200 Subject: [PATCH 02/14] so far --- Cargo.lock | 74 +++++++++++++++++++++++++++++++++++-- crates/pg_lsp/src/server.rs | 22 +++++------ 2 files changed, 81 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0044279e..8fbbee9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" +dependencies = [ + "gimli", +] + [[package]] name = "adler" version = "1.0.2" @@ -216,6 +225,21 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "backtrace" +version = "0.3.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.21.7" @@ -311,11 +335,11 @@ checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" -version = "1.0.83" +version = "1.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "58e804ac3194a48bb129643eb1d62fcc20d18c6b8c181704489353d13120bcd1" dependencies = [ - "libc", + "shlex", ] [[package]] @@ -816,6 +840,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" + [[package]] name = "glob" version = "0.3.1" @@ -1250,6 +1280,15 @@ dependencies = [ "libc", ] +[[package]] +name = "object" +version = "0.36.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -1447,6 +1486,7 @@ dependencies = [ "sqlx", "text-size", "threadpool", + "tokio", ] [[package]] @@ -1879,6 +1919,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -2452,6 +2498,28 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokio" +version = "1.40.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +dependencies = [ + "backtrace", + "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "tracing" version = "0.1.40" diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 351112a0..078ed925 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -758,7 +758,7 @@ impl Server { }); } - fn refresh_schema_cache(&self) { + async fn refresh_schema_cache(&self) { if self.db_conn.is_none() { return; } @@ -767,17 +767,15 @@ impl Server { let conn = self.db_conn.as_ref().unwrap().pool.clone(); let client = self.client.clone(); - async_std::task::spawn(async move { - client - .send_notification::(ShowMessageParams { - typ: lsp_types::MessageType::INFO, - message: "Refreshing schema cache...".to_string(), - }) - .unwrap(); - let schema_cache = SchemaCache::load(&conn).await; - tx.send(InternalMessage::SetSchemaCache(schema_cache)) - .unwrap(); - }); + client + .send_notification::(ShowMessageParams { + typ: lsp_types::MessageType::INFO, + message: "Refreshing schema cache...".to_string(), + }) + .unwrap(); + let schema_cache = SchemaCache::load(&conn).await; + tx.send(InternalMessage::SetSchemaCache(schema_cache)) + .unwrap(); } fn did_change_configuration( From 5742e63a92167fdb358f65f25e4bdf2e0e3b349f Mon Sep 17 00:00:00 2001 From: Julian Date: Sun, 13 Oct 2024 13:32:20 +0200 Subject: [PATCH 03/14] fully remove async_std::task --- crates/pg_lsp/src/server.rs | 40 ++++++++++++++----------------- crates/pg_schema_cache/src/lib.rs | 2 +- xtask/src/install.rs | 5 +--- 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 078ed925..488732e5 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -2,7 +2,6 @@ mod debouncer; mod dispatch; pub mod options; -use async_std::task::{self}; use lsp_server::{Connection, ErrorCode, Message, RequestId}; use lsp_types::{ notification::{ @@ -240,7 +239,7 @@ impl Server { }); } - fn start_listening(&self) { + async fn start_listening(&self) { if self.db_conn.is_none() { return; } @@ -248,27 +247,25 @@ impl Server { let pool = self.db_conn.as_ref().unwrap().pool.clone(); let tx = self.internal_tx.clone(); - task::spawn(async move { - let mut listener = PgListener::connect_with(&pool).await.unwrap(); - listener - .listen_all(["postgres_lsp", "pgrst"]) - .await - .unwrap(); + let mut listener = PgListener::connect_with(&pool).await.unwrap(); + listener + .listen_all(["postgres_lsp", "pgrst"]) + .await + .unwrap(); - loop { - match listener.recv().await { - Ok(notification) => { - if notification.payload().to_string() == "reload schema" { - tx.send(InternalMessage::RefreshSchemaCache).unwrap(); - } - } - Err(e) => { - eprintln!("Listener error: {}", e); - break; + loop { + match listener.recv().await { + Ok(notification) => { + if notification.payload().to_string() == "reload schema" { + tx.send(InternalMessage::RefreshSchemaCache).unwrap(); } } + Err(e) => { + eprintln!("Listener error: {}", e); + break; + } } - }); + } } async fn update_db_connection(&mut self, connection_string: Option) { @@ -298,9 +295,8 @@ impl Server { }) .unwrap(); - self.refresh_schema_cache(); - - self.start_listening(); + self.refresh_schema_cache().await; + self.start_listening().await; } fn update_options(&mut self, options: Options) { diff --git a/crates/pg_schema_cache/src/lib.rs b/crates/pg_schema_cache/src/lib.rs index aed612c6..82454fc3 100644 --- a/crates/pg_schema_cache/src/lib.rs +++ b/crates/pg_schema_cache/src/lib.rs @@ -4,11 +4,11 @@ #![feature(future_join)] mod functions; -mod versions; mod schema_cache; mod schemas; mod tables; mod types; +mod versions; use sqlx::postgres::PgPool; diff --git a/xtask/src/install.rs b/xtask/src/install.rs index 85c03e13..c149bd5a 100644 --- a/xtask/src/install.rs +++ b/xtask/src/install.rs @@ -137,10 +137,7 @@ fn install_client(sh: &Shell, client_opt: ClientOpt) -> anyhow::Result<()> { } fn install_server(sh: &Shell) -> anyhow::Result<()> { - let cmd = cmd!( - sh, - "cargo install --path crates/pg_lsp --locked --force" - ); + let cmd = cmd!(sh, "cargo install --path crates/pg_lsp --locked --force"); cmd.run()?; Ok(()) } From f64661e5c299f27a052af423ab9242ddc7d2dea6 Mon Sep 17 00:00:00 2001 From: Julian Date: Sun, 13 Oct 2024 14:06:19 +0200 Subject: [PATCH 04/14] move initializer to ClientFlags --- crates/pg_lsp/src/client/client_flags.rs | 34 +++++++++++++++++++++--- crates/pg_lsp/src/main.rs | 2 +- crates/pg_lsp/src/server.rs | 10 +++---- crates/pg_lsp/src/utils/from_proto.rs | 22 --------------- 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/crates/pg_lsp/src/client/client_flags.rs b/crates/pg_lsp/src/client/client_flags.rs index 8fca812d..6209f443 100644 --- a/crates/pg_lsp/src/client/client_flags.rs +++ b/crates/pg_lsp/src/client/client_flags.rs @@ -1,10 +1,36 @@ +use lsp_types::InitializeParams; + /// Contains information about the client's capabilities. /// This is used to determine which features the server can use. #[derive(Debug, Clone)] pub struct ClientFlags { - /// If `true`, the server can pull the configuration from the client. - pub configuration_pull: bool, + /// If `true`, the server can pull configuration from the client. + pub has_configuration: bool, + + /// If `true`, the client notifies the server when its configuration changes. + pub will_push_configuration: bool, +} + +impl ClientFlags { + pub(crate) fn from_initialize_request_params(params: &InitializeParams) -> Self { + let has_configuration = params + .capabilities + .workspace + .as_ref() + .and_then(|w| w.configuration) + .unwrap_or(false); + + let will_push_configuration = params + .capabilities + .workspace + .as_ref() + .and_then(|w| w.did_change_configuration) + .and_then(|c| c.dynamic_registration) + .unwrap_or(false); - /// If `true`, the client notifies the server when the configuration changes. - pub configuration_push: bool, + Self { + has_configuration, + will_push_configuration, + } + } } diff --git a/crates/pg_lsp/src/main.rs b/crates/pg_lsp/src/main.rs index 803e0f39..9c678fac 100644 --- a/crates/pg_lsp/src/main.rs +++ b/crates/pg_lsp/src/main.rs @@ -4,8 +4,8 @@ use pg_lsp::server::Server; #[tokio::main] async fn main() -> anyhow::Result<()> { let (connection, threads) = Connection::stdio(); - let server = Server::init(connection)?; + let server = Server::init(connection)?; server.run().await?; threads.join()?; diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 488732e5..1a65dcc0 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -127,7 +127,7 @@ impl Server { connection.initialize_finish(id, serde_json::to_value(result)?)?; - let client_flags = Arc::new(from_proto::client_flags(params.capabilities)); + let client_flags = Arc::new(ClientFlags::from_initialize_request_params(¶ms)); let pool = Arc::new(threadpool::Builder::new().build()); @@ -200,7 +200,7 @@ impl Server { self.compute_debouncer.clear(); - self.pool.execute(move || { + tokio::spawn(async move { client .send_notification::(ShowMessageParams { typ: lsp_types::MessageType::INFO, @@ -778,7 +778,7 @@ impl Server { &mut self, params: DidChangeConfigurationParams, ) -> anyhow::Result<()> { - if self.client_flags.configuration_pull { + if self.client_flags.has_configuration { self.pull_options(); } else { let options = self.client.parse_options(params.settings)?; @@ -869,7 +869,7 @@ impl Server { } fn pull_options(&mut self) { - if !self.client_flags.configuration_pull { + if !self.client_flags.has_configuration { return; } @@ -899,7 +899,7 @@ impl Server { } fn register_configuration(&mut self) { - if self.client_flags.configuration_push { + if self.client_flags.will_push_configuration { let registration = Registration { id: "pull-config".to_string(), method: DidChangeConfiguration::METHOD.to_string(), diff --git a/crates/pg_lsp/src/utils/from_proto.rs b/crates/pg_lsp/src/utils/from_proto.rs index 47708be7..eaae06ce 100644 --- a/crates/pg_lsp/src/utils/from_proto.rs +++ b/crates/pg_lsp/src/utils/from_proto.rs @@ -1,5 +1,3 @@ -use crate::client::client_flags::ClientFlags; - use super::line_index_ext::LineIndexExt; use pg_base_db::{Change, Document}; @@ -17,23 +15,3 @@ pub fn content_changes( }) .collect() } - -pub fn client_flags(capabilities: lsp_types::ClientCapabilities) -> ClientFlags { - let configuration_pull = capabilities - .workspace - .as_ref() - .and_then(|cap| cap.configuration) - .unwrap_or(false); - - let configuration_push = capabilities - .workspace - .as_ref() - .and_then(|cap| cap.did_change_configuration) - .and_then(|cap| cap.dynamic_registration) - .unwrap_or(false); - - ClientFlags { - configuration_pull, - configuration_push, - } -} From 12c86fd8c39a5087218fafc2c8564a82aa892afe Mon Sep 17 00:00:00 2001 From: Julian Date: Sun, 13 Oct 2024 15:46:28 +0200 Subject: [PATCH 05/14] so far so good --- Cargo.lock | 14 ++++ crates/pg_lsp/Cargo.toml | 1 + crates/pg_lsp/src/server.rs | 146 ++++++++++++++++++++---------------- 3 files changed, 97 insertions(+), 64 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8fbbee9e..fd12e85a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1487,6 +1487,7 @@ dependencies = [ "text-size", "threadpool", "tokio", + "tokio-util", ] [[package]] @@ -2520,6 +2521,19 @@ dependencies = [ "syn 2.0.71", ] +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.40" diff --git a/crates/pg_lsp/Cargo.toml b/crates/pg_lsp/Cargo.toml index 84e97785..9d23bce9 100644 --- a/crates/pg_lsp/Cargo.toml +++ b/crates/pg_lsp/Cargo.toml @@ -33,6 +33,7 @@ pg_schema_cache.workspace = true pg_workspace.workspace = true pg_diagnostics.workspace = true tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "sync"] } +tokio-util = "0.7.12" [dev-dependencies] diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 1a65dcc0..2e038e0e 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -27,11 +27,11 @@ use pg_hover::HoverParams; use pg_schema_cache::SchemaCache; use pg_workspace::Workspace; use serde::{de::DeserializeOwned, Serialize}; -use std::{collections::HashSet, sync::Arc, time::Duration}; +use std::{collections::HashSet, future::Future, sync::Arc, time::Duration}; use text_size::TextSize; -use threadpool::ThreadPool; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; use crate::{ client::{client_flags::ClientFlags, LspClient}, @@ -72,9 +72,9 @@ impl DbConnection { /// For now, we move it into a separate task and use tokio's channels to communicate. fn get_client_receiver( connection: Connection, -) -> (mpsc::UnboundedReceiver, oneshot::Receiver<()>) { + cancel_token: Arc, +) -> mpsc::UnboundedReceiver { let (message_tx, message_rx) = mpsc::unbounded_channel(); - let (close_tx, close_rx) = oneshot::channel(); tokio::task::spawn(async move { // TODO: improve Result handling @@ -83,7 +83,7 @@ fn get_client_receiver( match msg { Message::Request(r) if connection.handle_shutdown(&r).unwrap() => { - close_tx.send(()).unwrap(); + cancel_token.cancel(); return; } @@ -92,16 +92,15 @@ fn get_client_receiver( } }); - (message_rx, close_rx) + message_rx } pub struct Server { client_rx: mpsc::UnboundedReceiver, - close_rx: oneshot::Receiver<()>, + cancel_token: Arc, client: LspClient, internal_tx: mpsc::UnboundedSender, internal_rx: mpsc::UnboundedReceiver, - pool: Arc, client_flags: Arc, ide: Arc, db_conn: Option, @@ -138,10 +137,12 @@ impl Server { let cloned_pool = pool.clone(); let cloned_client = client.clone(); - let (client_rx, close_rx) = get_client_receiver(connection); + let cancel_token = Arc::new(CancellationToken::new()); + + let client_rx = get_client_receiver(connection, cancel_token.clone()); let server = Self { - close_rx, + cancel_token, client_rx, internal_rx, internal_tx, @@ -186,7 +187,6 @@ impl Server { }); }, ), - pool, }; Ok(server) @@ -200,7 +200,7 @@ impl Server { self.compute_debouncer.clear(); - tokio::spawn(async move { + self.spawn_with_cancel(async move { client .send_notification::(ShowMessageParams { typ: lsp_types::MessageType::INFO, @@ -714,15 +714,17 @@ impl Server { Q: FnOnce() -> anyhow::Result + Send + 'static, { let client = self.client.clone(); - self.pool.execute(move || match query() { - Ok(result) => { - let response = lsp_server::Response::new_ok(id, result); - client.send_response(response).unwrap(); - } - Err(why) => { - client - .send_error(id, ErrorCode::InternalError, why.to_string()) - .unwrap(); + self.spawn_with_cancel(async move { + match query() { + Ok(result) => { + let response = lsp_server::Response::new_ok(id, result); + client.send_response(response).unwrap(); + } + Err(why) => { + client + .send_error(id, ErrorCode::InternalError, why.to_string()) + .unwrap(); + } } }); } @@ -748,9 +750,11 @@ impl Server { let client = self.client.clone(); let ide = Arc::clone(&self.ide); - self.pool.execute(move || { + self.spawn_with_cancel(async move { let response = lsp_server::Response::new_ok(id, query(&ide)); - client.send_response(response).unwrap(); + client + .send_response(response) + .expect("Failed to send query to client"); }); } @@ -791,22 +795,21 @@ impl Server { async fn process_messages(&mut self) -> anyhow::Result<()> { loop { tokio::select! { - _ = &mut self.close_rx => { + _ = self.cancel_token.cancelled() => { + // Close the loop, proceed to shutdown. return Ok(()) }, msg = self.internal_rx.recv() => { match msg { - // TODO: handle internal sender close? Is that valid state? - None => return Ok(()), - Some(m) => self.handle_internal_message(m) + None => panic!("The LSP's internal sender closed. This should never happen."), + Some(m) => self.handle_internal_message(m).await } }, msg = self.client_rx.recv() => { match msg { - // the client sender is closed, we can return - None => return Ok(()), + None => panic!("The LSP's client closed, but not via an 'exit' method. This should never happen."), Some(m) => self.handle_message(m) } }, @@ -848,14 +851,14 @@ impl Server { Ok(()) } - fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> { + async fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> { match msg { InternalMessage::SetSchemaCache(c) => { self.ide.set_schema_cache(c); self.compute_now(); } InternalMessage::RefreshSchemaCache => { - self.refresh_schema_cache(); + self.refresh_schema_cache().await; } InternalMessage::PublishDiagnostics(uri) => { self.publish_diagnostics(uri)?; @@ -869,10 +872,6 @@ impl Server { } fn pull_options(&mut self) { - if !self.client_flags.has_configuration { - return; - } - let params = ConfigurationParams { items: vec![ConfigurationItem { section: Some("postgres_lsp".to_string()), @@ -881,53 +880,72 @@ impl Server { }; let client = self.client.clone(); - let sender = self.internal_tx.clone(); - self.pool.execute(move || { + let internal_tx = self.internal_tx.clone(); + self.spawn_with_cancel(async move { match client.send_request::(params) { Ok(mut json) => { let options = client .parse_options(json.pop().expect("invalid configuration request")) .unwrap(); - sender.send(InternalMessage::SetOptions(options)).unwrap(); + if let Err(why) = internal_tx.send(InternalMessage::SetOptions(options)) { + println!("Failed to set internal options: {}", why); + } } - Err(_why) => { - // log::error!("Retrieving configuration failed: {}", why); + Err(why) => { + println!("Retrieving configuration failed: {}", why); } }; }); } fn register_configuration(&mut self) { - if self.client_flags.will_push_configuration { - let registration = Registration { - id: "pull-config".to_string(), - method: DidChangeConfiguration::METHOD.to_string(), - register_options: None, - }; + let registration = Registration { + id: "pull-config".to_string(), + method: DidChangeConfiguration::METHOD.to_string(), + register_options: None, + }; - let params = RegistrationParams { - registrations: vec![registration], - }; + let params = RegistrationParams { + registrations: vec![registration], + }; - let client = self.client.clone(); - self.pool.execute(move || { - if let Err(_why) = client.send_request::(params) { - // log::error!( - // "Failed to register \"{}\" notification: {}", - // DidChangeConfiguration::METHOD, - // why - // ); - } - }); - } + let client = self.client.clone(); + self.spawn_with_cancel(async move { + if let Err(why) = client.send_request::(params) { + println!( + "Failed to register \"{}\" notification: {}", + DidChangeConfiguration::METHOD, + why + ); + } + }); + } + + fn spawn_with_cancel(&self, f: F) -> tokio::task::JoinHandle<()> + where + F: Future + Send + 'static, + { + let cancel_token = self.cancel_token.clone(); + tokio::spawn(async move { + tokio::select! { + _ = cancel_token.cancelled() => {}, + _ = f => {} + }; + }) } pub async fn run(mut self) -> anyhow::Result<()> { - self.register_configuration(); - self.pull_options(); + if self.client_flags.will_push_configuration { + self.register_configuration(); + } + + if self.client_flags.has_configuration { + self.pull_options(); + } + self.process_messages().await?; - self.pool.join(); + Ok(()) } } From c22d977f53d62d6e038027b10a8cbc8c747dc9cb Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 10:16:21 +0200 Subject: [PATCH 06/14] so far so good --- crates/pg_lsp/src/db_connection.rs | 61 ++++++++++++ crates/pg_lsp/src/lib.rs | 1 + crates/pg_lsp/src/server.rs | 152 +++++++++++------------------ 3 files changed, 117 insertions(+), 97 deletions(-) create mode 100644 crates/pg_lsp/src/db_connection.rs diff --git a/crates/pg_lsp/src/db_connection.rs b/crates/pg_lsp/src/db_connection.rs new file mode 100644 index 00000000..c74f8642 --- /dev/null +++ b/crates/pg_lsp/src/db_connection.rs @@ -0,0 +1,61 @@ +use pg_schema_cache::SchemaCache; +use sqlx::{postgres::PgListener, PgPool}; + +#[derive(Debug)] +pub(crate) struct DbConnection { + pub pool: PgPool, + connection_string: String, +} + +impl DbConnection { + pub(crate) async fn new(connection_string: String) -> Result { + let pool = PgPool::connect(&connection_string).await?; + Ok(Self { + pool, + connection_string: connection_string, + }) + } + + pub(crate) async fn refresh_db_connection( + self, + connection_string: Option, + ) -> anyhow::Result { + if connection_string.is_none() + || connection_string.as_ref() == Some(&self.connection_string) + { + return Ok(self); + } + + self.pool.close().await; + + let conn = DbConnection::new(connection_string.unwrap()).await?; + + Ok(conn) + } + + pub(crate) async fn start_listening(&self, on_schema_update: F) -> anyhow::Result<()> + where + F: Fn() -> () + Send + 'static, + { + let mut listener = PgListener::connect_with(&self.pool).await?; + listener.listen_all(["postgres_lsp", "pgrst"]).await?; + + loop { + match listener.recv().await { + Ok(notification) => { + if notification.payload().to_string() == "reload schema" { + on_schema_update(); + } + } + Err(e) => { + eprintln!("Listener error: {}", e); + return Err(e.into()); + } + } + } + } + + pub(crate) async fn get_schema_cache(&self) -> SchemaCache { + SchemaCache::load(&self.pool).await + } +} diff --git a/crates/pg_lsp/src/lib.rs b/crates/pg_lsp/src/lib.rs index ac95c913..97474d52 100644 --- a/crates/pg_lsp/src/lib.rs +++ b/crates/pg_lsp/src/lib.rs @@ -1,3 +1,4 @@ mod client; +mod db_connection; pub mod server; mod utils; diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 2e038e0e..fc1fda71 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -35,6 +35,7 @@ use tokio_util::sync::CancellationToken; use crate::{ client::{client_flags::ClientFlags, LspClient}, + db_connection::DbConnection, utils::{file_path, from_proto, line_index_ext::LineIndexExt, normalize_uri, to_proto}, }; @@ -52,22 +53,6 @@ enum InternalMessage { SetSchemaCache(SchemaCache), } -#[derive(Debug)] -struct DbConnection { - pub pool: PgPool, - connection_string: String, -} - -impl DbConnection { - pub async fn new(connection_string: &str) -> Result { - let pool = PgPool::connect(connection_string).await?; - Ok(Self { - pool, - connection_string: connection_string.to_owned(), - }) - } -} - /// `lsp-servers` `Connection` type uses a crossbeam channel, which is not compatible with tokio's async runtime. /// For now, we move it into a separate task and use tokio's channels to communicate. fn get_client_receiver( @@ -110,21 +95,9 @@ pub struct Server { impl Server { pub fn init(connection: Connection) -> anyhow::Result { let client = LspClient::new(connection.sender.clone()); + let cancel_token = Arc::new(CancellationToken::new()); - let (internal_tx, internal_rx) = mpsc::unbounded_channel(); - - let (id, params) = connection.initialize_start()?; - let params: InitializeParams = serde_json::from_value(params)?; - - let result = InitializeResult { - capabilities: Self::capabilities(), - server_info: Some(ServerInfo { - name: "Postgres LSP".to_owned(), - version: Some(env!("CARGO_PKG_VERSION").to_owned()), - }), - }; - - connection.initialize_finish(id, serde_json::to_value(result)?)?; + let (params, client_rx) = Self::establish_client_connection(connection, &cancel_token)?; let client_flags = Arc::new(ClientFlags::from_initialize_request_params(¶ms)); @@ -132,15 +105,13 @@ impl Server { let ide = Arc::new(Workspace::new()); + let (internal_tx, internal_rx) = mpsc::unbounded_channel(); + let cloned_tx = internal_tx.clone(); let cloned_ide = ide.clone(); let cloned_pool = pool.clone(); let cloned_client = client.clone(); - let cancel_token = Arc::new(CancellationToken::new()); - - let client_rx = get_client_receiver(connection, cancel_token.clone()); - let server = Self { cancel_token, client_rx, @@ -239,68 +210,31 @@ impl Server { }); } - async fn start_listening(&self) { - if self.db_conn.is_none() { - return; - } - - let pool = self.db_conn.as_ref().unwrap().pool.clone(); - let tx = self.internal_tx.clone(); - - let mut listener = PgListener::connect_with(&pool).await.unwrap(); - listener - .listen_all(["postgres_lsp", "pgrst"]) - .await - .unwrap(); - - loop { - match listener.recv().await { - Ok(notification) => { - if notification.payload().to_string() == "reload schema" { - tx.send(InternalMessage::RefreshSchemaCache).unwrap(); - } - } - Err(e) => { - eprintln!("Listener error: {}", e); - break; - } - } - } - } - - async fn update_db_connection(&mut self, connection_string: Option) { - if connection_string == self.db_conn.as_ref().map(|c| c.connection_string.clone()) { - return; - } - if let Some(conn) = self.db_conn.take() { - conn.pool.close().await; - } - - if connection_string.is_none() { - return; - } - - let new_conn = DbConnection::new(connection_string.unwrap().as_str()).await; - - if new_conn.is_err() { - return; + async fn update_options(&mut self, options: Options) -> anyhow::Result<()> { + if options.db_connection_string.is_none() { + return Ok(()); } - self.db_conn = Some(new_conn.unwrap()); - - self.client - .send_notification::(ShowMessageParams { - typ: lsp_types::MessageType::INFO, - message: "Connection to database established".to_string(), - }) - .unwrap(); + let new_conn = if self.db_conn.is_none() { + DbConnection::new(options.db_connection_string.clone().unwrap()).await? + } else { + let current_conn = self.db_conn.take().unwrap(); + current_conn + .refresh_db_connection(options.db_connection_string.clone()) + .await? + }; - self.refresh_schema_cache().await; - self.start_listening().await; - } + let internal_tx = self.internal_tx.clone(); + self.spawn_with_cancel(async move { + new_conn.start_listening(move || { + internal_tx + .send(InternalMessage::RefreshSchemaCache) + .unwrap(); + // TODO: handle result + }).await.unwrap() + }); - fn update_options(&mut self, options: Options) { - async_std::task::block_on(self.update_db_connection(options.db_connection_string)); + Ok(()) } fn capabilities() -> ServerCapabilities { @@ -922,16 +856,40 @@ impl Server { }); } - fn spawn_with_cancel(&self, f: F) -> tokio::task::JoinHandle<()> + fn establish_client_connection( + connection: Connection, + cancel_token: &Arc, + ) -> anyhow::Result<(InitializeParams, mpsc::UnboundedReceiver)> { + let (id, params) = connection.initialize_start()?; + + let params: InitializeParams = serde_json::from_value(params)?; + + let result = InitializeResult { + capabilities: Self::capabilities(), + server_info: Some(ServerInfo { + name: "Postgres LSP".to_owned(), + version: Some(env!("CARGO_PKG_VERSION").to_owned()), + }), + }; + + connection.initialize_finish(id, serde_json::to_value(result)?)?; + + let client_rx = get_client_receiver(connection, cancel_token.clone()); + + Ok((params, client_rx)) + } + + fn spawn_with_cancel(&self, f: F) -> tokio::task::JoinHandle> where - F: Future + Send + 'static, + F: Future + Send + 'static, + O: Send + 'static, { let cancel_token = self.cancel_token.clone(); tokio::spawn(async move { tokio::select! { - _ = cancel_token.cancelled() => {}, - _ = f => {} - }; + _ = cancel_token.cancelled() => None, + output = f => Some(output) + } }) } From da769cbb5aef8315a774aa77d6cec115a3d73acc Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 12:10:52 +0200 Subject: [PATCH 07/14] udpate database via messages --- crates/pg_lsp/src/client.rs | 8 +++ crates/pg_lsp/src/db_connection.rs | 61 +++++++++-------- crates/pg_lsp/src/server.rs | 106 ++++++++++++++++------------- 3 files changed, 99 insertions(+), 76 deletions(-) diff --git a/crates/pg_lsp/src/client.rs b/crates/pg_lsp/src/client.rs index 85ff5a61..85176942 100644 --- a/crates/pg_lsp/src/client.rs +++ b/crates/pg_lsp/src/client.rs @@ -50,6 +50,14 @@ impl LspClient { Ok(()) } + /// This will ignore any errors that occur while sending the notification. + pub fn send_info_notification(&self, message: &str) { + let _ = self.send_notification::(ShowMessageParams { + message: message.into(), + typ: MessageType::INFO, + }); + } + pub fn send_request(&self, params: R::Params) -> Result where R: lsp_types::request::Request, diff --git a/crates/pg_lsp/src/db_connection.rs b/crates/pg_lsp/src/db_connection.rs index c74f8642..51ba633d 100644 --- a/crates/pg_lsp/src/db_connection.rs +++ b/crates/pg_lsp/src/db_connection.rs @@ -1,10 +1,12 @@ use pg_schema_cache::SchemaCache; use sqlx::{postgres::PgListener, PgPool}; +use tokio::task::JoinHandle; #[derive(Debug)] pub(crate) struct DbConnection { pub pool: PgPool, connection_string: String, + schema_update_handle: Option>, } impl DbConnection { @@ -13,49 +15,52 @@ impl DbConnection { Ok(Self { pool, connection_string: connection_string, + schema_update_handle: None, }) } - pub(crate) async fn refresh_db_connection( - self, - connection_string: Option, - ) -> anyhow::Result { - if connection_string.is_none() - || connection_string.as_ref() == Some(&self.connection_string) - { - return Ok(self); - } + pub(crate) fn connected_to(&self, connection_string: &str) -> bool { + connection_string == self.connection_string + } + pub(crate) async fn close(self) { + if self.schema_update_handle.is_some() { + self.schema_update_handle.unwrap().abort(); + } self.pool.close().await; - - let conn = DbConnection::new(connection_string.unwrap()).await?; - - Ok(conn) } - pub(crate) async fn start_listening(&self, on_schema_update: F) -> anyhow::Result<()> + pub(crate) async fn listen_for_schema_updates( + &mut self, + on_schema_update: F, + ) -> anyhow::Result<()> where - F: Fn() -> () + Send + 'static, + F: Fn(SchemaCache) -> () + Send + 'static, { let mut listener = PgListener::connect_with(&self.pool).await?; listener.listen_all(["postgres_lsp", "pgrst"]).await?; - loop { - match listener.recv().await { - Ok(notification) => { - if notification.payload().to_string() == "reload schema" { - on_schema_update(); + let pool = self.pool.clone(); + + let handle: JoinHandle<()> = tokio::spawn(async move { + loop { + match listener.recv().await { + Ok(not) => { + if not.payload().to_string() == "reload schema" { + let schema_cache = SchemaCache::load(&pool).await; + on_schema_update(schema_cache); + }; + } + Err(why) => { + eprintln!("Error receiving notification: {:?}", why); + break; } - } - Err(e) => { - eprintln!("Listener error: {}", e); - return Err(e.into()); } } - } - } + }); + + self.schema_update_handle = Some(handle); - pub(crate) async fn get_schema_cache(&self) -> SchemaCache { - SchemaCache::load(&self.pool).await + Ok(()) } } diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index fc1fda71..6d6af373 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -40,17 +40,14 @@ use crate::{ }; use self::{debouncer::EventDebouncer, options::Options}; -use sqlx::{ - postgres::{PgListener, PgPool}, - Executor, -}; +use sqlx::{postgres::PgPool, Executor}; #[derive(Debug)] enum InternalMessage { PublishDiagnostics(lsp_types::Url), SetOptions(Options), - RefreshSchemaCache, SetSchemaCache(SchemaCache), + SetDatabaseConnection(DbConnection), } /// `lsp-servers` `Connection` type uses a crossbeam channel, which is not compatible with tokio's async runtime. @@ -210,29 +207,54 @@ impl Server { }); } - async fn update_options(&mut self, options: Options) -> anyhow::Result<()> { - if options.db_connection_string.is_none() { + fn update_db_connection(&self, options: Options) -> anyhow::Result<()> { + if options.db_connection_string.is_none() + || self + .db_conn + .as_ref() + .is_some_and(|c| c.connected_to(options.db_connection_string.as_ref().unwrap())) + { return Ok(()); } - let new_conn = if self.db_conn.is_none() { - DbConnection::new(options.db_connection_string.clone().unwrap()).await? - } else { - let current_conn = self.db_conn.take().unwrap(); - current_conn - .refresh_db_connection(options.db_connection_string.clone()) - .await? - }; + let connection_string = options.db_connection_string.unwrap(); let internal_tx = self.internal_tx.clone(); + let client = self.client.clone(); self.spawn_with_cancel(async move { - new_conn.start_listening(move || { + match DbConnection::new(connection_string.into()).await { + Ok(conn) => { + internal_tx + .send(InternalMessage::SetDatabaseConnection(conn)) + .unwrap(); + } + Err(why) => { + client.send_info_notification(&format!("Unable to update database connection: {}", why)); + + } + } + }); + + Ok(()) + } + + async fn listen_for_schema_updates(&mut self) -> anyhow::Result<()> { + if self.db_conn.is_none() { + eprintln!("Error trying to listen for schema updates: No database connection"); + return Ok(()); + } + + let internal_tx = self.internal_tx.clone(); + self.db_conn + .as_mut() + .unwrap() + .listen_for_schema_updates(move |schema_cache| { internal_tx - .send(InternalMessage::RefreshSchemaCache) + .send(InternalMessage::SetSchemaCache(schema_cache)) .unwrap(); - // TODO: handle result - }).await.unwrap() - }); + // TODO: handle result + }) + .await?; Ok(()) } @@ -692,26 +714,6 @@ impl Server { }); } - async fn refresh_schema_cache(&self) { - if self.db_conn.is_none() { - return; - } - - let tx = self.internal_tx.clone(); - let conn = self.db_conn.as_ref().unwrap().pool.clone(); - let client = self.client.clone(); - - client - .send_notification::(ShowMessageParams { - typ: lsp_types::MessageType::INFO, - message: "Refreshing schema cache...".to_string(), - }) - .unwrap(); - let schema_cache = SchemaCache::load(&conn).await; - tx.send(InternalMessage::SetSchemaCache(schema_cache)) - .unwrap(); - } - fn did_change_configuration( &mut self, params: DidChangeConfigurationParams, @@ -720,7 +722,7 @@ impl Server { self.pull_options(); } else { let options = self.client.parse_options(params.settings)?; - self.update_options(options); + self.update_db_connection(options); } Ok(()) @@ -744,14 +746,14 @@ impl Server { msg = self.client_rx.recv() => { match msg { None => panic!("The LSP's client closed, but not via an 'exit' method. This should never happen."), - Some(m) => self.handle_message(m) + Some(m) => self.handle_message(m).await } }, }?; } } - fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> { + async fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> { match msg { Message::Request(request) => { if let Some(response) = dispatch::RequestDispatcher::new(request) @@ -768,7 +770,8 @@ impl Server { Message::Notification(notification) => { dispatch::NotificationDispatcher::new(notification) .on::(|params| { - self.did_change_configuration(params) + self.did_change_configuration(params); + Ok(()) })? .on::(|params| self.did_close(params))? .on::(|params| self.did_open(params))? @@ -788,17 +791,24 @@ impl Server { async fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> { match msg { InternalMessage::SetSchemaCache(c) => { + self.client + .send_info_notification("Refreshing Schema Cache..."); self.ide.set_schema_cache(c); + self.client.send_info_notification("Updated Schema Cache."); self.compute_now(); } - InternalMessage::RefreshSchemaCache => { - self.refresh_schema_cache().await; - } InternalMessage::PublishDiagnostics(uri) => { self.publish_diagnostics(uri)?; } InternalMessage::SetOptions(options) => { - self.update_options(options); + self.update_db_connection(options); + } + InternalMessage::SetDatabaseConnection(conn) => { + let current = self.db_conn.replace(conn); + if current.is_some() { + current.unwrap().close().await + } + self.listen_for_schema_updates(); } } From 6f672dba2d69ac38fd010df746c0303e82c40ed1 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 12:12:55 +0200 Subject: [PATCH 08/14] ok ok --- crates/pg_lsp/src/server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 6d6af373..535ddbbd 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -801,14 +801,14 @@ impl Server { self.publish_diagnostics(uri)?; } InternalMessage::SetOptions(options) => { - self.update_db_connection(options); + self.update_db_connection(options)?; } InternalMessage::SetDatabaseConnection(conn) => { let current = self.db_conn.replace(conn); if current.is_some() { current.unwrap().close().await } - self.listen_for_schema_updates(); + self.listen_for_schema_updates().await?; } } From f47f04474bbac2dd486f6268c62ae0598801bef7 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 12:15:45 +0200 Subject: [PATCH 09/14] clean up linting --- crates/pg_lsp/src/server.rs | 5 ++--- crates/pg_lsp/src/server/debouncer/thread.rs | 14 -------------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 535ddbbd..2c3550e5 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -722,7 +722,7 @@ impl Server { self.pull_options(); } else { let options = self.client.parse_options(params.settings)?; - self.update_db_connection(options); + self.update_db_connection(options)?; } Ok(()) @@ -770,8 +770,7 @@ impl Server { Message::Notification(notification) => { dispatch::NotificationDispatcher::new(notification) .on::(|params| { - self.did_change_configuration(params); - Ok(()) + self.did_change_configuration(params) })? .on::(|params| self.did_close(params))? .on::(|params| self.did_open(params))? diff --git a/crates/pg_lsp/src/server/debouncer/thread.rs b/crates/pg_lsp/src/server/debouncer/thread.rs index 1aa85939..9329b7cd 100644 --- a/crates/pg_lsp/src/server/debouncer/thread.rs +++ b/crates/pg_lsp/src/server/debouncer/thread.rs @@ -8,7 +8,6 @@ use super::buffer::{EventBuffer, Get, State}; struct DebouncerThread { mutex: Arc>, thread: JoinHandle<()>, - stopped: Arc, } impl DebouncerThread { @@ -36,14 +35,8 @@ impl DebouncerThread { Self { mutex, thread, - stopped, } } - - fn stop(self) -> JoinHandle<()> { - self.stopped.store(true, Ordering::Relaxed); - self.thread - } } /// Threaded debouncer wrapping [EventBuffer]. Accepts a common delay and a @@ -68,13 +61,6 @@ impl EventDebouncer { pub fn clear(&self) { self.0.mutex.lock().unwrap().clear(); } - - /// Signals the debouncer thread to quit and returns a - /// [std::thread::JoinHandle] which can be `.join()`ed in the consumer - /// thread. The common idiom is: `debouncer.stop().join().unwrap();` - pub fn stop(self) -> JoinHandle<()> { - self.0.stop() - } } #[cfg(test)] From 17f884f9e214703ec8adcd518cd0c6440d427559 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 12:18:22 +0200 Subject: [PATCH 10/14] undo odd changes --- crates/pg_schema_cache/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/pg_schema_cache/src/lib.rs b/crates/pg_schema_cache/src/lib.rs index 82454fc3..aed612c6 100644 --- a/crates/pg_schema_cache/src/lib.rs +++ b/crates/pg_schema_cache/src/lib.rs @@ -4,11 +4,11 @@ #![feature(future_join)] mod functions; +mod versions; mod schema_cache; mod schemas; mod tables; mod types; -mod versions; use sqlx::postgres::PgPool; From 69e19a06623b300e335321e51907298d313759ad Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 12:24:14 +0200 Subject: [PATCH 11/14] tidy --- crates/pg_lsp/src/server.rs | 28 ++++++++++---------- crates/pg_lsp/src/server/debouncer/thread.rs | 5 +--- crates/pg_schema_cache/src/lib.rs | 2 +- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 2c3550e5..af9b7ad4 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -94,11 +94,8 @@ impl Server { let client = LspClient::new(connection.sender.clone()); let cancel_token = Arc::new(CancellationToken::new()); - let (params, client_rx) = Self::establish_client_connection(connection, &cancel_token)?; + let (client_flags, client_rx) = Self::establish_client_connection(connection, &cancel_token)?; - let client_flags = Arc::new(ClientFlags::from_initialize_request_params(¶ms)); - - let pool = Arc::new(threadpool::Builder::new().build()); let ide = Arc::new(Workspace::new()); @@ -106,7 +103,7 @@ impl Server { let cloned_tx = internal_tx.clone(); let cloned_ide = ide.clone(); - let cloned_pool = pool.clone(); + let pool = Arc::new(threadpool::Builder::new().build()); let cloned_client = client.clone(); let server = Self { @@ -115,7 +112,7 @@ impl Server { internal_rx, internal_tx, client, - client_flags, + client_flags: Arc::new(client_flags), db_conn: None, ide, compute_debouncer: EventDebouncer::new( @@ -124,7 +121,7 @@ impl Server { let inner_cloned_ide = cloned_ide.clone(); let inner_cloned_tx = cloned_tx.clone(); let inner_cloned_client = cloned_client.clone(); - cloned_pool.execute(move || { + pool.execute(move || { inner_cloned_client .send_notification::(ShowMessageParams { typ: lsp_types::MessageType::INFO, @@ -229,8 +226,10 @@ impl Server { .unwrap(); } Err(why) => { - client.send_info_notification(&format!("Unable to update database connection: {}", why)); - + client.send_info_notification(&format!( + "Unable to update database connection: {}", + why + )); } } }); @@ -868,7 +867,7 @@ impl Server { fn establish_client_connection( connection: Connection, cancel_token: &Arc, - ) -> anyhow::Result<(InitializeParams, mpsc::UnboundedReceiver)> { + ) -> anyhow::Result<(ClientFlags, mpsc::UnboundedReceiver)> { let (id, params) = connection.initialize_start()?; let params: InitializeParams = serde_json::from_value(params)?; @@ -885,9 +884,12 @@ impl Server { let client_rx = get_client_receiver(connection, cancel_token.clone()); - Ok((params, client_rx)) + let client_flags = ClientFlags::from_initialize_request_params(¶ms); + + Ok((client_flags, client_rx)) } + /// Spawns an asynchronous task that can be cancelled with the `Server`'s `cancel_token`. fn spawn_with_cancel(&self, f: F) -> tokio::task::JoinHandle> where F: Future + Send + 'static, @@ -911,8 +913,6 @@ impl Server { self.pull_options(); } - self.process_messages().await?; - - Ok(()) + self.process_messages().await } } diff --git a/crates/pg_lsp/src/server/debouncer/thread.rs b/crates/pg_lsp/src/server/debouncer/thread.rs index 9329b7cd..a7486f21 100644 --- a/crates/pg_lsp/src/server/debouncer/thread.rs +++ b/crates/pg_lsp/src/server/debouncer/thread.rs @@ -32,10 +32,7 @@ impl DebouncerThread { } } }); - Self { - mutex, - thread, - } + Self { mutex, thread } } } diff --git a/crates/pg_schema_cache/src/lib.rs b/crates/pg_schema_cache/src/lib.rs index aed612c6..82454fc3 100644 --- a/crates/pg_schema_cache/src/lib.rs +++ b/crates/pg_schema_cache/src/lib.rs @@ -4,11 +4,11 @@ #![feature(future_join)] mod functions; -mod versions; mod schema_cache; mod schemas; mod tables; mod types; +mod versions; use sqlx::postgres::PgPool; From f011c25f459ccea2138c50fa592573a5dd38e93a Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 12:24:49 +0200 Subject: [PATCH 12/14] what --- crates/pg_schema_cache/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/pg_schema_cache/src/lib.rs b/crates/pg_schema_cache/src/lib.rs index 82454fc3..aed612c6 100644 --- a/crates/pg_schema_cache/src/lib.rs +++ b/crates/pg_schema_cache/src/lib.rs @@ -4,11 +4,11 @@ #![feature(future_join)] mod functions; +mod versions; mod schema_cache; mod schemas; mod tables; mod types; -mod versions; use sqlx::postgres::PgPool; From 577d0f94253fa15aca26978aa6028733b7ddf292 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 12:27:47 +0200 Subject: [PATCH 13/14] improve error handling --- crates/pg_lsp/src/server.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index af9b7ad4..637debbf 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -59,9 +59,15 @@ fn get_client_receiver( let (message_tx, message_rx) = mpsc::unbounded_channel(); tokio::task::spawn(async move { - // TODO: improve Result handling loop { - let msg = connection.receiver.recv().unwrap(); + let msg = match connection.receiver.recv() { + Ok(msg) => msg, + Err(e) => { + eprint!("Connection was closed by LSP client: {}", e); + cancel_token.cancel(); + return; + } + }; match msg { Message::Request(r) if connection.handle_shutdown(&r).unwrap() => { @@ -69,6 +75,7 @@ fn get_client_receiver( return; } + // any non-shutdown request is forwarded to the server _ => message_tx.send(msg).unwrap(), }; } From 5abcc137884745883f6636d810825852c2e1535b Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 18 Oct 2024 12:30:37 +0200 Subject: [PATCH 14/14] improve error handling #2 --- crates/pg_lsp/src/server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/pg_lsp/src/server.rs b/crates/pg_lsp/src/server.rs index 637debbf..65d2a354 100644 --- a/crates/pg_lsp/src/server.rs +++ b/crates/pg_lsp/src/server.rs @@ -216,6 +216,7 @@ impl Server { || self .db_conn .as_ref() + // if the connection is already connected to the same database, do nothing .is_some_and(|c| c.connected_to(options.db_connection_string.as_ref().unwrap())) { return Ok(()); @@ -257,8 +258,7 @@ impl Server { .listen_for_schema_updates(move |schema_cache| { internal_tx .send(InternalMessage::SetSchemaCache(schema_cache)) - .unwrap(); - // TODO: handle result + .expect("LSP Server: Failed to send internal message."); }) .await?;