Skip to content

Commit ff46fdb

Browse files
committed
feat: cache external models and others for faster loading
1 parent c4e1f02 commit ff46fdb

File tree

4 files changed

+59
-33
lines changed

4 files changed

+59
-33
lines changed

sqlmesh/core/loader.py

+48-24
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl
2424
from sqlmesh.core.model import (
2525
Model,
26-
ExternalModel,
2726
ModelCache,
2827
SeedModel,
2928
create_external_model,
@@ -59,6 +58,14 @@ class LoadedProject:
5958
user_rules: RuleSet
6059

6160

61+
class CacheBase(abc.ABC):
62+
@abc.abstractmethod
63+
def get_or_load_models(
64+
self, target_path: Path, loader: t.Callable[[], t.List[Model]]
65+
) -> t.List[Model]:
66+
"""Get or load all models from cache."""
67+
68+
6269
class Loader(abc.ABC):
6370
"""Abstract base class to load macros and models for a context"""
6471

@@ -69,6 +76,10 @@ def __init__(self, context: GenericContext, path: Path) -> None:
6976
self.config = self.context.configs[self.config_path]
7077
self._variables_by_gateway: t.Dict[str, t.Dict[str, t.Any]] = {}
7178

79+
@abc.abstractmethod
80+
def _cache(self, **kwargs: t.Any) -> CacheBase:
81+
"""Returns an instance of a CacheBase."""
82+
7283
def load(self) -> LoadedProject:
7384
"""
7485
Loads all macros and models in the context's path.
@@ -208,32 +219,41 @@ def _load_external_models(
208219
if external_models_path.exists() and external_models_path.is_dir():
209220
paths_to_load.extend(self._glob_paths(external_models_path, extension=".yaml"))
210221

222+
def _load() -> t.List[Model]:
223+
try:
224+
with open(path, "r", encoding="utf-8") as file:
225+
return [
226+
create_external_model(
227+
defaults=self.config.model_defaults.dict(),
228+
path=path,
229+
project=self.config.project,
230+
audit_definitions=audits,
231+
**{
232+
"dialect": self.config.model_defaults.dialect,
233+
"default_catalog": self.context.default_catalog,
234+
**row,
235+
},
236+
)
237+
for row in YAML().load(file.read())
238+
]
239+
except Exception as ex:
240+
raise ConfigError(f"Failed to load model definition at '{path}'.\n{ex}")
241+
242+
cache = self._cache()
243+
211244
for path in paths_to_load:
212245
self._track_file(path)
213246

214-
with open(path, "r", encoding="utf-8") as file:
215-
external_models: t.List[ExternalModel] = []
216-
for row in YAML().load(file.read()):
217-
model = create_external_model(
218-
defaults=self.config.model_defaults.dict(),
219-
path=path,
220-
project=self.config.project,
221-
audit_definitions=audits,
222-
**{
223-
"dialect": self.config.model_defaults.dialect,
224-
"default_catalog": self.context.default_catalog,
225-
**row,
226-
},
227-
)
228-
external_models.append(model)
229-
230-
# external models with no explicit gateway defined form the base set
231-
for model in (e for e in external_models if e.gateway is None):
247+
external_models = cache.get_or_load_models(path, _load)
248+
# external models with no explicit gateway defined form the base set
249+
for model in external_models:
250+
if model.gateway is None:
232251
models[model.fqn] = model
233252

234-
# however, if there is a gateway defined, gateway-specific models take precedence
235-
if gateway:
236-
for model in (e for e in external_models if e.gateway == gateway):
253+
# however, if there is a gateway defined, gateway-specific models take precedence
254+
if gateway:
255+
for model in external_models:
256+
if model.gateway == gateway:
237257
models.update({model.fqn: model})
238258

239259
return models
@@ -339,6 +359,9 @@ def _get_variables(self, gateway_name: t.Optional[str] = None) -> t.Dict[str, t.
339359
class SqlMeshLoader(Loader):
340360
"""Loads macros and models for a context using the SQLMesh file formats"""
341361

362+
def _cache(self, **kwargs: t.Any) -> CacheBase:
363+
return SqlMeshLoader._Cache(self, self.config_path)
364+
342365
def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
343366
"""Loads all user defined macros."""
344367
# Store a copy of the macro registry
@@ -416,7 +439,8 @@ def _load_sql_models(
416439
) -> UniqueKeyDict[str, Model]:
417440
"""Loads the sql models into a Dict"""
418441
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
419-
cache = SqlMeshLoader._Cache(self, self.config_path)
442+
443+
cache = self._cache()
420444

421445
for path in self._glob_paths(
422446
self.config_path / c.MODELS,
@@ -662,7 +686,7 @@ def _load_linting_rules(self) -> RuleSet:
662686

663687
return RuleSet(user_rules.values())
664688

665-
class _Cache:
689+
class _Cache(CacheBase):
666690
def __init__(self, loader: SqlMeshLoader, config_path: Path):
667691
self._loader = loader
668692
self.config_path = config_path

sqlmesh/core/model/cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_or_load(
5959
return cache_entry
6060

6161
models = loader()
62-
if isinstance(models, list) and isinstance(seq_get(models, 0), SqlModel):
62+
if isinstance(models, list):
6363
# make sure we preload full_depends_on
6464
for model in models:
6565
model.full_depends_on

sqlmesh/dbt/loader.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
GatewayConfig,
1212
ModelDefaultsConfig,
1313
)
14-
from sqlmesh.core.loader import LoadedProject, Loader
14+
from sqlmesh.core.loader import CacheBase, LoadedProject, Loader
1515
from sqlmesh.core.macros import MacroRegistry, macro
1616
from sqlmesh.core.model import Model, ModelCache
1717
from sqlmesh.core.signal import signal
@@ -83,6 +83,9 @@ def __init__(self, context: GenericContext, path: Path) -> None:
8383
self._macros_max_mtime: t.Optional[float] = None
8484
super().__init__(context, path)
8585

86+
def _cache(self, **kwargs: t.Any) -> CacheBase:
87+
return DbtLoader._Cache(self, **kwargs)
88+
8689
def load(self) -> LoadedProject:
8790
self._projects = []
8891
return super().load()
@@ -120,7 +123,11 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
120123
yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder(
121124
project.context.project_root
122125
)
123-
cache = DbtLoader._Cache(self, project, macros_max_mtime, yaml_max_mtimes)
126+
cache = self._cache(
127+
project=project,
128+
macros_max_mtime=macros_max_mtime,
129+
yaml_max_mtimes=yaml_max_mtimes,
130+
)
124131

125132
logger.debug("Converting models to sqlmesh")
126133
# Now that config is rendered, create the sqlmesh models
@@ -255,7 +262,7 @@ def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, floa
255262

256263
return result
257264

258-
class _Cache:
265+
class _Cache(CacheBase):
259266
MAX_ENTRY_NAME_LENGTH = 200
260267

261268
def __init__(

tests/core/test_loader.py

-5
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,3 @@ def my_model(context, **kwargs):
201201
assert model.description == "model_payload_a"
202202
path_b.write_text(model_payload_b)
203203
context.load() # raise no error to duplicate key if the functions are identical (by registry class_method)
204-
model = context.get_model(f"{model_name}")
205-
assert (
206-
model.description != "model_payload_b"
207-
) # model will not be overwritten by model_payload_b
208-
assert model.description == "model_payload_a"

0 commit comments

Comments
 (0)