Skip to content

Commit 2995acc

Browse files
authored
Quartic Operator (#75)
* remove helpers from tests/models/mono/__init__.py * model tests with inheritance structure * parametric model tests with inheritance * bugfix: fit_regselect_continuous() now returns self * restructure operator tests for nonparametric (parametric still needed) * nonparametric quartic operator * v0.5.12 and update changelog * fix interpolation tests * small printing fix to TimedBlock, bug note for continuous regselect
1 parent d9758fa commit 2995acc

20 files changed

+1961
-1791
lines changed

docs/source/opinf/changelog.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
New versions may introduce substantial new features or API adjustments.
66
:::
77

8+
## Version 0.5.12
9+
10+
- New `operators.QuarticOperator`, plus unit tests.
11+
- Reorganized some unit tests for models and operators to have OOP structure.
12+
- Bugfix: `fit_regselect_continuous()` now returns `self`
13+
814
## Version 0.5.11
915

10-
- New scaling option for ``pre.ShiftScaleTransformer`` so that training snapshots have at maximum norm 1. Contributed by [@nicolearetz](https://github.com/nicolearetz).
11-
- Small clarifications to ``pre.ShiftScaleTransformer`` and updates to the ``pre`` documentation.
16+
- New scaling option for `pre.ShiftScaleTransformer` so that training snapshots have at maximum norm 1. Contributed by [@nicolearetz](https://github.com/nicolearetz).
17+
- Small clarifications to `pre.ShiftScaleTransformer` and updates to the `pre` documentation.
1218

1319
## Version 0.5.10
1420

src/opinf/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
https://github.com/Willcox-Research-Group/rom-operator-inference-Python3
88
"""
99

10-
__version__ = "0.5.11"
10+
__version__ = "0.5.12"
1111

1212
from . import (
1313
basis,

src/opinf/models/mono/_base.py

+22-24
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,8 @@
77
import warnings
88
import numpy as np
99

10-
from ... import errors, lstsq
11-
from ...operators import (
12-
ConstantOperator,
13-
LinearOperator,
14-
QuadraticOperator,
15-
CubicOperator,
16-
InputOperator,
17-
StateInputOperator,
18-
_utils as oputils,
19-
)
10+
from ... import errors, lstsq, operators as _operators
11+
from ...operators import _utils as oputils
2012

2113

2214
class _Model(abc.ABC):
@@ -135,32 +127,32 @@ def __iter__(self):
135127
@property
136128
def c_(self):
137129
""":class:`opinf.operators.ConstantOperator` (or ``None``)."""
138-
return self._get_operator_of_type(ConstantOperator)
130+
return self._get_operator_of_type(_operators.ConstantOperator)
139131

140132
@property
141133
def A_(self):
142134
""":class:`opinf.operators.LinearOperator` (or ``None``)."""
143-
return self._get_operator_of_type(LinearOperator)
135+
return self._get_operator_of_type(_operators.LinearOperator)
144136

145137
@property
146138
def H_(self):
147139
""":class:`opinf.operators.QuadraticOperator` (or ``None``)."""
148-
return self._get_operator_of_type(QuadraticOperator)
140+
return self._get_operator_of_type(_operators.QuadraticOperator)
149141

150142
@property
151143
def G_(self):
152144
""":class:`opinf.operators.CubicOperator` (or ``None``)."""
153-
return self._get_operator_of_type(CubicOperator)
145+
return self._get_operator_of_type(_operators.CubicOperator)
154146

155147
@property
156148
def B_(self):
157149
""":class:`opinf.operators.InputOperator` (or ``None``)."""
158-
return self._get_operator_of_type(InputOperator)
150+
return self._get_operator_of_type(_operators.InputOperator)
159151

160152
@property
161153
def N_(self):
162154
""":class:`opinf.operators.StateInputOperator` (or ``None``)."""
163-
return self._get_operator_of_type(StateInputOperator)
155+
return self._get_operator_of_type(_operators.StateInputOperator)
164156

165157
# Properties: dimensions --------------------------------------------------
166158
@staticmethod
@@ -362,27 +354,33 @@ def _check_is_trained(self):
362354
raise AttributeError("no state_dimension (call fit())")
363355
if self._has_inputs and (self.input_dimension is None):
364356
raise AttributeError("no input_dimension (call fit())")
365-
366-
for op in self.operators:
367-
if op.entries is None:
368-
raise AttributeError("model not trained (call fit())")
357+
if any(oputils.is_uncalibrated(op) for op in self.operators):
358+
raise AttributeError("model not trained (call fit())")
369359

370360
def __eq__(self, other):
371361
"""Two models are equal if they have equivalent operators."""
372362
if not isinstance(other, self.__class__):
373363
return False
374364
if len(self.operators) != len(other.operators):
375365
return False
376-
for selfop, otherop in zip(self.operators, other.operators):
377-
if selfop != otherop:
378-
return False
379366
if self.state_dimension != other.state_dimension:
380367
return False
381368
if self.input_dimension != other.input_dimension:
382369
return False
370+
marked = set()
371+
for selfop in self.operators:
372+
found = False
373+
for i, otherop in enumerate(other.operators):
374+
if selfop == otherop and i not in marked:
375+
found = True
376+
marked.add(i)
377+
break
378+
if not found:
379+
return False
383380
return True
384381

385382
# Model persistence -------------------------------------------------------
386383
def copy(self):
387384
"""Make a copy of the model."""
388-
return self.__class__([op.copy() for op in self.operators])
385+
sol = self.solver.copy() if self.solver is not None else None
386+
return self.__class__([op.copy() for op in self.operators], solver=sol)

src/opinf/models/mono/_nonparametric.py

+13-34
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,8 @@
1414
import scipy.interpolate as spinterpolate
1515

1616
from ._base import _Model
17-
from ... import errors, utils
18-
from ...operators import (
19-
ConstantOperator,
20-
LinearOperator,
21-
QuadraticOperator,
22-
CubicOperator,
23-
InputOperator,
24-
StateInputOperator,
25-
_utils as oputils,
26-
)
27-
28-
29-
_operator_name2class = {
30-
OpClass.__name__: OpClass
31-
for OpClass in (
32-
ConstantOperator,
33-
LinearOperator,
34-
QuadraticOperator,
35-
CubicOperator,
36-
InputOperator,
37-
StateInputOperator,
38-
)
39-
}
17+
from ... import errors, utils, operators as _operators
18+
from ...operators import _utils as oputils
4019

4120

4221
# Base class ==================================================================
@@ -58,12 +37,12 @@ class _NonparametricModel(_Model):
5837

5938
# Properties: operators ---------------------------------------------------
6039
_operator_abbreviations = {
61-
"c": ConstantOperator,
62-
"A": LinearOperator,
63-
"H": QuadraticOperator,
64-
"G": CubicOperator,
65-
"B": InputOperator,
66-
"N": StateInputOperator,
40+
"c": _operators.ConstantOperator,
41+
"A": _operators.LinearOperator,
42+
"H": _operators.QuadraticOperator,
43+
"G": _operators.CubicOperator,
44+
"B": _operators.InputOperator,
45+
"N": _operators.StateInputOperator,
6746
}
6847

6948
@staticmethod
@@ -460,8 +439,8 @@ def load(cls, loadfile: str):
460439
ops = []
461440
for i in range(num_operators):
462441
gp = hf[f"operator_{i}"]
463-
OpClassName = gp["meta"].attrs["class"]
464-
ops.append(_operator_name2class[OpClassName].load(gp))
442+
OpClass = getattr(_operators, gp["meta"].attrs["class"])
443+
ops.append(OpClass.load(gp))
465444

466445
# Construct the model.
467446
model = cls(ops)
@@ -1105,20 +1084,20 @@ class _FrozenSteadyModel(_FrozenMixin, SteadyModel):
11051084
a parametric model.
11061085
"""
11071086

1108-
pass # pragma: no cover
1087+
pass
11091088

11101089

11111090
class _FrozenDiscreteModel(_FrozenMixin, DiscreteModel):
11121091
"""Nonparametric discrete-time model that is the evaluation of
11131092
a parametric model.
11141093
"""
11151094

1116-
pass # pragma: no cover
1095+
pass
11171096

11181097

11191098
class _FrozenContinuousModel(_FrozenMixin, ContinuousModel):
11201099
"""Nonparametric continuous-time model that is the evaluation of
11211100
a parametric model.
11221101
"""
11231102

1124-
pass # pragma: no cover
1103+
pass

src/opinf/models/mono/_parametric.py

+12-38
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,8 @@
2020
_FrozenDiscreteModel,
2121
_FrozenContinuousModel,
2222
)
23-
from ... import errors, utils
24-
from ...operators import (
25-
OperatorTemplate,
26-
ParametricOperatorTemplate,
27-
InterpConstantOperator,
28-
InterpLinearOperator,
29-
InterpQuadraticOperator,
30-
InterpCubicOperator,
31-
InterpInputOperator,
32-
InterpStateInputOperator,
33-
_utils as oputils,
34-
)
35-
36-
37-
_operator_name2class = {
38-
OpClass.__name__: OpClass
39-
for OpClass in (
40-
InterpConstantOperator,
41-
InterpLinearOperator,
42-
InterpQuadraticOperator,
43-
InterpCubicOperator,
44-
InterpInputOperator,
45-
InterpStateInputOperator,
46-
)
47-
}
23+
from ... import errors, utils, operators as _operators
24+
from ...operators import _utils as oputils
4825

4926

5027
# Base classes ================================================================
@@ -82,8 +59,8 @@ def _isvalidoperator(self, op):
8259
return isinstance(
8360
op,
8461
(
85-
OperatorTemplate,
86-
ParametricOperatorTemplate,
62+
_operators.OperatorTemplate,
63+
_operators.ParametricOperatorTemplate,
8764
),
8865
)
8966

@@ -1124,12 +1101,12 @@ def set_interpolator(self, InterpolatorClass):
11241101

11251102
# Properties: operators ---------------------------------------------------
11261103
_operator_abbreviations = {
1127-
"c": InterpConstantOperator,
1128-
"A": InterpLinearOperator,
1129-
"H": InterpQuadraticOperator,
1130-
"G": InterpCubicOperator,
1131-
"B": InterpInputOperator,
1132-
"N": InterpStateInputOperator,
1104+
"c": _operators.InterpConstantOperator,
1105+
"A": _operators.InterpLinearOperator,
1106+
"H": _operators.InterpQuadraticOperator,
1107+
"G": _operators.InterpCubicOperator,
1108+
"B": _operators.InterpInputOperator,
1109+
"N": _operators.InterpStateInputOperator,
11331110
}
11341111

11351112
def _isvalidoperator(self, op):
@@ -1317,11 +1294,8 @@ def load(cls, loadfile: str, InterpolatorClass: type = None):
13171294
for i in range(num_operators):
13181295
gp = hf[f"operator_{i}"]
13191296
OpClassName = gp["meta"].attrs["class"]
1320-
ops.append(
1321-
_operator_name2class[OpClassName].load(
1322-
gp, InterpolatorClass
1323-
)
1324-
)
1297+
OpClass = getattr(_operators, OpClassName)
1298+
ops.append(OpClass.load(gp, InterpolatorClass))
13251299

13261300
# Construct the model.
13271301
model = cls(ops)

0 commit comments

Comments
 (0)