@@ -1267,7 +1267,7 @@ def __init__(self, A, B):
1267
1267
fsA = A .arg_function_spaces [- 1 ]
1268
1268
fsB = B .arg_function_spaces [0 ]
1269
1269
1270
- assert fsA == fsB . dual ( ), (
1270
+ assert space_equivalence ( fsA , fsB ), (
1271
1271
"Cannot perform argument contraction over middle indices. "
1272
1272
"They should be in dual function spaces."
1273
1273
)
@@ -1317,7 +1317,7 @@ def __new__(cls, A, B, decomposition=None):
1317
1317
fsA = A .arg_function_spaces [0 ]
1318
1318
fsB = B .arg_function_spaces [0 ]
1319
1319
1320
- assert fsA == fsB , (
1320
+ assert space_equivalence ( fsA , fsB ) , (
1321
1321
"Cannot perform argument contraction over middle indices. "
1322
1322
"They must be in the same function space."
1323
1323
)
@@ -1401,6 +1401,19 @@ def _output_string(self, prec=None):
1401
1401
return "(%s).diag" % tensor
1402
1402
1403
1403
1404
+ def space_equivalence (A , B ):
1405
+ """Checks that two function spaces are equivalent.
1406
+
1407
+ :arg A: A function space.
1408
+ :arg B: Another function space.
1409
+
1410
+ Returns `True` if they have matching meshes, elements, and rank. Otherwise,
1411
+ `False` is returned.
1412
+ """
1413
+
1414
+ return A .mesh () == B .mesh () and A .ufl_element () == B .ufl_element ()
1415
+
1416
+
1404
1417
def as_slate (F ):
1405
1418
"""Convert an assembled or unassembled expression into a Slate Tensor.
1406
1419
0 commit comments