Skip to content

Commit 8f62ef9

Browse files
committed
Remove unused interpolation code
1 parent f753ce4 commit 8f62ef9

File tree

2 files changed

+6
-321
lines changed

2 files changed

+6
-321
lines changed

fenics_ice/interpolation.py

+4-319
Original file line numberDiff line numberDiff line change
@@ -2,175 +2,22 @@
22
"""
33

44
from .backend import (
5-
Cell, Mesh, MeshEditor, Point, UserExpression, backend_Constant,
6-
backend_Function, backend_ScalarType, parameters)
5+
Cell, Mesh, MeshEditor, Point, backend_ScalarType, parameters)
76
from ..interface import (
8-
check_space_type, comm_dup_cached, packed, space_comm, space_eq,
9-
var_assign, var_comm, var_get_values, var_id, var_inner, var_is_scalar,
10-
var_local_size, var_new, var_new_conjugate_dual, var_replacement,
11-
var_scalar_value, var_set_values)
7+
comm_dup_cached, space_comm, var_comm, var_get_values, var_local_size,
8+
var_new, var_set_values)
129

13-
from ..equation import Equation, ZeroAssignment
14-
from ..equations import MatrixActionRHS
15-
from ..linear_equation import LinearEquation, Matrix
16-
from ..manager import manager_disabled
17-
18-
from .expr import (
19-
ExprEquation, derivative, eliminate_zeros, expr_zero, extract_dependencies,
20-
extract_variables)
21-
from .variables import ReplacementConstant
10+
from ..linear_equation import Matrix
2211

2312
import functools
2413
import mpi4py.MPI as MPI
2514
import numpy as np
26-
try:
27-
import ufl_legacy as ufl
28-
except ModuleNotFoundError:
29-
import ufl
3015

3116
__all__ = \
3217
[
33-
"ExprInterpolation",
34-
"Interpolation",
35-
"PointInterpolation"
3618
]
3719

3820

39-
@manager_disabled()
40-
def interpolate_expression(x, expr, *, adj_x=None):
41-
if adj_x is None:
42-
check_space_type(x, "primal")
43-
else:
44-
check_space_type(x, "conjugate_dual")
45-
check_space_type(adj_x, "conjugate_dual")
46-
for dep in extract_variables(expr):
47-
check_space_type(dep, "primal")
48-
49-
expr = eliminate_zeros(expr)
50-
51-
class Expr(UserExpression):
52-
def eval(self, value, x):
53-
value[:] = expr(tuple(x))
54-
55-
def value_shape(self):
56-
return x.ufl_shape
57-
58-
if adj_x is None:
59-
if isinstance(x, backend_Constant):
60-
if isinstance(expr, backend_Constant):
61-
value = expr
62-
else:
63-
if len(x.ufl_shape) > 0:
64-
raise ValueError("Scalar Constant required")
65-
value = x.values()
66-
Expr().eval(value, ())
67-
value, = value
68-
var_assign(x, value)
69-
elif isinstance(x, backend_Function):
70-
try:
71-
x.assign(expr)
72-
except RuntimeError:
73-
x.interpolate(Expr())
74-
else:
75-
raise TypeError(f"Unexpected type: {type(x)}")
76-
else:
77-
expr_val = var_new_conjugate_dual(adj_x)
78-
expr_arguments = ufl.algorithms.extract_arguments(expr)
79-
if len(expr_arguments) > 0:
80-
test, = expr_arguments
81-
if len(test.ufl_shape) > 0:
82-
raise NotImplementedError("Case not implemented")
83-
expr = ufl.replace(expr, {test: ufl.classes.IntValue(1)})
84-
interpolate_expression(expr_val, expr)
85-
86-
if isinstance(x, backend_Constant):
87-
if len(x.ufl_shape) > 0:
88-
raise ValueError("Scalar Constant required")
89-
var_assign(x, var_inner(adj_x, expr_val))
90-
elif isinstance(x, backend_Function):
91-
if not space_eq(x.function_space(), adj_x.function_space()):
92-
raise ValueError("Unable to perform transpose interpolation")
93-
var_set_values(
94-
x, var_get_values(expr_val).conjugate() * var_get_values(adj_x)) # noqa: E501
95-
else:
96-
raise TypeError(f"Unexpected type: {type(x)}")
97-
98-
99-
class ExprInterpolation(ExprEquation):
100-
r"""Represents interpolation of `rhs` onto the space for `x`.
101-
102-
The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial
103-
\mathcal{F} / \partial x` is the identity.
104-
105-
:arg x: The forward solution.
106-
:arg rhs: A :class:`ufl.core.expr.Expr` defining the expression to
107-
interpolate onto the space for `x`. Should not depend on `x`.
108-
"""
109-
110-
def __init__(self, x, rhs):
111-
deps, nl_deps = extract_dependencies(rhs, space_type="primal")
112-
if var_id(x) in deps:
113-
raise ValueError("Invalid dependency")
114-
deps, nl_deps = list(deps.values()), tuple(nl_deps.values())
115-
deps.insert(0, x)
116-
117-
super().__init__(x, deps, nl_deps=nl_deps, ic=False, adj_ic=False)
118-
self._rhs = rhs
119-
120-
def drop_references(self):
121-
replace_map = {dep: var_replacement(dep)
122-
for dep in self.dependencies()}
123-
124-
super().drop_references()
125-
self._rhs = ufl.replace(self._rhs, replace_map)
126-
127-
def forward_solve(self, x, deps=None):
128-
interpolate_expression(x, self._replace(self._rhs, deps))
129-
130-
def adjoint_derivative_action(self, nl_deps, dep_index, adj_x):
131-
eq_deps = self.dependencies()
132-
if dep_index <= 0 or dep_index >= len(eq_deps):
133-
raise ValueError("Unexpected dep_index")
134-
135-
dep = eq_deps[dep_index]
136-
137-
if isinstance(dep, (backend_Constant, ReplacementConstant)):
138-
if len(dep.ufl_shape) > 0:
139-
raise NotImplementedError("Case not implemented")
140-
dF = derivative(self._rhs, dep, argument=ufl.classes.IntValue(1))
141-
else:
142-
dF = derivative(self._rhs, dep)
143-
dF = eliminate_zeros(dF)
144-
dF = self._nonlinear_replace(dF, nl_deps)
145-
146-
F = var_new_conjugate_dual(dep)
147-
interpolate_expression(F, dF, adj_x=adj_x)
148-
return (-1.0, F)
149-
150-
def adjoint_jacobian_solve(self, adj_x, nl_deps, b):
151-
return b
152-
153-
def tangent_linear(self, tlm_map):
154-
x = self.x()
155-
156-
tlm_rhs = expr_zero(x)
157-
for dep in self.dependencies():
158-
if dep != x:
159-
tau_dep = tlm_map[dep]
160-
if tau_dep is not None:
161-
tlm_rhs = (tlm_rhs
162-
+ derivative(self._rhs, dep, argument=tau_dep))
163-
164-
if isinstance(tlm_rhs, ufl.classes.Zero):
165-
return ZeroAssignment(tlm_map[x])
166-
else:
167-
return ExprInterpolation(tlm_map[x], tlm_rhs)
168-
169-
170-
def function_coords(x):
171-
return x.function_space().tabulate_dof_coordinates()
172-
173-
17421
def has_ghost_cells(mesh):
17522
for cell in range(mesh.num_cells()):
17623
if Cell(mesh, cell).is_ghost():
@@ -436,165 +283,3 @@ def adjoint_action(self, nl_deps, adj_x, b, b_index=0, *, method="assign"):
436283
b.vector()[:] -= self._P_T.dot(var_get_values(adj_x))
437284
else:
438285
raise ValueError(f"Invalid method: '{method:s}'")
439-
440-
441-
class Interpolation(LinearEquation):
442-
r"""Represents interpolation of the scalar-valued function `y` onto the
443-
space for `x`.
444-
445-
The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial
446-
\mathcal{F} / \partial x` is the identity.
447-
448-
Internally this builds (or uses a supplied) interpolation matrix for the
449-
local process *only*. This behaves correctly if the there are no edges
450-
between owned and non-owned nodes in the degree of freedom graph associated
451-
with the discrete function space for `y`.
452-
453-
:arg x: A scalar-valued DOLFIN `Function` defining the forward solution.
454-
:arg y: A scalar-valued DOLFIN `Function` to interpolate onto the space for
455-
`x`.
456-
:arg X_coords: A :class:`numpy.ndarray` defining the coordinates at which
457-
to interpolate `y`. Shape is `(n, d)` where `n` is the number of
458-
process local degrees of freedom for `x` and `d` is the geometric
459-
dimension. Defaults to the process local degree of freedom locations
460-
for `x`. Ignored if `P` is supplied.
461-
:arg P: The interpolation matrix. A :class:`scipy.sparse.spmatrix`.
462-
:arg tolerance: Maximum permitted distance (as returned by the DOLFIN
463-
`BoundingBoxTree.compute_closest_entity` method) of an interpolation
464-
point from a cell in the mesh for `y`. Ignored if `P` is supplied.
465-
"""
466-
467-
def __init__(self, x, y, *, x_coords=None, P=None,
468-
tolerance=0.0):
469-
check_space_type(x, "primal")
470-
check_space_type(y, "primal")
471-
472-
if not isinstance(x, backend_Function):
473-
raise TypeError("Solution must be a Function")
474-
if len(x.ufl_shape) > 0:
475-
raise ValueError("Solution must be a scalar-valued Function")
476-
if not isinstance(y, backend_Function):
477-
raise TypeError("y must be a Function")
478-
if len(y.ufl_shape) > 0:
479-
raise ValueError("y must be a scalar-valued Function")
480-
if (x_coords is not None) and (var_comm(x).size > 1):
481-
raise TypeError("Cannot prescribe x_coords in parallel")
482-
483-
if P is None:
484-
y_space = y.function_space()
485-
486-
if x_coords is None:
487-
x_coords = function_coords(x)
488-
489-
y_cells, y_distances = point_cells(x_coords, y_space.mesh())
490-
if (y_distances > tolerance).any():
491-
raise RuntimeError("Unable to locate one or more cells")
492-
493-
y_colors = greedy_coloring(y_space)
494-
P = interpolation_matrix(x_coords, y, y_cells, y_colors)
495-
else:
496-
P = P.copy()
497-
498-
super().__init__(
499-
x, MatrixActionRHS(LocalMatrix(P), y))
500-
501-
502-
class PointInterpolation(Equation):
503-
r"""Represents interpolation of a scalar-valued function at given points.
504-
505-
The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial
506-
\mathcal{F} / \partial x` is the identity.
507-
508-
Internally this builds (or uses a supplied) interpolation matrix for the
509-
local process *only*. This behaves correctly if the there are no edges
510-
between owned and non-owned nodes in the degree of freedom graph associated
511-
with the discrete function space for `y`.
512-
513-
:arg X: A scalar variable, or a :class:`Sequence` of scalar variables,
514-
defining the forward solution.
515-
:arg y: A scalar-valued DOLFIN `Function` to interpolate.
516-
:arg X_coords: A :class:`numpy.ndarray` defining the coordinates at which
517-
to interpolate `y`. Shape is `(n, d)` where `n` is the number of
518-
interpolation points and `d` is the geometric dimension. Ignored if `P`
519-
is supplied.
520-
:arg P: The interpolation matrix. A :class:`scipy.sparse.spmatrix`.
521-
:arg tolerance: Maximum permitted distance (as returned by the DOLFIN
522-
`BoundingBoxTree.compute_closest_entity` method) of an interpolation
523-
point from a cell in the mesh for `y`. Ignored if `P` is supplied.
524-
"""
525-
526-
def __init__(self, X, y, X_coords=None, *,
527-
P=None, tolerance=0.0):
528-
X = packed(X)
529-
for x in X:
530-
check_space_type(x, "primal")
531-
if not var_is_scalar(x):
532-
raise ValueError("Solution must be a scalar variable, or a "
533-
"Sequence of scalar variables")
534-
check_space_type(y, "primal")
535-
536-
if X_coords is None:
537-
if P is None:
538-
raise TypeError("X_coords required when P is not supplied")
539-
else:
540-
if len(X) != X_coords.shape[0]:
541-
raise ValueError("Invalid number of variables")
542-
if not isinstance(y, backend_Function):
543-
raise TypeError("y must be a Function")
544-
if len(y.ufl_shape) > 0:
545-
raise ValueError("y must be a scalar-valued Function")
546-
547-
if P is None:
548-
y_space = y.function_space()
549-
550-
y_cells = point_owners(X_coords, y_space, tolerance=tolerance)
551-
y_colors = greedy_coloring(y_space)
552-
P = interpolation_matrix(X_coords, y, y_cells, y_colors)
553-
else:
554-
P = P.copy()
555-
556-
super().__init__(X, list(X) + [y], nl_deps=[], ic=False, adj_ic=False)
557-
self._P = P
558-
self._P_T = P.T
559-
560-
def forward_solve(self, X, deps=None):
561-
y = (self.dependencies() if deps is None else deps)[-1]
562-
563-
check_space_type(y, "primal")
564-
y_v = var_get_values(y)
565-
x_v_local = np.full(len(X), np.nan, dtype=backend_ScalarType)
566-
for i in range(len(X)):
567-
x_v_local[i] = self._P.getrow(i).dot(y_v)
568-
569-
comm = var_comm(y)
570-
x_v = np.full(len(X), np.nan, dtype=x_v_local.dtype)
571-
comm.Allreduce(x_v_local, x_v, op=MPI.SUM)
572-
573-
for i, x in enumerate(X):
574-
var_assign(x, x_v[i])
575-
576-
def adjoint_derivative_action(self, nl_deps, dep_index, adj_X):
577-
if dep_index != len(self.X()):
578-
raise ValueError("Unexpected dep_index")
579-
580-
adj_x_v = np.full(len(adj_X), np.nan, dtype=backend_ScalarType)
581-
for i, adj_x in enumerate(adj_X):
582-
adj_x_v[i] = var_scalar_value(adj_x)
583-
584-
F = var_new_conjugate_dual(self.dependencies()[-1])
585-
var_set_values(F, self._P_T.dot(adj_x_v))
586-
return (-1.0, F)
587-
588-
def adjoint_jacobian_solve(self, adj_X, nl_deps, B):
589-
return B
590-
591-
def tangent_linear(self, tlm_map):
592-
X = self.X()
593-
y = self.dependencies()[-1]
594-
595-
tlm_y = tlm_map[y]
596-
if tlm_y is None:
597-
return ZeroAssignment([tlm_map[x] for x in X])
598-
else:
599-
return PointInterpolation([tlm_map[x] for x in X], tlm_y,
600-
P=self._P)

fenics_ice/solver.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def interior(x_coords, y_space):
5555

5656

5757
def interpolation_matrix(x_coords, y_space):
58-
from tlm_adjoint.fenics.interpolation import (
58+
from .interpolation import (
5959
greedy_coloring, interpolation_matrix, point_owners)
6060

6161
y_cells = point_owners(x_coords, y_space, tolerance=np.inf)
@@ -1292,7 +1292,7 @@ def comp_J_inv(self, verbose=False):
12921292
cache_jacobian=False, cache_rhs_assembly=False).solve()
12931293

12941294
if not hasattr(self, "_cached_J_mismatch_data"):
1295-
from tlm_adjoint.fenics.interpolation import LocalMatrix
1295+
from .interpolation import LocalMatrix
12961296
from scipy.sparse import spdiags
12971297

12981298
obs_local, P = interpolation_matrix(uv_obs_pts, interp_space)

0 commit comments

Comments
 (0)