Skip to content

Commit 87bed59

Browse files
Merge pull request #2465 from agriyakhetarpal/memoization
Use `@cache` and `@cached_property` for memoization
2 parents d62272f + 94b7200 commit 87bed59

File tree

6 files changed

+86
-123
lines changed

6 files changed

+86
-123
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
## Optimizations
1010

11+
- Implemented memoization via `cache` and `cached_property` from functools ([#2465](https://github.com/pybamm-team/PyBaMM/pull/2465))
1112
- `ParameterValues` now avoids trying to process children if a function parameter is an object that doesn't depend on its children ([#2477](https://github.com/pybamm-team/PyBaMM/pull/2477))
1213
- Added more rules for simplifying expressions, especially around Concatenations. Also, meshes constructed from multiple domains are now cached ([#2443](https://github.com/pybamm-team/PyBaMM/pull/2443))
1314
- Added more rules for simplifying expressions. Constants in binary operators are now moved to the left by default (e.g. `x*2` returns `2*x`) ([#2424](https://github.com/pybamm-team/PyBaMM/pull/2424))

pybamm/expression_tree/symbol.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sympy
99
from anytree.exporter import DotExporter
1010
from scipy.sparse import csr_matrix, issparse
11+
from functools import lru_cache, cached_property
1112

1213
import pybamm
1314
from pybamm.expression_tree.printing.print_name import prettify_print_name
@@ -827,6 +828,7 @@ def evaluates_to_number(self):
827828
def evaluates_to_constant_number(self):
828829
return self.evaluates_to_number() and self.is_constant()
829830

831+
@lru_cache
830832
def evaluates_on_edges(self, dimension):
831833
"""
832834
Returns True if a symbol evaluates on an edge, i.e. symbol contains a gradient
@@ -845,12 +847,9 @@ def evaluates_on_edges(self, dimension):
845847
Whether the symbol evaluates on edges (in the finite volume discretisation
846848
sense)
847849
"""
848-
try:
849-
return self._saved_evaluates_on_edges[dimension]
850-
except KeyError:
851-
eval_on_edges = self._evaluates_on_edges(dimension)
852-
self._saved_evaluates_on_edges[dimension] = eval_on_edges
853-
return eval_on_edges
850+
eval_on_edges = self._evaluates_on_edges(dimension)
851+
self._saved_evaluates_on_edges[dimension] = eval_on_edges
852+
return eval_on_edges
854853

855854
def _evaluates_on_edges(self, dimension):
856855
# Default behaviour: return False
@@ -894,48 +893,39 @@ def new_copy(self):
894893
obj._print_name = self.print_name
895894
return obj
896895

897-
@property
896+
@cached_property
898897
def size(self):
899898
"""
900899
Size of an object, found by evaluating it with appropriate t and y
901900
"""
902-
try:
903-
return self._saved_size
904-
except AttributeError:
905-
self._saved_size = np.prod(self.shape)
906-
return self._saved_size
901+
return np.prod(self.shape)
907902

908-
@property
903+
@cached_property
909904
def shape(self):
910905
"""
911906
Shape of an object, found by evaluating it with appropriate t and y.
912907
"""
908+
# Default behaviour is to try to evaluate the object directly
909+
# Try with some large y, to avoid having to unpack (slow)
913910
try:
914-
return self._saved_shape
915-
except AttributeError:
916-
# Default behaviour is to try to evaluate the object directly
917-
# Try with some large y, to avoid having to unpack (slow)
918-
try:
919-
y = np.nan * np.ones((1000, 1))
920-
evaluated_self = self.evaluate(0, y, y, inputs="shape test")
921-
# If that fails, fall back to calculating how big y should really be
922-
except ValueError:
923-
unpacker = pybamm.SymbolUnpacker(pybamm.StateVector)
924-
state_vectors_in_node = unpacker.unpack_symbol(self)
925-
min_y_size = max(
926-
max(len(x._evaluation_array) for x in state_vectors_in_node), 1
927-
)
928-
# Pick a y that won't cause RuntimeWarnings
929-
y = np.nan * np.ones((min_y_size, 1))
930-
evaluated_self = self.evaluate(0, y, y, inputs="shape test")
931-
932-
# Return shape of evaluated object
933-
if isinstance(evaluated_self, numbers.Number):
934-
self._saved_shape = ()
935-
else:
936-
self._saved_shape = evaluated_self.shape
911+
y = np.nan * np.ones((1000, 1))
912+
evaluated_self = self.evaluate(0, y, y, inputs="shape test")
913+
# If that fails, fall back to calculating how big y should really be
914+
except ValueError:
915+
unpacker = pybamm.SymbolUnpacker(pybamm.StateVector)
916+
state_vectors_in_node = unpacker.unpack_symbol(self)
917+
min_y_size = max(
918+
max(len(x._evaluation_array) for x in state_vectors_in_node), 1
919+
)
920+
# Pick a y that won't cause RuntimeWarnings
921+
y = np.nan * np.ones((min_y_size, 1))
922+
evaluated_self = self.evaluate(0, y, y, inputs="shape test")
937923

938-
return self._saved_shape
924+
# Return shape of evaluated object
925+
if isinstance(evaluated_self, numbers.Number):
926+
return ()
927+
else:
928+
return evaluated_self.shape
939929

940930
@property
941931
def size_for_testing(self):

pybamm/models/full_battery_models/base_battery_model.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pybamm
66
import numbers
7+
from functools import cached_property
78

89

910
class BatteryModelOptions(pybamm.FuzzyDict):
@@ -600,19 +601,14 @@ def phases(self):
600601
self._phases[domain] = phases
601602
return self._phases
602603

603-
@property
604+
@cached_property
604605
def whole_cell_domains(self):
605-
try:
606-
return self._whole_cell_domains
607-
except AttributeError:
608-
if self["working electrode"] == "positive":
609-
wcd = ["separator", "positive electrode"]
610-
elif self["working electrode"] == "negative":
611-
wcd = ["negative electrode", "separator"]
612-
elif self["working electrode"] == "both":
613-
wcd = ["negative electrode", "separator", "positive electrode"]
614-
self._whole_cell_domains = wcd
615-
return wcd
606+
if self["working electrode"] == "positive":
607+
return ["separator", "positive electrode"]
608+
elif self["working electrode"] == "negative":
609+
return ["negative electrode", "separator"]
610+
elif self["working electrode"] == "both":
611+
return ["negative electrode", "separator", "positive electrode"]
616612

617613
@property
618614
def electrode_types(self):

pybamm/models/full_battery_models/lithium_ion/electrode_soh.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
import pybamm
55
import numpy as np
6+
from functools import lru_cache
67

78

89
class ElectrodeSOH(pybamm.BaseModel):
@@ -246,28 +247,18 @@ def __init__(self, parameter_values, param=None):
246247
- self.param.n.prim.U_dimensional(x, T)
247248
)
248249

250+
@lru_cache
249251
def _get_electrode_soh_sims_full(self):
250-
try:
251-
return self._full_sim
252-
except AttributeError:
253-
full_model = ElectrodeSOH(param=self.param)
254-
self._full_sim = pybamm.Simulation(
255-
full_model, parameter_values=self.parameter_values
256-
)
257-
return self._full_sim
252+
full_model = ElectrodeSOH(param=self.param)
253+
return pybamm.Simulation(full_model, parameter_values=self.parameter_values)
258254

255+
@lru_cache
259256
def _get_electrode_soh_sims_split(self):
260-
try:
261-
return self._split_sims
262-
except AttributeError:
263-
x100_model = ElectrodeSOHx100(param=self.param)
264-
x100_sim = pybamm.Simulation(
265-
x100_model, parameter_values=self.parameter_values
266-
)
267-
x0_model = ElectrodeSOHx0(param=self.param)
268-
x0_sim = pybamm.Simulation(x0_model, parameter_values=self.parameter_values)
269-
self._split_sims = [x100_sim, x0_sim]
270-
return self._split_sims
257+
x100_model = ElectrodeSOHx100(param=self.param)
258+
x100_sim = pybamm.Simulation(x100_model, parameter_values=self.parameter_values)
259+
x0_model = ElectrodeSOHx0(param=self.param)
260+
x0_sim = pybamm.Simulation(x0_model, parameter_values=self.parameter_values)
261+
return [x100_sim, x0_sim]
271262

272263
def solve(self, inputs):
273264
ics = self._set_up_solve(inputs)

pybamm/simulation.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import copy
88
import warnings
99
import sys
10+
from functools import lru_cache
1011

1112

1213
def is_notebook():
@@ -889,6 +890,7 @@ def step(
889890

890891
return self.solution
891892

893+
@lru_cache
892894
def get_esoh_solver(self, calc_esoh):
893895
if (
894896
calc_esoh is False
@@ -897,13 +899,9 @@ def get_esoh_solver(self, calc_esoh):
897899
):
898900
return None
899901

900-
try:
901-
return self._esoh_solver
902-
except AttributeError:
903-
self._esoh_solver = pybamm.lithium_ion.ElectrodeSOHSolver(
904-
self.parameter_values, self.model.param
905-
)
906-
return self._esoh_solver
902+
return pybamm.lithium_ion.ElectrodeSOHSolver(
903+
self.parameter_values, self.model.param
904+
)
907905

908906
def plot(self, output_variables=None, **kwargs):
909907
"""

pybamm/solvers/solution.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pybamm
1010
import pandas as pd
1111
from scipy.io import savemat
12+
from functools import cached_property
1213

1314

1415
class NumpyEncoder(json.JSONEncoder):
@@ -344,15 +345,9 @@ def all_models(self):
344345
"""Model(s) used for solution"""
345346
return self._all_models
346347

347-
@property
348+
@cached_property
348349
def all_inputs_casadi(self):
349-
try:
350-
return self._all_inputs_casadi
351-
except AttributeError:
352-
self._all_inputs_casadi = [
353-
casadi.vertcat(*inp.values()) for inp in self.all_inputs
354-
]
355-
return self._all_inputs_casadi
350+
return [casadi.vertcat(*inp.values()) for inp in self.all_inputs]
356351

357352
@property
358353
def t_event(self):
@@ -374,63 +369,55 @@ def termination(self, value):
374369
"""Updates the reason for termination"""
375370
self._termination = value
376371

377-
@property
372+
@cached_property
378373
def first_state(self):
379374
"""
380375
A Solution object that only contains the first state. This is faster to evaluate
381376
than the full solution when only the first state is needed (e.g. to initialize
382377
a model with the solution)
383378
"""
384-
try:
385-
return self._first_state
386-
except AttributeError:
387-
new_sol = Solution(
388-
self.all_ts[0][:1],
389-
self.all_ys[0][:, :1],
390-
self.all_models[:1],
391-
self.all_inputs[:1],
392-
None,
393-
None,
394-
"success",
395-
)
396-
new_sol._all_inputs_casadi = self.all_inputs_casadi[:1]
397-
new_sol._sub_solutions = self.sub_solutions[:1]
379+
new_sol = Solution(
380+
self.all_ts[0][:1],
381+
self.all_ys[0][:, :1],
382+
self.all_models[:1],
383+
self.all_inputs[:1],
384+
None,
385+
None,
386+
"success",
387+
)
388+
new_sol._all_inputs_casadi = self.all_inputs_casadi[:1]
389+
new_sol._sub_solutions = self.sub_solutions[:1]
398390

399-
new_sol.solve_time = 0
400-
new_sol.integration_time = 0
401-
new_sol.set_up_time = 0
391+
new_sol.solve_time = 0
392+
new_sol.integration_time = 0
393+
new_sol.set_up_time = 0
402394

403-
self._first_state = new_sol
404-
return self._first_state
395+
return new_sol
405396

406-
@property
397+
@cached_property
407398
def last_state(self):
408399
"""
409400
A Solution object that only contains the final state. This is faster to evaluate
410401
than the full solution when only the final state is needed (e.g. to initialize
411402
a model with the solution)
412403
"""
413-
try:
414-
return self._last_state
415-
except AttributeError:
416-
new_sol = Solution(
417-
self.all_ts[-1][-1:],
418-
self.all_ys[-1][:, -1:],
419-
self.all_models[-1:],
420-
self.all_inputs[-1:],
421-
self.t_event,
422-
self.y_event,
423-
self.termination,
424-
)
425-
new_sol._all_inputs_casadi = self.all_inputs_casadi[-1:]
426-
new_sol._sub_solutions = self.sub_solutions[-1:]
404+
new_sol = Solution(
405+
self.all_ts[-1][-1:],
406+
self.all_ys[-1][:, -1:],
407+
self.all_models[-1:],
408+
self.all_inputs[-1:],
409+
self.t_event,
410+
self.y_event,
411+
self.termination,
412+
)
413+
new_sol._all_inputs_casadi = self.all_inputs_casadi[-1:]
414+
new_sol._sub_solutions = self.sub_solutions[-1:]
427415

428-
new_sol.solve_time = 0
429-
new_sol.integration_time = 0
430-
new_sol.set_up_time = 0
416+
new_sol.solve_time = 0
417+
new_sol.integration_time = 0
418+
new_sol.set_up_time = 0
431419

432-
self._last_state = new_sol
433-
return self._last_state
420+
return new_sol
434421

435422
@property
436423
def total_time(self):

0 commit comments

Comments
 (0)