Skip to content

Commit 6c829d7

Browse files
committed
InterpolationDerivativeEstimator
1 parent 7de0b53 commit 6c829d7

File tree

7 files changed

+289
-127
lines changed

7 files changed

+289
-127
lines changed

docs/source/api/missing.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ ddt.ipynb
6161
:toctree: _autosummaries
6262
:nosignatures:
6363

64+
DerivativeEstimatorTemplate
6465
UniformFiniteDifferencer
6566
NonuniformFiniteDifferencer
66-
DerivativeEstimatorTemplate
67+
InterpolationDerivativeEstimator
6768
fwd1
6869
fwd2
6970
fwd3

src/opinf/ddt/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44
from ._base import *
55
from ._finite_difference import *
6+
from ._interpolation import *

src/opinf/ddt/_base.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,34 @@ class DerivativeEstimatorTemplate(abc.ABC):
7676
# Constructor -------------------------------------------------------------
7777
def __init__(self, time_domain):
7878
"""Set the time domain."""
79-
self.time_domain = time_domain
79+
if not isinstance(time_domain, np.ndarray) or time_domain.ndim != 1:
80+
raise ValueError("time_domain must be a one-dimensional array")
81+
self.__t = time_domain
8082

8183
# Properties --------------------------------------------------------------
8284
@property
8385
def time_domain(self):
8486
"""Time domain of the snapshot data, a (k,) ndarray."""
8587
return self.__t
8688

87-
@time_domain.setter
88-
def time_domain(self, t):
89-
"""Set the time domain."""
90-
self.__t = t
91-
9289
# Main routine ------------------------------------------------------------
90+
def _check_dimensions(self, states, inputs, check_against_time=True):
91+
"""Check dimensions and alignment of the state and inputs."""
92+
if states.ndim != 2:
93+
raise errors.DimensionalityError("states must be two-dimensional")
94+
if check_against_time and states.shape[-1] != self.time_domain.size:
95+
raise errors.DimensionalityError(
96+
"states not aligned with time_domain"
97+
)
98+
if inputs is not None:
99+
if inputs.ndim == 1:
100+
inputs = inputs.reshape((1, -1))
101+
if inputs.shape[1] != states.shape[1]:
102+
raise errors.DimensionalityError(
103+
"states and inputs not aligned"
104+
)
105+
return states, inputs
106+
93107
@abc.abstractmethod
94108
def estimate(self, states, inputs=None):
95109
"""Estimate the first time derivatives of the states.
@@ -216,7 +230,7 @@ def verify(self, plot: bool = False, return_errors=False):
216230
t = 1 + (dt * t_base)
217231
Q = np.row_stack([test[1](t) for test in self.__tests])
218232
dQdt = np.row_stack([test[2](t) for test in self.__tests])
219-
self.time_domain = t
233+
self.__t = t
220234

221235
# Call the derivative estimator.
222236
Q_est, dQdt_est = self.estimate(Q, None)
@@ -265,6 +279,6 @@ def verify(self, plot: bool = False, return_errors=False):
265279
for dt, err in zip(dts, estimation_errors[name]):
266280
print(f"dt = {dt:.1e}:\terror = {err:.4e}")
267281

268-
self.time_domain = time_domain # Restore original time domain.
282+
self.__t = time_domain # Restore original time domain.
269283
if return_errors:
270284
return estimation_errors

src/opinf/ddt/_finite_difference.py

+25-97
Original file line numberDiff line numberDiff line change
@@ -835,46 +835,33 @@ class UniformFiniteDifferencer(DerivativeEstimatorTemplate):
835835
def __init__(self, time_domain, scheme="ord4"):
836836
"""Store the time domain and set the finite difference scheme."""
837837
DerivativeEstimatorTemplate.__init__(self, time_domain)
838-
self.scheme = scheme
839838

840-
# Properties --------------------------------------------------------------
841-
@DerivativeEstimatorTemplate.time_domain.setter
842-
def time_domain(self, t):
843-
"""Set the time domain---ensuring it is uniform---and the time step."""
844-
if np.isscalar(t[0]):
845-
# Case 1: the time domain is a one-dimensional array.
846-
diffs = np.diff(t)
847-
if not np.allclose(diffs, diffs[0]):
848-
raise ValueError("time domain must be uniformly spaced")
849-
self.__dt = diffs[0]
850-
elif all(np.ndim(tt) == 1 for tt in t):
851-
# Case 2: there are several time domains (multiple trajectories).
852-
for tt in t:
853-
self.time_domain = tt # Check that each has uniform spacing.
854-
self.__dt = None
855-
else:
856-
raise ValueError("time_domain should be a one-dimensional array")
857-
DerivativeEstimatorTemplate.time_domain.fset(self, t)
839+
# Check for uniform spacing.
840+
diffs = np.diff(time_domain)
841+
if not np.allclose(diffs, diffs[0]):
842+
raise ValueError("time domain must be uniformly spaced")
843+
844+
# Set the finite difference scheme.
845+
if not callable(scheme):
846+
if scheme not in self._schemes:
847+
raise ValueError(
848+
f"invalid finite difference scheme '{scheme}'"
849+
)
850+
scheme = self._schemes[scheme]
851+
self.__scheme = scheme
858852

853+
# Properties --------------------------------------------------------------
859854
@property
860855
def dt(self):
861856
"""Time step."""
862-
return self.__dt
857+
t = self.time_domain
858+
return t[1] - t[0]
863859

864860
@property
865861
def scheme(self):
866862
"""Finite difference engine."""
867863
return self.__scheme
868864

869-
@scheme.setter
870-
def scheme(self, f):
871-
"""Set the finite difference scheme."""
872-
if not callable(f):
873-
if f not in self._schemes:
874-
raise ValueError(f"invalid finite difference scheme '{f}'")
875-
f = self._schemes[f]
876-
self.__scheme = f
877-
878865
# Main routine ------------------------------------------------------------
879866
def estimate(self, states, inputs=None):
880867
r"""Estimate the first time derivatives of the states using
@@ -900,19 +887,7 @@ def estimate(self, states, inputs=None):
900887
Inputs corresponding to ``_states``, if applicable.
901888
**Only returned** if ``inputs`` is provided.
902889
"""
903-
if self.dt is None:
904-
raise RuntimeError("dt is None")
905-
906-
if states.ndim != 2:
907-
raise errors.DimensionalityError("states must be two-dimensional")
908-
if inputs is not None:
909-
if inputs.ndim == 1:
910-
inputs = inputs.reshape((1, -1))
911-
if inputs.shape[1] != states.shape[1]:
912-
raise errors.DimensionalityError(
913-
"states and inputs not aligned"
914-
)
915-
890+
states, inputs = self._check_dimensions(states, inputs, False)
916891
return self.scheme(states, self.dt, inputs)
917892

918893

@@ -933,26 +908,15 @@ class NonuniformFiniteDifferencer(DerivativeEstimatorTemplate):
933908

934909
def __init__(self, time_domain):
935910
"""Set the time domain."""
936-
self.__check_uniformity = True
937911
DerivativeEstimatorTemplate.__init__(self, time_domain)
938912

939-
# Properties --------------------------------------------------------------
940-
@DerivativeEstimatorTemplate.time_domain.setter
941-
def time_domain(self, t):
942-
"""Set the time domain. Raise a warning if it is uniform."""
943-
if np.isscalar(t[0]):
944-
# Case 1: the time domain is a one-dimensional array.
945-
if self.__check_uniformity and np.allclose(
946-
diffs := np.diff(t), diffs[0]
947-
):
948-
warnings.warn(
949-
"time_domain is uniformly spaced, consider using "
950-
"UniformFiniteDifferencer",
951-
errors.OpInfWarning,
952-
)
953-
elif not all(np.ndim(tt) == 1 for tt in t): # OK if several domains.
954-
raise ValueError("time_domain should be a one-dimensional array")
955-
DerivativeEstimatorTemplate.time_domain.fset(self, t)
913+
# Warn if time_domain in not uniform.
914+
if np.allclose(diffs := np.diff(time_domain), diffs[0]):
915+
warnings.warn(
916+
"time_domain is uniformly spaced, consider using "
917+
"UniformFiniteDifferencer",
918+
errors.OpInfWarning,
919+
)
956920

957921
# Main routine ------------------------------------------------------------
958922
def estimate(self, states, inputs=None):
@@ -976,20 +940,7 @@ def estimate(self, states, inputs=None):
976940
Inputs corresponding to ``_states``, if applicable.
977941
**Only returned** if ``inputs`` is provided.
978942
"""
979-
# Check dimensions.
980-
if states.ndim != 2:
981-
raise errors.DimensionalityError("states must be two-dimensional")
982-
if states.shape[-1] != self.time_domain.size:
983-
raise errors.DimensionalityError(
984-
"states not aligned with time_domain"
985-
)
986-
if inputs is not None:
987-
if inputs.ndim == 1:
988-
inputs = inputs.reshape((1, -1))
989-
if inputs.shape[1] != self.time_domain.size:
990-
raise errors.DimensionalityError(
991-
"inputs not aligned with time_domain"
992-
)
943+
states, inputs = self._check_dimensions(states, inputs)
993944

994945
# Do the computation.
995946
ddts = np.gradient(states, self.time_domain, edge_order=2, axis=-1)
@@ -998,29 +949,6 @@ def estimate(self, states, inputs=None):
998949
return states, ddts, inputs
999950
return states, ddts
1000951

1001-
# Verification ------------------------------------------------------------
1002-
def verify(self, plot: bool = False):
1003-
"""Verify that :meth:`estimate()` is consistent in the sense that the
1004-
all outputs have the same number of columns and test the accuracy of
1005-
the results on a few test problems.
1006-
1007-
Parameters
1008-
----------
1009-
plot : bool
1010-
If ``True``, plot the relative errors of the derivative estimation
1011-
errors as a function of the time step.
1012-
If ``False`` (default), print a report of the relative errors.
1013-
1014-
Returns
1015-
-------
1016-
errors : dict
1017-
Estimation errors for each test case.
1018-
Time steps are listed as ``errors[dts]``.
1019-
"""
1020-
self.__check_uniformity = False
1021-
DerivativeEstimatorTemplate.verify(self, plot=plot)
1022-
self.__check_uniformity = True
1023-
1024952

1025953
# Old API =====================================================================
1026954
def ddt_uniform(states, dt, order=2):

0 commit comments

Comments
 (0)