Skip to content

Commit fa717f2

Browse files
committed
Allow primal/dual error in Mult to make tests pass
1 parent 1082433 commit fa717f2

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

firedrake/slate/slate.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,7 @@ def __init__(self, A, B):
12671267
fsA = A.arg_function_spaces[-1]
12681268
fsB = B.arg_function_spaces[0]
12691269

1270-
assert fsA == fsB.dual(), (
1270+
assert space_equivalence(fsA, fsB), (
12711271
"Cannot perform argument contraction over middle indices. "
12721272
"They should be in dual function spaces."
12731273
)
@@ -1317,7 +1317,7 @@ def __new__(cls, A, B, decomposition=None):
13171317
fsA = A.arg_function_spaces[0]
13181318
fsB = B.arg_function_spaces[0]
13191319

1320-
assert fsA == fsB, (
1320+
assert space_equivalence(fsA, fsB), (
13211321
"Cannot perform argument contraction over middle indices. "
13221322
"They must be in the same function space."
13231323
)
@@ -1401,6 +1401,19 @@ def _output_string(self, prec=None):
14011401
return "(%s).diag" % tensor
14021402

14031403

1404+
def space_equivalence(A, B):
1405+
"""Checks that two function spaces are equivalent.
1406+
1407+
:arg A: A function space.
1408+
:arg B: Another function space.
1409+
1410+
Returns `True` if they have matching meshes, elements, and rank. Otherwise,
1411+
`False` is returned.
1412+
"""
1413+
1414+
return A.mesh() == B.mesh() and A.ufl_element() == B.ufl_element()
1415+
1416+
14041417
def as_slate(F):
14051418
"""Convert an assembled or unassembled expression into a Slate Tensor.
14061419

0 commit comments

Comments
 (0)