Skip to content

Commit 80b6edd

Browse files
authored
Fix Real space assembly (#4331)
* Fix Real space assembly
1 parent c730277 commit 80b6edd

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

tests/firedrake/regression/test_real_space.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,22 @@ def test_real_nonsquare_two_form_assembly():
5858

5959

6060
@pytest.mark.skipcomplex
61-
def test_real_mixed_one_form_assembly():
61+
@pytest.mark.parametrize("coefficient", (False, True))
62+
def test_real_mixed_one_form_assembly(coefficient):
6263
mesh = UnitIntervalMesh(3)
6364
rfs = FunctionSpace(mesh, "Real", 0)
6465
cgfs = FunctionSpace(mesh, "CG", 1)
6566

6667
mfs = cgfs*rfs
6768
v, q = TestFunctions(mfs)
6869

69-
A = assemble(conj(v) * dx + q * dx)
70+
if coefficient:
71+
z = Function(mfs)
72+
z.subfunctions[1].assign(1)
73+
u, p = split(z)
74+
A = assemble(inner(u, v) * dx + inner(p, q) * dx)
75+
else:
76+
A = assemble(conj(v) * dx + q * dx)
7077

7178
qq = TestFunction(rfs)
7279

@@ -131,6 +138,33 @@ def test_real_mixed_empty_component_assembly():
131138
assemble(derivative(inner(grad(v), grad(v)) * dx, w))
132139

133140

141+
@pytest.mark.skipcomplex
142+
@pytest.mark.parametrize("coefficient", (False, True))
143+
def test_real_extruded_mixed_one_form_assembly(coefficient):
144+
m = UnitIntervalMesh(3)
145+
mesh = ExtrudedMesh(m, 10)
146+
rfs = FunctionSpace(mesh, "Real", 0)
147+
cgfs = FunctionSpace(mesh, "CG", 1)
148+
149+
mfs = cgfs*rfs
150+
v, q = TestFunctions(mfs)
151+
152+
if coefficient:
153+
z = Function(mfs)
154+
z.subfunctions[1].assign(1)
155+
u, p = split(z)
156+
A = assemble(inner(u, v) * dx + inner(p, q) * dx)
157+
else:
158+
A = assemble(conj(v) * dx + q * dx)
159+
160+
qq = TestFunction(rfs)
161+
162+
AA = assemble(qq * dx)
163+
164+
np.testing.assert_almost_equal(A.dat.data[1],
165+
AA.dat.data)
166+
167+
134168
@pytest.mark.skipcomplex
135169
def test_real_extruded_mixed_two_form_assembly():
136170
m = UnitIntervalMesh(3)

tsfc/kernel_interface/common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from numpy import asarray
1717
from tsfc import fem, ufl_utils
1818
from finat.element_factory import as_fiat_cell, create_element
19+
from finat.ufl import MixedElement
1920
from tsfc.kernel_interface import KernelInterface
2021
from tsfc.logging import logger
2122
from ufl.utils.sequences import max_degree
@@ -298,10 +299,16 @@ def set_quad_rule(params, cell, integral_type, functions):
298299
# Check if the integral has a quad degree or quad element attached,
299300
# otherwise use the estimated polynomial degree attached by compute_form_data
300301
quad_rule = params.get("quadrature_rule", "default")
302+
elements = []
303+
for f in functions:
304+
e = f.ufl_element()
305+
if type(e) is MixedElement:
306+
elements.extend(e.sub_elements)
307+
else:
308+
elements.append(e)
301309
try:
302310
quadrature_degree = params["quadrature_degree"]
303311
except KeyError:
304-
elements = [f.ufl_function_space().ufl_element() for f in functions]
305312
quad_data = set((e.degree(), e.quadrature_scheme() or "default") for e in elements
306313
if e.family() in {"Quadrature", "Boundary Quadrature"})
307314
if len(quad_data) == 0:
@@ -320,8 +327,7 @@ def set_quad_rule(params, cell, integral_type, functions):
320327
if isinstance(quad_rule, str):
321328
scheme = quad_rule
322329
fiat_cell = as_fiat_cell(cell)
323-
finat_elements = set(create_element(f.ufl_element()) for f in functions
324-
if f.ufl_element().family() != "Real")
330+
finat_elements = set(create_element(e) for e in elements if e.family() != "Real")
325331
fiat_cells = [fiat_cell] + [finat_el.complex for finat_el in finat_elements]
326332
fiat_cell = max_complex(fiat_cells)
327333

0 commit comments

Comments
 (0)