Skip to content

Commit 2a07936

Browse files
committed
SLATE: support dual space
1 parent 80b6edd commit 2a07936

File tree

2 files changed

+13
-25
lines changed

2 files changed

+13
-25
lines changed

firedrake/slate/slate.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def arg_function_spaces(self):
464464
if isinstance(tensor, BaseForm):
465465
return tuple(a.function_space() for a in tensor.arguments())
466466
else:
467-
return (tensor.ufl_function_space(),)
467+
return (tensor.ufl_function_space().dual(),)
468468

469469
@cached_property
470470
def _argument(self):
@@ -1084,14 +1084,14 @@ def arg_function_spaces(self):
10841084
is defined on.
10851085
"""
10861086
tensor, = self.operands
1087-
return tensor.arg_function_spaces[::-1]
1087+
return tuple(V.dual() for V in reversed(tensor.arg_function_spaces))
10881088

10891089
def arguments(self):
10901090
"""Returns the expected arguments of the resulting tensor of
10911091
performing a specific unary operation on a tensor.
10921092
"""
10931093
tensor, = self.operands
1094-
return tensor.arguments()[::-1]
1094+
return tuple(a.reconstruct(a.function_space().dual()) for a in reversed(tensor.arguments()))
10951095

10961096
def _output_string(self, prec=None):
10971097
"""Creates a string representation of the inverse of a tensor."""
@@ -1219,7 +1219,7 @@ def __init__(self, A, B):
12191219
raise ValueError("Illegal op on a %s-tensor with a %s-tensor."
12201220
% (A.shape, B.shape))
12211221

1222-
assert all(space_equivalence(fsA, fsB) for fsA, fsB in
1222+
assert all(fsA == fsB for fsA, fsB in
12231223
zip(A.arg_function_spaces, B.arg_function_spaces)), (
12241224
"Function spaces associated with operands must match."
12251225
)
@@ -1267,9 +1267,9 @@ def __init__(self, A, B):
12671267
fsA = A.arg_function_spaces[-1]
12681268
fsB = B.arg_function_spaces[0]
12691269

1270-
assert space_equivalence(fsA, fsB), (
1270+
assert fsA == fsB.dual(), (
12711271
"Cannot perform argument contraction over middle indices. "
1272-
"They must be in the same function space."
1272+
"They should be in dual function spaces."
12731273
)
12741274

12751275
super(Mul, self).__init__(A, B)
@@ -1314,12 +1314,12 @@ def __new__(cls, A, B, decomposition=None):
13141314
if A.shape[1] != B.shape[0]:
13151315
raise ValueError(f"Illegal op on a {A.shape}-tensor with a {B.shape}-tensor.")
13161316

1317-
fsA = A.arg_function_spaces[0]
1317+
fsA = A.arg_function_spaces[1]
13181318
fsB = B.arg_function_spaces[0]
13191319

1320-
assert space_equivalence(fsA, fsB), (
1320+
assert fsA.dual() == fsB, (
13211321
"Cannot perform argument contraction over middle indices. "
1322-
"They must be in the same function space."
1322+
"They should be in dual function spaces."
13231323
)
13241324

13251325
# For matrices smaller than 5x5, exact formulae can be used
@@ -1341,7 +1341,8 @@ def __init__(self, A, B, decomposition=None):
13411341

13421342
super(Solve, self).__init__(A_factored, B)
13431343

1344-
self._args = A_factored.arguments()[::-1][:-1] + B.arguments()[1:]
1344+
Ainv_args = [a.reconstruct(a.function_space().dual()) for a in reversed(A.arguments())]
1345+
self._args = Ainv_args[:-1] + B.arguments()[1:]
13451346
self._arg_fs = [arg.function_space() for arg in self._args]
13461347

13471348
@cached_property
@@ -1400,19 +1401,6 @@ def _output_string(self, prec=None):
14001401
return "(%s).diag" % tensor
14011402

14021403

1403-
def space_equivalence(A, B):
1404-
"""Checks that two function spaces are equivalent.
1405-
1406-
:arg A: A function space.
1407-
:arg B: Another function space.
1408-
1409-
Returns `True` if they have matching meshes, elements, and rank. Otherwise,
1410-
`False` is returned.
1411-
"""
1412-
1413-
return A.mesh() == B.mesh() and A.ufl_element() == B.ufl_element()
1414-
1415-
14161404
def as_slate(F):
14171405
"""Convert an assembled or unassembled expression into a Slate Tensor.
14181406

firedrake/ufl_expr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def extract_domains(func):
379379
list of firedrake.mesh.MeshGeometry
380380
Extracted domains.
381381
"""
382-
if isinstance(func, (function.Function, cofunction.Cofunction)):
382+
if isinstance(func, (function.Function, cofunction.Cofunction, Argument, Coargument)):
383383
return [func.function_space().mesh()]
384384
else:
385385
return ufl.domain.extract_domains(func)
@@ -398,7 +398,7 @@ def extract_unique_domain(func):
398398
list of firedrake.mesh.MeshGeometry
399399
Extracted domains.
400400
"""
401-
if isinstance(func, (function.Function, cofunction.Cofunction)):
401+
if isinstance(func, (function.Function, cofunction.Cofunction, Argument, Coargument)):
402402
return func.function_space().mesh()
403403
else:
404404
return ufl.domain.extract_unique_domain(func)

0 commit comments

Comments
 (0)