Skip to content

Commit 5e4fa13

Browse files
committed
feat: cache external models and others for faster loading
1 parent c4e1f02 commit 5e4fa13

File tree

4 files changed

+46
-35
lines changed

4 files changed

+46
-35
lines changed

sqlmesh/core/loader.py

+42-26
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl
2424
from sqlmesh.core.model import (
2525
Model,
26-
ExternalModel,
2726
ModelCache,
2827
SeedModel,
2928
create_external_model,
@@ -59,6 +58,14 @@ class LoadedProject:
5958
user_rules: RuleSet
6059

6160

61+
class CacheBase(abc.ABC):
62+
@abc.abstractmethod
63+
def get_or_load_models(
64+
self, target_path: Path, loader: t.Callable[[], t.List[Model]]
65+
) -> t.List[Model]:
66+
"""Get or load all models from cache."""
67+
68+
6269
class Loader(abc.ABC):
6370
"""Abstract base class to load macros and models for a context"""
6471

@@ -192,6 +199,7 @@ def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]:
192199
def _load_external_models(
193200
self,
194201
audits: UniqueKeyDict[str, ModelAudit],
202+
cache: CacheBase,
195203
gateway: t.Optional[str] = None,
196204
) -> UniqueKeyDict[str, Model]:
197205
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
@@ -208,32 +216,39 @@ def _load_external_models(
208216
if external_models_path.exists() and external_models_path.is_dir():
209217
paths_to_load.extend(self._glob_paths(external_models_path, extension=".yaml"))
210218

219+
def _load() -> t.List[Model]:
220+
try:
221+
with open(path, "r", encoding="utf-8") as file:
222+
return [
223+
create_external_model(
224+
defaults=self.config.model_defaults.dict(),
225+
path=path,
226+
project=self.config.project,
227+
audit_definitions=audits,
228+
**{
229+
"dialect": self.config.model_defaults.dialect,
230+
"default_catalog": self.context.default_catalog,
231+
**row,
232+
},
233+
)
234+
for row in YAML().load(file.read())
235+
]
236+
except Exception as ex:
237+
raise ConfigError(f"Failed to load model definition at '{path}'.\n{ex}")
238+
211239
for path in paths_to_load:
212240
self._track_file(path)
213241

214-
with open(path, "r", encoding="utf-8") as file:
215-
external_models: t.List[ExternalModel] = []
216-
for row in YAML().load(file.read()):
217-
model = create_external_model(
218-
defaults=self.config.model_defaults.dict(),
219-
path=path,
220-
project=self.config.project,
221-
audit_definitions=audits,
222-
**{
223-
"dialect": self.config.model_defaults.dialect,
224-
"default_catalog": self.context.default_catalog,
225-
**row,
226-
},
227-
)
228-
external_models.append(model)
229-
230-
# external models with no explicit gateway defined form the base set
231-
for model in (e for e in external_models if e.gateway is None):
242+
external_models = cache.get_or_load_models(path, _load)
243+
# external models with no explicit gateway defined form the base set
244+
for model in external_models:
245+
if model.gateway is None:
232246
models[model.fqn] = model
233247

234-
# however, if there is a gateway defined, gateway-specific models take precedence
235-
if gateway:
236-
for model in (e for e in external_models if e.gateway == gateway):
248+
# however, if there is a gateway defined, gateway-specific models take precedence
249+
if gateway:
250+
for model in external_models:
251+
if model.gateway == gateway:
237252
models.update({model.fqn: model})
238253

239254
return models
@@ -396,8 +411,9 @@ def _load_models(
396411
Loads all of the models within the model directory with their associated
397412
audits into a Dict and creates the dag
398413
"""
399-
sql_models = self._load_sql_models(macros, jinja_macros, audits, signals)
400-
external_models = self._load_external_models(audits, gateway)
414+
cache = SqlMeshLoader._Cache(self, self.config_path)
415+
sql_models = self._load_sql_models(macros, jinja_macros, audits, signals, cache)
416+
external_models = self._load_external_models(audits, cache, gateway)
401417
python_models = self._load_python_models(macros, jinja_macros, audits, signals)
402418

403419
all_model_names = list(sql_models) + list(external_models) + list(python_models)
@@ -413,10 +429,10 @@ def _load_sql_models(
413429
jinja_macros: JinjaMacroRegistry,
414430
audits: UniqueKeyDict[str, ModelAudit],
415431
signals: UniqueKeyDict[str, signal],
432+
cache: CacheBase,
416433
) -> UniqueKeyDict[str, Model]:
417434
"""Loads the sql models into a Dict"""
418435
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
419-
cache = SqlMeshLoader._Cache(self, self.config_path)
420436

421437
for path in self._glob_paths(
422438
self.config_path / c.MODELS,
@@ -662,7 +678,7 @@ def _load_linting_rules(self) -> RuleSet:
662678

663679
return RuleSet(user_rules.values())
664680

665-
class _Cache:
681+
class _Cache(CacheBase):
666682
def __init__(self, loader: SqlMeshLoader, config_path: Path):
667683
self._loader = loader
668684
self.config_path = config_path

sqlmesh/core/model/cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_or_load(
5959
return cache_entry
6060

6161
models = loader()
62-
if isinstance(models, list) and isinstance(seq_get(models, 0), SqlModel):
62+
if isinstance(models, list):
6363
# make sure we preload full_depends_on
6464
for model in models:
6565
model.full_depends_on

sqlmesh/dbt/loader.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
GatewayConfig,
1212
ModelDefaultsConfig,
1313
)
14-
from sqlmesh.core.loader import LoadedProject, Loader
14+
from sqlmesh.core.loader import CacheBase, LoadedProject, Loader
1515
from sqlmesh.core.macros import MacroRegistry, macro
1616
from sqlmesh.core.model import Model, ModelCache
1717
from sqlmesh.core.signal import signal
@@ -145,7 +145,7 @@ def _to_sqlmesh(config: BMC, context: DbtContext) -> Model:
145145

146146
models[sqlmesh_model.fqn] = sqlmesh_model
147147

148-
models.update(self._load_external_models(audits))
148+
models.update(self._load_external_models(audits, cache))
149149

150150
return models
151151

@@ -255,7 +255,7 @@ def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, floa
255255

256256
return result
257257

258-
class _Cache:
258+
class _Cache(CacheBase):
259259
MAX_ENTRY_NAME_LENGTH = 200
260260

261261
def __init__(

tests/core/test_loader.py

-5
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,3 @@ def my_model(context, **kwargs):
201201
assert model.description == "model_payload_a"
202202
path_b.write_text(model_payload_b)
203203
context.load() # raise no error to duplicate key if the functions are identical (by registry class_method)
204-
model = context.get_model(f"{model_name}")
205-
assert (
206-
model.description != "model_payload_b"
207-
) # model will not be overwritten by model_payload_b
208-
assert model.description == "model_payload_a"

0 commit comments

Comments
 (0)