Skip to content

Commit 7171a83

Browse files
authored
Merge pull request #126 from dngoldberg/branch_obs_sensitivity
i did pytests and all seemed ok
2 parents ee3d698 + c0b5e6a commit 7171a83

File tree

8 files changed

+419
-57
lines changed

8 files changed

+419
-57
lines changed

Diff for: example_cases/ismipc_30x30/ismipc_30x30.sh

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ cp output_momsolve/ismipc_U_obs.h5 input/
1313

1414
#Run each phase of the model in turn
1515
RUN_DIR=$FENICS_ICE_BASE_DIR/runs/
16-
python $RUN_DIR/run_inv.py ismipc_30x30.toml
17-
python $RUN_DIR/run_forward.py ismipc_30x30.toml
18-
python $RUN_DIR/run_eigendec.py ismipc_30x30.toml
19-
python $RUN_DIR/run_errorprop.py ismipc_30x30.toml
20-
python $RUN_DIR/run_invsigma.py ismipc_30x30.toml
16+
#python $RUN_DIR/run_inv.py ismipc_30x30.toml
17+
#python $RUN_DIR/run_forward.py ismipc_30x30.toml
18+
#python $RUN_DIR/run_eigendec.py ismipc_30x30.toml
19+
#python $RUN_DIR/run_errorprop.py ismipc_30x30.toml
20+
mpirun -n 4 python $RUN_DIR/run_obs_sens_prop.py ismipc_30x30.toml

Diff for: example_cases/ismipc_30x30/ismipc_30x30.toml

+10-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ random_seed = 0
3232
[mesh]
3333

3434
mesh_filename = "ismip_mesh.xml"
35-
periodic_bc = true
35+
periodic_bc = false
3636

3737
[obs]
3838

@@ -67,7 +67,8 @@ sliding_law = 'linear' #budd, linear
6767
[momsolve.picard_params]
6868
nonlinear_solver = "newton"
6969
[momsolve.picard_params.newton_solver]
70-
linear_solver = "umfpack"
70+
linear_solver = "cg"
71+
preconditioner = "hypre_amg"
7172
maximum_iterations = 200
7273
absolute_tolerance = 1.0e-0
7374
relative_tolerance = 1.0e-3
@@ -77,7 +78,9 @@ error_on_nonconvergence = false
7778
[momsolve.newton_params]
7879
nonlinear_solver = "newton"
7980
[momsolve.newton_params.newton_solver]
80-
linear_solver = "umfpack"
81+
#linear_solver = "umfpack"
82+
linear_solver = "cg"
83+
preconditioner = "hypre_amg"
8184
maximum_iterations = 25
8285
absolute_tolerance = 1.0e-7
8386
relative_tolerance = 1.0e-8
@@ -128,6 +131,10 @@ qoi = 'h2' #or 'vaf'
128131
name = "Bottom Edge"
129132
id = 4
130133

134+
[mass_solve]
135+
136+
use_cg_thickness = true
137+
131138
[testing]
132139

133140
expected_init_alpha = 531.6114524861194

Diff for: fenics_ice/config.py

+40-3
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,19 @@ def parse(self):
124124
except KeyError:
125125
pass
126126

127+
try:
128+
obs_sens_dict = self.config_dict['obs_sens']
129+
except KeyError:
130+
obs_sens_dict = {}
131+
self.obs_sens = ObsSensCfg(**obs_sens_dict)
132+
133+
try:
134+
mass_solve_dict = self.config_dict['mass_solve']
135+
except KeyError:
136+
mass_solve_dict = {}
137+
self.mass_solve = MassSolveCfg(**mass_solve_dict)
138+
139+
127140
def check_dirs(self):
128141
"""
129142
Check input directory exists & create output dir if necessary.
@@ -138,13 +151,15 @@ def check_dirs(self):
138151
self.time.phase_name,
139152
self.eigendec.phase_name,
140153
self.error_prop.phase_name,
141-
self.inv_sigma.phase_name]
154+
self.inv_sigma.phase_name,
155+
self.obs_sens.phase_name]
142156

143157
ph_suffix = [self.inversion.phase_suffix,
144158
self.time.phase_suffix,
145159
self.eigendec.phase_suffix,
146160
self.error_prop.phase_suffix,
147-
self.inv_sigma.phase_suffix]
161+
self.inv_sigma.phase_suffix,
162+
self.obs_sens.phase_suffix]
148163

149164
for ph, suff in zip(ph_names, ph_suffix):
150165
out_dir = (outdir / ph / suff)
@@ -253,6 +268,17 @@ class ErrorPropCfg(ConfigPrinter):
253268
phase_name: str = 'error_prop'
254269
phase_suffix: str = ''
255270

271+
272+
@dataclass(frozen=True)
273+
class ObsSensCfg(ConfigPrinter):
274+
"""
275+
Configuration related to observation sensitivities
276+
"""
277+
qoi: str = 'vaf'
278+
phase_name: str = 'obs_sens'
279+
phase_suffix: str = ''
280+
281+
256282
@dataclass(frozen=True)
257283
class SampleCfg(ConfigPrinter):
258284
"""
@@ -394,6 +420,18 @@ def __post_init__(self):
394420
assert self.min_thickness >= 0.0
395421

396422

423+
@dataclass(frozen=True)
424+
class MassSolveCfg(ConfigPrinter):
425+
"""
426+
Options for mass balance solver
427+
"""
428+
429+
use_cg_thickness: bool = False
430+
431+
def __post_init__(self):
432+
""" """
433+
434+
397435
@dataclass(frozen=True)
398436
class MomsolveCfg(ConfigPrinter):
399437
"""
@@ -513,7 +551,6 @@ def __post_init__(self):
513551

514552
for fname in fname_default_suff:
515553
self.set_default_filename(fname, fname_default_suff[fname])
516-
#embed()
517554

518555
@dataclass(frozen=True)
519556
class TimeCfg(ConfigPrinter):

Diff for: fenics_ice/inout.py

-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,6 @@ def write_variable(var, params, name=None, outdir=None, phase_name='', phase_suf
294294
outvar.rename(name, "")
295295
# Prefix the run name
296296
outfname = Path(outdir) / phase_name / phase_suffix / "_".join((params.io.run_name+phase_suffix, name))
297-
#embed()
298297

299298
# Write out output according to user specified format in toml
300299
output_var_format = params.io.output_var_format

Diff for: fenics_ice/model.py

+41-33
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,43 @@
2626
from pathlib import Path
2727
from numpy.random import randn
2828
import logging
29-
from IPython import embed
3029

3130
log = logging.getLogger("fenics_ice")
3231

32+
# Functions for repeated ungridded interpolation
33+
# TODO - this will not handle extrapolation/missing data
34+
# nicely - unfound simplex are returned '-1' which takes the last
35+
# tri.simplices...
36+
37+
# at the moment i have moved these from vel_obs_from_data, so they
38+
# can be called directly from a run script.
39+
# the ismipc test, which calls this function, still seems to perform fine
40+
# but this refactoring may make things less efficient.
41+
def interp_weights(xy, uv, periodic_bc, d=2):
42+
"""Compute the nearest vertices & weights (for reuse)"""
43+
from scipy.spatial import Delaunay
44+
tri = Delaunay(xy)
45+
simplex = tri.find_simplex(uv)
46+
47+
if not np.all(simplex >= 0):
48+
if not periodic_bc:
49+
log.warning("Some points missing in interpolation "
50+
"of velocity obs to function space.")
51+
else:
52+
log.warning("Some points missing in interpolation "
53+
"of velocity obs to function space.")
54+
55+
vertices = np.take(tri.simplices, simplex, axis=0)
56+
temp = np.take(tri.transform, simplex, axis=0)
57+
delta = uv - temp[:, d]
58+
bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
59+
return vertices, np.hstack((bary, 1 - bary.sum(axis=1,
60+
keepdims=True)))
61+
62+
def interpolate(values, vtx, wts):
63+
"""Bilinear interpolation, given vertices & weights above"""
64+
return np.einsum('nj,nj->n', np.take(values, vtx), wts)
65+
3366
class model:
3467
"""
3568
The 'model' object is the core of any fenics_ice simulation. It handles loading input
@@ -122,7 +155,10 @@ def init_fields_from_data(self):
122155
self.bed = self.field_from_data("bed", self.Q)
123156
self.bmelt = self.field_from_data("bmelt", self.M, default=0.0)
124157
self.smb = self.field_from_data("smb", self.M, default=0.0)
125-
self.H_np = self.field_from_data("thick", self.M, min_val=min_thick)
158+
if (self.params.mass_solve.use_cg_thickness and self.params.mesh.periodic_bc):
159+
self.H_np = self.field_from_data("thick", self.Qp, min_val=min_thick)
160+
else:
161+
self.H_np = self.field_from_data("thick", self.M, min_val=min_thick)
126162

127163
if self.params.melt.use_melt_parameterisation:
128164

@@ -308,44 +344,16 @@ def vel_obs_from_data(self):
308344
use_cloud_point=self.params.inversion.use_cloud_point_velocities)
309345
else:
310346
inout.read_vel_obs(infile, model=self)
311-
# Functions for repeated ungridded interpolation
312-
# TODO - this will not handle extrapolation/missing data
313-
# nicely - unfound simplex are returned '-1' which takes the last
314-
# tri.simplices...
315-
def interp_weights(xy, uv, d=2):
316-
"""Compute the nearest vertices & weights (for reuse)"""
317-
from scipy.spatial import Delaunay
318-
tri = Delaunay(xy)
319-
simplex = tri.find_simplex(uv)
320-
321-
if not np.all(simplex >= 0):
322-
if not self.params.mesh.periodic_bc:
323-
log.warning("Some points missing in interpolation "
324-
"of velocity obs to function space.")
325-
else:
326-
log.warning("Some points missing in interpolation "
327-
"of velocity obs to function space.")
328-
329-
vertices = np.take(tri.simplices, simplex, axis=0)
330-
temp = np.take(tri.transform, simplex, axis=0)
331-
delta = uv - temp[:, d]
332-
bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
333-
return vertices, np.hstack((bary, 1 - bary.sum(axis=1,
334-
keepdims=True)))
335-
336-
def interpolate(values, vtx, wts):
337-
"""Bilinear interpolation, given vertices & weights above"""
338-
return np.einsum('nj,nj->n', np.take(values, vtx), wts)
339347

340348
# Grab coordinates of both Lagrangian & DG function spaces
341349
# and compute (once) the interpolating arrays
342350
Q_coords = self.Q.tabulate_dof_coordinates()
343351
M_coords = self.M.tabulate_dof_coordinates()
344352

345353
vtx_Q, wts_Q = interp_weights(self.vel_obs['uv_comp_pts'],
346-
Q_coords)
354+
Q_coords, self.params.mesh.periodic_bc)
347355
vtx_M, wts_M = interp_weights(self.vel_obs['uv_comp_pts'],
348-
M_coords)
356+
M_coords, self.params.mesh.periodic_bc)
349357

350358
# Define new functions to hold results
351359
self.u_obs_Q = Function(self.Q, name="u_obs")
@@ -378,7 +386,7 @@ def interpolate(values, vtx, wts):
378386
# We need to do the same as above but for cloud point data
379387
# so we can write out a nicer output in the mesh coordinates
380388
vtx_Q_c, wts_Q_c = interp_weights(self.vel_obs['uv_obs_pts'],
381-
Q_coords)
389+
Q_coords, self.params.mesh.periodic_bc)
382390

383391
# Define new functions to hold results
384392
self.u_cloud_Q = Function(self.Q, name="u_obs_cloud")

Diff for: fenics_ice/solver.py

+44-12
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import time
3030
import ufl
3131
import weakref
32-
from IPython import embed
3332

3433
log = logging.getLogger("fenics_ice")
3534

@@ -77,11 +76,26 @@ def interpolation_matrix(x_coords, y_space):
7776
return x_local, P
7877

7978

79+
def Amat_obs_action(P, Rvec, vec_cg, dg_space):
80+
# This function implements the Rvec*P*D action on a P1 function
81+
# where D is a projection into DG space
82+
#
83+
# to be called for each component of velocity
84+
#
85+
86+
test, trial = TestFunction(dg_space), TrialFunction(dg_space)
87+
vec_dg = Function(dg_space)
88+
solve(inner(trial, test) * dx == inner(vec_cg, test) * dx,
89+
vec_dg, solver_parameters={"linear_solver": "lu"})
90+
91+
return Rvec * (P @ vec_dg.vector().get_local())
92+
93+
8094
class ssa_solver:
8195
"""
8296
The ssa_solver object is currently the only kind of fenics_ice 'solver' available.
8397
"""
84-
def __init__(self, model, mixed_space=False):
98+
def __init__(self, model, mixed_space=False, obs_sensitivity=False):
8599

86100
# Enable aggressive compiler options
87101
parameters["form_compiler"]["optimize"] = False
@@ -93,6 +107,7 @@ def __init__(self, model, mixed_space=False):
93107
self.model.solvers.append(self)
94108
self.params = model.params
95109
self.mixed_space = mixed_space
110+
self.obs_sensitivity = obs_sensitivity
96111

97112
# Mesh/Function Spaces
98113
self.mesh = model.mesh
@@ -146,10 +161,15 @@ def __init__(self, model, mixed_space=False):
146161
self.U = Function(self.V, name="U")
147162
self.U_np = Function(self.V, name="U_np")
148163
self.Phi = TestFunction(self.V)
149-
self.Ksi = TestFunction(self.M)
150164
self.pTau = TestFunction(self.Qp)
151165

152-
self.trial_H = TrialFunction(self.M)
166+
if not (self.params.mass_solve.use_cg_thickness and self.params.mesh.periodic_bc):
167+
self.trial_H = TrialFunction(self.M)
168+
self.Ksi = TestFunction(self.M)
169+
else:
170+
self.trial_H = TrialFunction(self.Qp)
171+
self.Ksi = TestFunction(self.Qp)
172+
153173

154174
# Facets
155175
self.ff = model.ff
@@ -607,21 +627,26 @@ def def_thickadv_eq(self):
607627
+ inner(jump(Ksi), jump(0.5 * (dot(U_np, nm) + abs(dot(U_np, nm))) * trial_H))
608628
* dS
609629

610-
# Outflow at boundaries
611-
+ conditional(dot(U_np, nm) > 0, 1.0, 0.0)*inner(Ksi, dot(U_np * trial_H, nm))
612-
* ds
613-
614-
# Inflow at boundaries
615-
+ conditional(dot(U_np, nm) < 0, 1.0, 0.0)*inner(Ksi, dot(U_np * H_init, nm))
616-
* ds
617-
618630
# basal melting
619631
+ bmelt*Ksi*dx
620632

621633
# surface mass balance
622634
- smb*Ksi*dx
623635
)
624636

637+
638+
if not (self.params.mass_solve.use_cg_thickness and self.params.mesh.periodic_bc):
639+
self.thickadv = self.thickadv + (
640+
641+
# Outflow at boundaries
642+
+ conditional(dot(U_np, nm) > 0, 1.0, 0.0)*inner(Ksi, dot(U_np * trial_H, nm))
643+
* ds
644+
645+
# Inflow at boundaries
646+
+ conditional(dot(U_np, nm) < 0, 1.0, 0.0)*inner(Ksi, dot(U_np * H_init, nm))
647+
* ds
648+
)
649+
625650
# # Forward euler
626651
# self.thickadv = (inner(Ksi, ((trial_H - H_np) / dt)) * dx
627652
# - inner(grad(Ksi), U_np * H_np) * dx
@@ -1303,9 +1328,16 @@ def comp_J_inv(self, verbose=False):
13031328
J_v_obs, op=MPI.SUM)
13041329
J_v_obs, = J_v_obs
13051330

1331+
u_std_local = u_std[obs_local]
1332+
v_std_local = v_std[obs_local]
1333+
13061334
self._cached_J_mismatch_data \
13071335
= (interp_space,
13081336
u_PRP, v_PRP, l_u_obs, l_v_obs, J_u_obs, J_v_obs)
1337+
if (self.obs_sensitivity):
1338+
self._cached_Amat_vars = \
1339+
(P, u_std_local, v_std_local, obs_local, interp_space)
1340+
13091341
(interp_space,
13101342
u_PRP, v_PRP, l_u_obs, l_v_obs, J_u_obs, J_v_obs) = \
13111343
self._cached_J_mismatch_data

Diff for: runs/run_forward.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def run_forward(config_file):
7272

7373
# Run the forward model
7474
Q = slvr.timestep(adjoint_flag=1, qoi_func=qoi_func)
75+
7576
# Run the adjoint model, computing gradient of Qoi w.r.t cntrl
7677
dQ_ts = compute_gradient(Q, cntrl) # Isaac 27
7778

0 commit comments

Comments
 (0)