Skip to content

Commit 4293ccd

Browse files
committed
add a test
1 parent 2a07936 commit 4293ccd

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

firedrake/slate/slate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,12 +1314,12 @@ def __new__(cls, A, B, decomposition=None):
13141314
if A.shape[1] != B.shape[0]:
13151315
raise ValueError(f"Illegal op on a {A.shape}-tensor with a {B.shape}-tensor.")
13161316

1317-
fsA = A.arg_function_spaces[1]
1317+
fsA = A.arg_function_spaces[0]
13181318
fsB = B.arg_function_spaces[0]
13191319

1320-
assert fsA.dual() == fsB, (
1320+
assert fsA == fsB, (
13211321
"Cannot perform argument contraction over middle indices. "
1322-
"They should be in dual function spaces."
1322+
"They must be in the same function space."
13231323
)
13241324

13251325
# For matrices smaller than 5x5, exact formulae can be used
@@ -1341,7 +1341,7 @@ def __init__(self, A, B, decomposition=None):
13411341

13421342
super(Solve, self).__init__(A_factored, B)
13431343

1344-
Ainv_args = [a.reconstruct(a.function_space().dual()) for a in reversed(A.arguments())]
1344+
Ainv_args = tuple(a.reconstruct(a.function_space().dual()) for a in reversed(A.arguments()))
13451345
self._args = Ainv_args[:-1] + B.arguments()[1:]
13461346
self._arg_fs = [arg.function_space() for arg in self._args]
13471347

tests/firedrake/slate/test_assemble_tensors.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,24 @@ def test_assemble_matrix(rank_two_tensor):
130130
assert np.allclose(M.M.values, assemble(rank_two_tensor.form).M.values, rtol=1e-14)
131131

132132

133+
def test_assemble_solve(mesh):
134+
V = FunctionSpace(mesh, "DG", 0)
135+
u = TrialFunction(V)
136+
v = TestFunction(V)
137+
138+
M = inner(u, v)*dx
139+
f = Cofunction(V.dual())
140+
f.dat.data[...] = 1
141+
142+
u1 = Function(V)
143+
u2 = Function(V)
144+
# Assemble a SLATE tensor into f
145+
assemble(Inverse(Tensor(M)) * AssembledVector(f), tensor=u1)
146+
# Assemble a different tensor into f
147+
solve(M == f, u2)
148+
assert np.allclose(u1.dat.data, u2.dat.data, rtol=1e-14)
149+
150+
133151
def test_assemble_vector_into_tensor(mesh):
134152
V = FunctionSpace(mesh, "DG", 1)
135153
v = TestFunction(V)

0 commit comments

Comments
 (0)