@@ -464,7 +464,7 @@ def arg_function_spaces(self):
464
464
if isinstance (tensor , BaseForm ):
465
465
return tuple (a .function_space () for a in tensor .arguments ())
466
466
else :
467
- return (tensor .ufl_function_space (),)
467
+ return (tensor .ufl_function_space (). dual () ,)
468
468
469
469
@cached_property
470
470
def _argument (self ):
@@ -1084,14 +1084,14 @@ def arg_function_spaces(self):
1084
1084
is defined on.
1085
1085
"""
1086
1086
tensor , = self .operands
1087
- return tensor .arg_function_spaces [:: - 1 ]
1087
+ return tuple ( V . dual () for V in reversed ( tensor .arg_function_spaces ))
1088
1088
1089
1089
def arguments (self ):
1090
1090
"""Returns the expected arguments of the resulting tensor of
1091
1091
performing a specific unary operation on a tensor.
1092
1092
"""
1093
1093
tensor , = self .operands
1094
- return tensor .arguments ()[:: - 1 ]
1094
+ return tuple ( a . reconstruct ( a . function_space (). dual ()) for a in reversed ( tensor .arguments ()))
1095
1095
1096
1096
def _output_string (self , prec = None ):
1097
1097
"""Creates a string representation of the inverse of a tensor."""
@@ -1219,7 +1219,7 @@ def __init__(self, A, B):
1219
1219
raise ValueError ("Illegal op on a %s-tensor with a %s-tensor."
1220
1220
% (A .shape , B .shape ))
1221
1221
1222
- assert all (space_equivalence ( fsA , fsB ) for fsA , fsB in
1222
+ assert all (fsA == fsB for fsA , fsB in
1223
1223
zip (A .arg_function_spaces , B .arg_function_spaces )), (
1224
1224
"Function spaces associated with operands must match."
1225
1225
)
@@ -1267,9 +1267,9 @@ def __init__(self, A, B):
1267
1267
fsA = A .arg_function_spaces [- 1 ]
1268
1268
fsB = B .arg_function_spaces [0 ]
1269
1269
1270
- assert space_equivalence ( fsA , fsB ), (
1270
+ assert fsA == fsB . dual ( ), (
1271
1271
"Cannot perform argument contraction over middle indices. "
1272
- "They must be in the same function space ."
1272
+ "They should be in dual function spaces ."
1273
1273
)
1274
1274
1275
1275
super (Mul , self ).__init__ (A , B )
@@ -1314,12 +1314,12 @@ def __new__(cls, A, B, decomposition=None):
1314
1314
if A .shape [1 ] != B .shape [0 ]:
1315
1315
raise ValueError (f"Illegal op on a { A .shape } -tensor with a { B .shape } -tensor." )
1316
1316
1317
- fsA = A .arg_function_spaces [0 ]
1317
+ fsA = A .arg_function_spaces [1 ]
1318
1318
fsB = B .arg_function_spaces [0 ]
1319
1319
1320
- assert space_equivalence ( fsA , fsB ) , (
1320
+ assert fsA . dual () == fsB , (
1321
1321
"Cannot perform argument contraction over middle indices. "
1322
- "They must be in the same function space ."
1322
+ "They should be in dual function spaces ."
1323
1323
)
1324
1324
1325
1325
# For matrices smaller than 5x5, exact formulae can be used
@@ -1341,7 +1341,8 @@ def __init__(self, A, B, decomposition=None):
1341
1341
1342
1342
super (Solve , self ).__init__ (A_factored , B )
1343
1343
1344
- self ._args = A_factored .arguments ()[::- 1 ][:- 1 ] + B .arguments ()[1 :]
1344
+ Ainv_args = [a .reconstruct (a .function_space ().dual ()) for a in reversed (A .arguments ())]
1345
+ self ._args = Ainv_args [:- 1 ] + B .arguments ()[1 :]
1345
1346
self ._arg_fs = [arg .function_space () for arg in self ._args ]
1346
1347
1347
1348
@cached_property
@@ -1400,19 +1401,6 @@ def _output_string(self, prec=None):
1400
1401
return "(%s).diag" % tensor
1401
1402
1402
1403
1403
- def space_equivalence (A , B ):
1404
- """Checks that two function spaces are equivalent.
1405
-
1406
- :arg A: A function space.
1407
- :arg B: Another function space.
1408
-
1409
- Returns `True` if they have matching meshes, elements, and rank. Otherwise,
1410
- `False` is returned.
1411
- """
1412
-
1413
- return A .mesh () == B .mesh () and A .ufl_element () == B .ufl_element ()
1414
-
1415
-
1416
1404
def as_slate (F ):
1417
1405
"""Convert an assembled or unassembled expression into a Slate Tensor.
1418
1406
0 commit comments