diff --git a/crates/pg_cli/src/cli_options.rs b/crates/pg_cli/src/cli_options.rs index 24d7a3c3..20e18c8c 100644 --- a/crates/pg_cli/src/cli_options.rs +++ b/crates/pg_cli/src/cli_options.rs @@ -18,6 +18,10 @@ pub struct CliOptions { #[bpaf(long("use-server"), switch, fallback(false))] pub use_server: bool, + /// Skip connecting to the database and only run checks that don't require a database connection. + #[bpaf(long("skip-db"), switch, fallback(false))] + pub skip_db: bool, + /// Print additional diagnostics, and some diagnostics show more information. Also, print out what files were processed and which ones were modified. #[bpaf(long("verbose"), switch, fallback(false))] pub verbose: bool, diff --git a/crates/pg_cli/src/commands/mod.rs b/crates/pg_cli/src/commands/mod.rs index 6b47ecc6..6fe3472d 100644 --- a/crates/pg_cli/src/commands/mod.rs +++ b/crates/pg_cli/src/commands/mod.rs @@ -35,8 +35,10 @@ pub enum PgLspCommand { Check { #[bpaf(external(partial_configuration), hide_usage, optional)] configuration: Option, + #[bpaf(external, hide_usage)] cli_options: CliOptions, + /// Use this option when you want to format code piped from `stdin`, and print the output to `stdout`. /// /// The file doesn't need to exist on disk, what matters is the extension of the file. Based on the extension, we know how to check the code. @@ -286,6 +288,7 @@ pub(crate) trait CommandRunner: Sized { configuration, vcs_base_path, gitignore_matches, + skip_db: cli_options.skip_db, })?; let execution = self.get_execution(cli_options, console, workspace)?; diff --git a/crates/pg_cli/src/execute/process_file/check.rs b/crates/pg_cli/src/execute/process_file/check.rs index fa5b522b..866fa740 100644 --- a/crates/pg_cli/src/execute/process_file/check.rs +++ b/crates/pg_cli/src/execute/process_file/check.rs @@ -28,6 +28,7 @@ pub(crate) fn check_with_guard<'ctx>( let (only, skip) = (Vec::new(), Vec::new()); let max_diagnostics = ctx.remaining_diagnostics.load(Ordering::Relaxed); + let pull_diagnostics_result = workspace_file .guard() .pull_diagnostics( diff --git a/crates/pg_cli/src/execute/traverse.rs b/crates/pg_cli/src/execute/traverse.rs index 873b8905..ee51125c 100644 --- a/crates/pg_cli/src/execute/traverse.rs +++ b/crates/pg_cli/src/execute/traverse.rs @@ -123,6 +123,7 @@ pub(crate) fn traverse( let skipped = skipped.load(Ordering::Relaxed); let suggested_fixes_skipped = printer.skipped_fixes(); let diagnostics_not_printed = printer.not_printed_diagnostics(); + Ok(TraverseResult { summary: TraversalSummary { changed, @@ -381,6 +382,7 @@ impl<'ctx> DiagnosticsPrinter<'ctx> { } } } + diagnostics_to_print } } diff --git a/crates/pg_cli/src/reporter/gitlab.rs b/crates/pg_cli/src/reporter/gitlab.rs index ea3fd285..473f4f63 100644 --- a/crates/pg_cli/src/reporter/gitlab.rs +++ b/crates/pg_cli/src/reporter/gitlab.rs @@ -13,8 +13,8 @@ use std::{ }; pub struct GitLabReporter { - pub execution: Execution, - pub diagnostics: DiagnosticsPayload, + pub(crate) execution: Execution, + pub(crate) diagnostics: DiagnosticsPayload, } impl Reporter for GitLabReporter { diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs index a5fb0c6b..d3ff1052 100644 --- a/crates/pg_completions/src/context.rs +++ b/crates/pg_completions/src/context.rs @@ -266,7 +266,7 @@ mod tests { position: (position as u32).into(), text, tree: Some(&tree), - schema: &pg_schema_cache::SchemaCache::new(), + schema: &pg_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); @@ -298,7 +298,7 @@ mod tests { position: (position as u32).into(), text, tree: Some(&tree), - schema: &pg_schema_cache::SchemaCache::new(), + schema: &pg_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); @@ -332,7 +332,7 @@ mod tests { position: (position as u32).into(), text, tree: Some(&tree), - schema: &pg_schema_cache::SchemaCache::new(), + schema: &pg_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); @@ -357,7 +357,7 @@ mod tests { position: (position as u32).into(), text, tree: Some(&tree), - schema: &pg_schema_cache::SchemaCache::new(), + schema: &pg_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); @@ -385,7 +385,7 @@ mod tests { position: (position as u32).into(), text, tree: Some(&tree), - schema: &pg_schema_cache::SchemaCache::new(), + schema: &pg_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); @@ -411,7 +411,7 @@ mod tests { position: (position as u32).into(), text, tree: Some(&tree), - schema: &pg_schema_cache::SchemaCache::new(), + schema: &pg_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); @@ -436,7 +436,7 @@ mod tests { position: (position as u32).into(), text, tree: Some(&tree), - schema: &pg_schema_cache::SchemaCache::new(), + schema: &pg_schema_cache::SchemaCache::default(), }; let ctx = CompletionContext::new(¶ms); diff --git a/crates/pg_configuration/src/database.rs b/crates/pg_configuration/src/database.rs index 9b35ae61..2feb0330 100644 --- a/crates/pg_configuration/src/database.rs +++ b/crates/pg_configuration/src/database.rs @@ -26,6 +26,10 @@ pub struct DatabaseConfiguration { /// The name of the database. #[partial(bpaf(long("database")))] pub database: String, + + /// The connection timeout in seconds. + #[partial(bpaf(long("conn_timeout_secs"), fallback(Some(10)), debug_fallback))] + pub conn_timeout_secs: u16, } impl Default for DatabaseConfiguration { @@ -36,15 +40,7 @@ impl Default for DatabaseConfiguration { username: "postgres".to_string(), password: "postgres".to_string(), database: "postgres".to_string(), + conn_timeout_secs: 10, } } } - -impl DatabaseConfiguration { - pub fn to_connection_string(&self) -> String { - format!( - "postgres://{}:{}@{}:{}/{}", - self.username, self.password, self.host, self.port, self.database - ) - } -} diff --git a/crates/pg_configuration/src/lib.rs b/crates/pg_configuration/src/lib.rs index 9dd64f89..7190271d 100644 --- a/crates/pg_configuration/src/lib.rs +++ b/crates/pg_configuration/src/lib.rs @@ -104,6 +104,7 @@ impl PartialConfiguration { username: Some("postgres".to_string()), password: Some("postgres".to_string()), database: Some("postgres".to_string()), + conn_timeout_secs: Some(10), }), } } diff --git a/crates/pg_lsp/src/handlers/completions.rs b/crates/pg_lsp/src/handlers/completions.rs index f13526cd..83ade9f4 100644 --- a/crates/pg_lsp/src/handlers/completions.rs +++ b/crates/pg_lsp/src/handlers/completions.rs @@ -1,6 +1,6 @@ use crate::session::Session; use anyhow::Result; -use pg_workspace::workspace; +use pg_workspace::{workspace, WorkspaceError}; use tower_lsp::lsp_types::{self, CompletionItem, CompletionItemLabelDetails}; #[tracing::instrument(level = "trace", skip_all)] @@ -26,12 +26,22 @@ pub fn get_completions( pg_lsp_converters::negotiated_encoding(client_capabilities), )?; - let completion_result = session + let completion_result = match session .workspace .get_completions(workspace::CompletionParams { path, position: offset, - })?; + }) { + Ok(result) => result, + Err(e) => match e { + WorkspaceError::DatabaseConnectionError(_) => { + return Ok(lsp_types::CompletionResponse::Array(vec![])); + } + _ => { + return Err(e.into()); + } + }, + }; let items: Vec = completion_result .into_iter() diff --git a/crates/pg_lsp/src/session.rs b/crates/pg_lsp/src/session.rs index 9af4c83d..d8379617 100644 --- a/crates/pg_lsp/src/session.rs +++ b/crates/pg_lsp/src/session.rs @@ -464,6 +464,7 @@ impl Session { configuration: fs_configuration, vcs_base_path, gitignore_matches, + skip_db: false, }); if let Err(error) = result { diff --git a/crates/pg_schema_cache/src/schema_cache.rs b/crates/pg_schema_cache/src/schema_cache.rs index 8d73e631..77a0526a 100644 --- a/crates/pg_schema_cache/src/schema_cache.rs +++ b/crates/pg_schema_cache/src/schema_cache.rs @@ -18,10 +18,6 @@ pub struct SchemaCache { } impl SchemaCache { - pub fn new() -> SchemaCache { - SchemaCache::default() - } - pub async fn load(pool: &PgPool) -> Result { let (schemas, tables, functions, types, versions, columns) = futures_util::try_join!( Schema::load(pool), diff --git a/crates/pg_workspace/src/settings.rs b/crates/pg_workspace/src/settings.rs index 9f6fa661..5d8f9acd 100644 --- a/crates/pg_workspace/src/settings.rs +++ b/crates/pg_workspace/src/settings.rs @@ -5,6 +5,7 @@ use std::{ num::NonZeroU64, path::{Path, PathBuf}, sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}, + time::Duration, }; use ignore::gitignore::{Gitignore, GitignoreBuilder}; @@ -266,6 +267,7 @@ pub struct DatabaseSettings { pub username: String, pub password: String, pub database: String, + pub conn_timeout_secs: Duration, } impl Default for DatabaseSettings { @@ -276,19 +278,11 @@ impl Default for DatabaseSettings { username: "postgres".to_string(), password: "postgres".to_string(), database: "postgres".to_string(), + conn_timeout_secs: Duration::from_secs(10), } } } -impl DatabaseSettings { - pub fn to_connection_string(&self) -> String { - format!( - "postgres://{}:{}@{}:{}/{}", - self.username, self.password, self.host, self.port, self.database - ) - } -} - impl From for DatabaseSettings { fn from(value: PartialDatabaseConfiguration) -> Self { let d = DatabaseSettings::default(); @@ -298,6 +292,10 @@ impl From for DatabaseSettings { username: value.username.unwrap_or(d.username), password: value.password.unwrap_or(d.password), database: value.database.unwrap_or(d.database), + conn_timeout_secs: value + .conn_timeout_secs + .map(|s| Duration::from_secs(s.into())) + .unwrap_or(d.conn_timeout_secs), } } } diff --git a/crates/pg_workspace/src/workspace.rs b/crates/pg_workspace/src/workspace.rs index cbfd3756..16e6d135 100644 --- a/crates/pg_workspace/src/workspace.rs +++ b/crates/pg_workspace/src/workspace.rs @@ -90,6 +90,7 @@ pub struct UpdateSettingsParams { pub vcs_base_path: Option, pub gitignore_matches: Vec, pub workspace_directory: Option, + pub skip_db: bool, } #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -119,9 +120,6 @@ pub trait Workspace: Send + Sync + RefUnwindSafe { params: CompletionParams, ) -> Result; - /// Refresh the schema cache for this workspace - fn refresh_schema_cache(&self) -> Result<(), WorkspaceError>; - /// Update the global settings for this workspace fn update_settings(&self, params: UpdateSettingsParams) -> Result<(), WorkspaceError>; diff --git a/crates/pg_workspace/src/workspace/client.rs b/crates/pg_workspace/src/workspace/client.rs index c5059be9..36e67363 100644 --- a/crates/pg_workspace/src/workspace/client.rs +++ b/crates/pg_workspace/src/workspace/client.rs @@ -117,10 +117,6 @@ where self.request("pglsp/get_file_content", params) } - fn refresh_schema_cache(&self) -> Result<(), WorkspaceError> { - self.request("pglsp/refresh_schema_cache", ()) - } - fn pull_diagnostics( &self, params: super::PullDiagnosticsParams, diff --git a/crates/pg_workspace/src/workspace/server.rs b/crates/pg_workspace/src/workspace/server.rs index bf90c2c3..a8734da2 100644 --- a/crates/pg_workspace/src/workspace/server.rs +++ b/crates/pg_workspace/src/workspace/server.rs @@ -1,8 +1,10 @@ -use std::{fs, future::Future, panic::RefUnwindSafe, path::Path, sync::RwLock}; +use std::{fs, panic::RefUnwindSafe, path::Path, sync::RwLock}; use analyser::AnalyserVisitorBuilder; +use async_helper::run_async; use change::StatementChange; use dashmap::{DashMap, DashSet}; +use db_connection::DbConnection; use document::{Document, Statement}; use futures::{stream, StreamExt}; use pg_analyse::{AnalyserOptions, AnalysisFilter}; @@ -10,11 +12,8 @@ use pg_analyser::{Analyser, AnalyserConfig, AnalyserContext}; use pg_diagnostics::{serde::Diagnostic as SDiagnostic, Diagnostic, DiagnosticExt, Severity}; use pg_fs::{ConfigName, PgLspPath}; use pg_query::PgQueryStore; -use pg_schema_cache::SchemaCache; use pg_typecheck::TypecheckParams; -use sqlx::PgPool; -use std::sync::LazyLock; -use tokio::runtime::Runtime; +use schema_cache_manager::SchemaCacheManager; use tracing::info; use tree_sitter::TreeSitterStore; @@ -31,46 +30,21 @@ use super::{ }; mod analyser; +mod async_helper; mod change; +mod db_connection; mod document; mod migration; mod pg_query; +mod schema_cache_manager; mod tree_sitter; -/// Simple helper to manage the db connection and the associated connection string -#[derive(Default)] -struct DbConnection { - pool: Option, - connection_string: Option, -} - -// Global Tokio Runtime -static RUNTIME: LazyLock = - LazyLock::new(|| Runtime::new().expect("Failed to create Tokio runtime")); - -impl DbConnection { - pub(crate) fn get_pool(&self) -> Option { - self.pool.clone() - } - - pub(crate) fn set_connection(&mut self, connection_string: &str) -> Result<(), WorkspaceError> { - if self.connection_string.is_none() - || self.connection_string.as_ref().unwrap() != connection_string - { - self.connection_string = Some(connection_string.to_string()); - self.pool = Some(PgPool::connect_lazy(connection_string)?); - } - - Ok(()) - } -} - pub(super) struct WorkspaceServer { /// global settings object for this workspace settings: RwLock, /// Stores the schema cache for this workspace - schema_cache: RwLock, + schema_cache: SchemaCacheManager, /// Stores the document (text content + version number) associated with a URL documents: DashMap, @@ -105,7 +79,7 @@ impl WorkspaceServer { tree_sitter: TreeSitterStore::new(), pg_query: PgQueryStore::new(), changed_stmts: DashSet::default(), - schema_cache: RwLock::default(), + schema_cache: SchemaCacheManager::default(), connection: RwLock::default(), } } @@ -119,38 +93,6 @@ impl WorkspaceServer { SettingsHandleMut::new(&self.settings) } - fn refresh_db_connection(&self) -> Result<(), WorkspaceError> { - let s = self.settings(); - - let connection_string = s.as_ref().db.to_connection_string(); - self.connection - .write() - .unwrap() - .set_connection(&connection_string)?; - - self.reload_schema_cache()?; - - Ok(()) - } - - fn reload_schema_cache(&self) -> Result<(), WorkspaceError> { - tracing::info!("Reloading schema cache"); - // TODO return error if db connection is not available - if let Some(c) = self.connection.read().unwrap().get_pool() { - let maybe_schema_cache = run_async(async move { SchemaCache::load(&c).await })?; - let schema_cache = maybe_schema_cache?; - - let mut cache = self.schema_cache.write().unwrap(); - *cache = schema_cache; - } else { - let mut cache = self.schema_cache.write().unwrap(); - *cache = SchemaCache::default(); - } - tracing::info!("Schema cache reloaded"); - - Ok(()) - } - fn is_ignored_by_migration_config(&self, path: &Path) -> bool { let set = self.settings(); set.as_ref() @@ -201,11 +143,6 @@ impl WorkspaceServer { } impl Workspace for WorkspaceServer { - #[tracing::instrument(level = "trace", skip(self))] - fn refresh_schema_cache(&self) -> Result<(), WorkspaceError> { - self.reload_schema_cache() - } - /// Update the global settings for this workspace /// /// ## Panics @@ -222,10 +159,17 @@ impl Workspace for WorkspaceServer { params.gitignore_matches.as_slice(), )?; - self.refresh_db_connection()?; - tracing::info!("Updated settings in workspace"); + if !params.skip_db { + self.connection + .write() + .unwrap() + .set_conn_settings(&self.settings().as_ref().db); + } + + tracing::info!("Updated Db connection settings"); + Ok(()) } @@ -356,53 +300,55 @@ impl Workspace for WorkspaceServer { let mut diagnostics: Vec = vec![]; - // run diagnostics for each statement in parallel if its mostly i/o work - if let Ok(connection) = self.connection.read() { - if let Some(pool) = connection.get_pool() { - let typecheck_params: Vec<_> = doc - .iter_statements_with_text_and_range() - .map(|(stmt, range, text)| { - let ast = self.pg_query.get_ast(&stmt); - let tree = self.tree_sitter.get_parse_tree(&stmt); - (text.to_string(), ast, tree, *range) - }) - .collect(); - - let pool_clone = pool.clone(); - let path_clone = params.path.clone(); - let async_results = run_async(async move { - stream::iter(typecheck_params) - .map(|(text, ast, tree, range)| { - let pool = pool_clone.clone(); - let path = path_clone.clone(); - async move { - if let Some(ast) = ast { - pg_typecheck::check_sql(TypecheckParams { - conn: &pool, - sql: &text, - ast: &ast, - tree: tree.as_deref(), - }) - .await - .map(|d| { - let r = d.location().span.map(|span| span + range.start()); - - d.with_file_path(path.as_path().display().to_string()) - .with_file_span(r.unwrap_or(range)) - }) - } else { - None - } + if let Some(pool) = self + .connection + .read() + .expect("DbConnection RwLock panicked") + .get_pool() + { + let typecheck_params: Vec<_> = doc + .iter_statements_with_text_and_range() + .map(|(stmt, range, text)| { + let ast = self.pg_query.get_ast(&stmt); + let tree = self.tree_sitter.get_parse_tree(&stmt); + (text.to_string(), ast, tree, *range) + }) + .collect(); + + // run diagnostics for each statement in parallel if its mostly i/o work + let path_clone = params.path.clone(); + let async_results = run_async(async move { + stream::iter(typecheck_params) + .map(|(text, ast, tree, range)| { + let pool = pool.clone(); + let path = path_clone.clone(); + async move { + if let Some(ast) = ast { + pg_typecheck::check_sql(TypecheckParams { + conn: &pool, + sql: &text, + ast: &ast, + tree: tree.as_deref(), + }) + .await + .map(|d| { + let r = d.location().span.map(|span| span + range.start()); + + d.with_file_path(path.as_path().display().to_string()) + .with_file_span(r.unwrap_or(range)) + }) + } else { + None } - }) - .buffer_unordered(10) - .collect::>() - .await - })?; - - for result in async_results.into_iter().flatten() { - diagnostics.push(SDiagnostic::new(result)); - } + } + }) + .buffer_unordered(10) + .collect::>() + .await + })?; + + for result in async_results.into_iter().flatten() { + diagnostics.push(SDiagnostic::new(result)); } } @@ -470,6 +416,11 @@ impl Workspace for WorkspaceServer { ¶ms.position ); + let pool = match self.connection.read().unwrap().get_pool() { + Some(pool) => pool, + None => return Ok(pg_completions::CompletionResult::default()), + }; + let doc = self .documents .get(¶ms.path) @@ -497,14 +448,11 @@ impl Workspace for WorkspaceServer { tracing::debug!("Found the statement. We're looking for position {:?}. Statement Range {:?} to {:?}. Statement: {}", position, stmt_range.start(), stmt_range.end(), text); - let schema_cache = self - .schema_cache - .read() - .map_err(|_| WorkspaceError::runtime("Unable to load SchemaCache"))?; + let schema_cache = self.schema_cache.load(pool)?; let result = pg_completions::complete(pg_completions::CompletionParams { position, - schema: &schema_cache, + schema: schema_cache.as_ref(), tree: tree.as_deref(), text: text.to_string(), }); @@ -518,15 +466,3 @@ impl Workspace for WorkspaceServer { fn is_dir(path: &Path) -> bool { path.is_dir() || (path.is_symlink() && fs::read_link(path).is_ok_and(|path| path.is_dir())) } - -/// Use this function to run async functions in the workspace, which is a sync trait called from an -/// async context. -/// -/// Checkout https://greptime.com/blogs/2023-03-09-bridging-async-and-sync-rust for details. -fn run_async(future: F) -> Result -where - F: Future + Send + 'static, - R: Send + 'static, -{ - futures::executor::block_on(async { RUNTIME.spawn(future).await.map_err(|e| e.into()) }) -} diff --git a/crates/pg_workspace/src/workspace/server/async_helper.rs b/crates/pg_workspace/src/workspace/server/async_helper.rs new file mode 100644 index 00000000..896a63a4 --- /dev/null +++ b/crates/pg_workspace/src/workspace/server/async_helper.rs @@ -0,0 +1,21 @@ +use std::{future::Future, sync::LazyLock}; + +use tokio::runtime::Runtime; + +use crate::WorkspaceError; + +// Global Tokio Runtime +static RUNTIME: LazyLock = + LazyLock::new(|| Runtime::new().expect("Failed to create Tokio runtime")); + +/// Use this function to run async functions in the workspace, which is a sync trait called from an +/// async context. +/// +/// Checkout https://greptime.com/blogs/2023-03-09-bridging-async-and-sync-rust for details. +pub fn run_async(future: F) -> Result +where + F: Future + Send + 'static, + R: Send + 'static, +{ + futures::executor::block_on(async { RUNTIME.spawn(future).await.map_err(|e| e.into()) }) +} diff --git a/crates/pg_workspace/src/workspace/server/db_connection.rs b/crates/pg_workspace/src/workspace/server/db_connection.rs new file mode 100644 index 00000000..3a747342 --- /dev/null +++ b/crates/pg_workspace/src/workspace/server/db_connection.rs @@ -0,0 +1,35 @@ +use std::time::Duration; + +use sqlx::{pool::PoolOptions, postgres::PgConnectOptions, PgPool, Postgres}; + +use crate::settings::DatabaseSettings; + +#[derive(Default)] +pub struct DbConnection { + pool: Option, +} + +impl DbConnection { + /// There might be no pool available if the user decides to skip db checks. + pub(crate) fn get_pool(&self) -> Option { + self.pool.clone() + } + + pub(crate) fn set_conn_settings(&mut self, settings: &DatabaseSettings) { + let config = PgConnectOptions::new() + .host(&settings.host) + .port(settings.port) + .username(&settings.username) + .password(&settings.password) + .database(&settings.database); + + let timeout = settings.conn_timeout_secs.clone(); + + let pool = PoolOptions::::new() + .acquire_timeout(timeout) + .acquire_slow_threshold(Duration::from_secs(2)) + .connect_lazy_with(config); + + self.pool = Some(pool); + } +} diff --git a/crates/pg_workspace/src/workspace/server/schema_cache_manager.rs b/crates/pg_workspace/src/workspace/server/schema_cache_manager.rs new file mode 100644 index 00000000..9df910e3 --- /dev/null +++ b/crates/pg_workspace/src/workspace/server/schema_cache_manager.rs @@ -0,0 +1,83 @@ +use std::sync::{RwLock, RwLockReadGuard}; + +use pg_schema_cache::SchemaCache; +use sqlx::PgPool; + +use crate::WorkspaceError; + +use super::async_helper::run_async; + +pub(crate) struct SchemaCacheHandle<'a> { + inner: RwLockReadGuard<'a, SchemaCacheManagerInner>, +} + +impl<'a> SchemaCacheHandle<'a> { + pub(crate) fn new(cache: &'a RwLock) -> Self { + Self { + inner: cache.read().unwrap(), + } + } + + pub(crate) fn wrap(inner: RwLockReadGuard<'a, SchemaCacheManagerInner>) -> Self { + Self { inner } + } +} + +impl AsRef for SchemaCacheHandle<'_> { + fn as_ref(&self) -> &SchemaCache { + &self.inner.cache + } +} + +#[derive(Default)] +pub(crate) struct SchemaCacheManagerInner { + cache: SchemaCache, + conn_str: String, +} + +#[derive(Default)] +pub struct SchemaCacheManager { + inner: RwLock, +} + +impl SchemaCacheManager { + pub fn load(&self, pool: PgPool) -> Result { + let inner = self.inner.read().unwrap(); + + if pool_to_conn_str(&pool) == inner.conn_str { + Ok(SchemaCacheHandle::wrap(inner)) + } else { + let new_conn_str = pool_to_conn_str(&pool); + + let maybe_refreshed = run_async(async move { SchemaCache::load(&pool).await })?; + let refreshed = maybe_refreshed?; + + let mut inner = self.inner.write().unwrap(); + + inner.cache = refreshed; + inner.conn_str = new_conn_str; + + Ok(SchemaCacheHandle::new(&self.inner)) + } + } +} + +fn pool_to_conn_str(pool: &PgPool) -> String { + let conn = pool.connect_options(); + + match conn.get_database() { + None => format!( + "postgres://{}:@{}:{}", + conn.get_username(), + conn.get_host(), + conn.get_port() + ), + Some(db) => format!( + "postgres://{}:@{}:{}/{}", + conn.get_username(), + conn.get_host(), + conn.get_port(), + db + ), + } +} diff --git a/pglsp.toml b/pglsp.toml index ea394361..4f7d325d 100644 --- a/pglsp.toml +++ b/pglsp.toml @@ -12,6 +12,7 @@ port = 5432 username = "postgres" password = "postgres" database = "postgres" +conn_timeout_secs = 10 # [migrations] # migrations_dir = "migrations"