Skip to content

Commit 24e465d

Browse files
Merge pull request #2477 from pybamm-team/issue-2403-function-parameter-constant
#403 add special case for constant function parameter
2 parents 8478e9a + 7bcfd80 commit 24e465d

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
## Optimizations
99

10+
- `ParameterValues` now avoids trying to process children if a function parameter is an object that doesn't depend on its children ([#2477](https://github.com/pybamm-team/PyBaMM/pull/2477))
1011
- Added more rules for simplifying expressions, especially around Concatenations. Also, meshes constructed from multiple domains are now cached ([#2443](https://github.com/pybamm-team/PyBaMM/pull/2443))
1112
- Added more rules for simplifying expressions. Constants in binary operators are now moved to the left by default (e.g. `x*2` returns `2*x`) ([#2424](https://github.com/pybamm-team/PyBaMM/pull/2424))
1213

pybamm/parameters/parameter_values.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def _process_symbol(self, symbol):
568568
# Check not NaN (parameter in csv file but no value given)
569569
if np.isnan(value):
570570
raise ValueError(f"Parameter '{symbol.name}' not found")
571-
# Scalar inherits name (for updating parameters)
571+
# Scalar inherits name
572572
return pybamm.Scalar(value, name=symbol.name)
573573
elif isinstance(value, pybamm.Symbol):
574574
new_value = self.process_symbol(value)
@@ -578,18 +578,29 @@ def _process_symbol(self, symbol):
578578
raise TypeError("Cannot process parameter '{}'".format(value))
579579

580580
elif isinstance(symbol, pybamm.FunctionParameter):
581-
new_children = []
582-
for child in symbol.children:
583-
if symbol.diff_variable is not None and any(
584-
x == symbol.diff_variable for x in child.pre_order()
585-
):
586-
# Wrap with NotConstant to avoid simplification,
587-
# which would stop symbolic diff from working properly
588-
new_child = pybamm.NotConstant(child)
589-
new_children.append(self.process_symbol(new_child))
590-
else:
591-
new_children.append(self.process_symbol(child))
592581
function_name = self[symbol.name]
582+
if isinstance(
583+
function_name,
584+
(numbers.Number, pybamm.Interpolant, pybamm.InputParameter),
585+
) or (
586+
isinstance(function_name, pybamm.Symbol)
587+
and function_name.size_for_testing == 1
588+
):
589+
# no need to process children, they will only be used for shape
590+
new_children = symbol.children
591+
else:
592+
# process children
593+
new_children = []
594+
for child in symbol.children:
595+
if symbol.diff_variable is not None and any(
596+
x == symbol.diff_variable for x in child.pre_order()
597+
):
598+
# Wrap with NotConstant to avoid simplification,
599+
# which would stop symbolic diff from working properly
600+
new_child = pybamm.NotConstant(child)
601+
new_children.append(self.process_symbol(new_child))
602+
else:
603+
new_children.append(self.process_symbol(child))
593604

594605
# Create Function or Interpolant or Scalar object
595606
if isinstance(function_name, tuple):

tests/unit/test_parameters/test_parameter_values.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,10 @@ def test_function(var):
343343
self.assertEqual(processed_func.evaluate(inputs={"a": 3}), 369)
344344

345345
# process constant function
346-
const = pybamm.FunctionParameter("const", {"a": a})
346+
# this should work even if the parameter in the function is not provided
347+
const = pybamm.FunctionParameter(
348+
"const", {"a": pybamm.Parameter("not provided")}
349+
)
347350
processed_const = parameter_values.process_symbol(const)
348351
self.assertIsInstance(processed_const, pybamm.Scalar)
349352
self.assertEqual(processed_const.evaluate(), 254)

0 commit comments

Comments
 (0)