Skip to content

BUG: Disk checkpointing with tape.timestepper() causes incorrect BC evaluation at taping. #4206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sghelichkhani opened this issue Apr 7, 2025 · 3 comments · May be fixed by #4284
Open
Labels

Comments

@sghelichkhani
Copy link
Contributor

sghelichkhani commented Apr 7, 2025

When using disk checkpointing with tape.timestepper() (i.e., with SingleDiskStorageSchedule), we observe inconsistent behavior in the forward solve even before invoking any functional or derivative evaluation. This appears specifically when updating a DirichletBC field during a time-stepping loop.

Consider the following reproducer:

from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import SingleDiskStorageSchedule as scheduler
# from checkpoint_schedules import SingleMemoryStorageSchedule as scheduler

continue_annotation()
enable_disk_checkpointing()

tape = get_working_tape()
tape.enable_checkpointing(scheduler())


mesh = checkpointable_mesh(UnitSquareMesh(5, 5))
V = FunctionSpace(mesh, "CG", 2)
T = Function(V).interpolate(0.0)
u = TrialFunction(V)
v = TestFunction(V)
a = inner(grad(u), grad(v)) * dx
x = SpatialCoordinate(mesh)
F = Function(V)
control = Control(F)
F.interpolate(sin(x[0] * pi) * sin(2 * x[1] * pi))
L = F * v * dx

bcs = [DirichletBC(V, T, (1,))]
uu = Function(V)

for i in tape.timestepper(iter(range(3))):
    T.assign(T + 1.0)
    solve(a == L, uu, bcs=bcs, solver_parameters={"ksp_type": "preonly", "pc_type": "lu"})

obj = assemble(uu * uu * dx)
pause_annotation()
print(f"Objective at the end of taping:{obj}")

rf = ReducedFunctional(obj, control)
print(f"Objective by calling functional: {rf(F)}")

would produce inconsistent values:

Objective at the end of taping:4.000901957850139
Objective by calling functional: 1.0009019578501521

whereas NOT checkpointing to disk, by using SingleMemoryStorageSchedule as scheduler (line 4) gives consistent AND different values:

Objective at the end of taping:9.000901957850123
Objective by calling functional: 9.000901957850061
@angus-g
Copy link
Contributor

angus-g commented Apr 7, 2025

Just to follow on from this, here's my diagnosis of what's happening:

We have a few blocks, T (the BC function) <- AssignBlock (for T += 1), <- DirichletBCBlock <- AssignBlock <- DirichletBCBlock. At the end of timestep 0 (or the beginning of timestep 1), we have the first 3 of these blocks on tape. During the add_dependency phase of the solve, the DirichletBCBlock was also checkpointed, so its associated block variable has a _checkpoint member which is just the T function:

def _ad_create_checkpoint(self):
deps = self.block.get_dependencies()
if len(deps) <= 0:
# We don't have any dependencies so the supplied value was not an OverloadedType.
# Most probably it was just a float that is immutable so will never change.
return None
return deps[0]

Now, when timestep 1 is due to begin, the process_taping method checkpoints all the blocks (https://github.com/dolfin-adjoint/pyadjoint/blob/0b1348021fb3bc69504e3e7ec4f90f21ba1b3822/pyadjoint/checkpointing.py#L203-L204):

            self.tape.timesteps[timestep - 1].checkpoint(
                _store_checkpointable_state, _store_adj_dependencies, self._global_deps)

This process goes through the saved_output property (https://github.com/dolfin-adjoint/pyadjoint/blob/0b1348021fb3bc69504e3e7ec4f90f21ba1b3822/pyadjoint/tape.py#L839-L840):

                    for var in self.adjoint_dependencies.union(self.checkpointable_state):
                        self._checkpoint[var] = var.saved_output._ad_create_checkpoint()

Let's consider when var here is the block variable associated with the DirichletBC. The saved_output property will use the existing checkpoint (https://github.com/dolfin-adjoint/pyadjoint/blob/0b1348021fb3bc69504e3e7ec4f90f21ba1b3822/pyadjoint/block_variable.py#L57-L62):

    @property
    def saved_output(self):
        if self.checkpoint is not None:
            return self.output._ad_restore_at_checkpoint(self.checkpoint)
        else:
            return self.output

This has the effect of calling DirichletBCMixin._ad_restore_at_checkpoint on a checkpointed function:

def _ad_restore_at_checkpoint(self, checkpoint):
if checkpoint is not None:
self.set_value(checkpoint.saved_output)
return self

Because of the use of set_value here, the DirichletBC is then modified (and refers to the previous value of T). It seems like the weak checkpointing in DirichletBC is probably correct, which points to the accessing of saved_output during TimeStep.checkpoint being incorrect. I'm really unclear about what's checkpointed and when and when to use .output vs. .saved_output, etc. But it looks to me like it should be something like this:

diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py
index d860025..58e25af 100644
--- a/pyadjoint/tape.py
+++ b/pyadjoint/tape.py
@@ -826,18 +826,21 @@ class TimeStep(list):
                         # because the global dependencies do not change.
                         self._checkpoint[var] = var._checkpoint
                     else:
-                        self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
+                        var.save_output(overwrite=True)
+                        self._checkpoint[var] = var._checkpoint
 
             if adj_dependencies:
                 if self._revised_adj_deps:
                     for var in self.adjoint_dependencies:
-                        self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
+                        var.save_output(overwrite=True)
+                        self._checkpoint[var] = var._checkpoint
                 else:
                     # The adjoint dependencies have not been revised yet. At this stage,
                     # the block nodes are not marked in the path because the control variable(s)
                     # are not yet determined.
                     for var in self.adjoint_dependencies.union(self.checkpointable_state):
-                        self._checkpoint[var] = var.saved_output._ad_create_checkpoint()
+                        var.save_output(overwrite=True)
+                        self._checkpoint[var] = var._checkpoint
 
     def restore_from_checkpoint(self, from_storage):
         """Restore the block var checkpoints from the timestep checkpoint."""

But again, I don't understand all the layers of complexity and similarly-named concepts, so this probably subtly breaks something else (adjoints or revolve schedules maybe?)...

@jrmaddison
Copy link
Contributor

jrmaddison commented Apr 15, 2025

This is my understanding based on some reading of the pyadjoint and Firedrake code, but perhaps others more familiar with the pyadjoint internals might correct:

BlockVariable.output is a symbolic variable and, during the initial forward run (the one used to build the tape), also references a value. This is natural in UFL where symbolic variables (BaseCoefficients) typically also carry values (Functions and Cofunctions). However when actually using the tape BlockVariable.output is used only as a symbolic variable, and the value is accessed via BlockVariable.saved_output. e.g. you can see this in GenericSolveBlock._create_F_form, which builds a new form using saved_outputs. The symbolic variable never changes, but the value in BlockVariable.saved_output can be freed and then rematerialized using a checkpointing schedule.

As an aside: this construction is part of what leads to issue dolfin-adjoint/pyadjoint#169. While BlockVariable.output is only used as a symbolic variable by the adjoint, it still references a value, even though in principle the adjoint should no longer need it.

I think this means that an adjoint calculation should not modify any value referenced by BlockVariable.output, and so replacing

if checkpoint is not None:
self.set_value(checkpoint.saved_output)

with

        if checkpoint is not None:
            bc = type(self)(self.function_space(), checkpoint.saved_output, self.sub_domain)
            # Hack so that bc._ad_create_checkpoint works
            bc.block = self.block
            return bc

should workaround the bug -- which it seems to in the reproducer. I'm not sure of the best way to avoid the hack in the above as I'm not very familiar with FloatingTypes.

@dham
Copy link
Member

dham commented Apr 17, 2025

I haven't had time to look at this in detail, but my understanding is that BlockVariable.output should only be used during taping (I think it's very badly named). BlockVariable should be storing everything that is needed during tape evaluation in attributes other than output.

@Ig-dolci Ig-dolci linked a pull request May 5, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants