Skip to content

Commit b8dee52

Browse files
authored
Fix uncertainty as int for EnergyAdjustment (#4426)
* cast uncertainty_per_atom to float * cast uncertainty to float * avoid == for possible float compare * more tests * add test and tweak docstring type
1 parent af567bf commit b8dee52

File tree

2 files changed

+94
-75
lines changed

2 files changed

+94
-75
lines changed

src/pymatgen/entries/computed_entries.py

Lines changed: 86 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from typing_extensions import Self
3232

33+
from pymatgen.analysis.phase_diagram import PhaseDiagram
3334
from pymatgen.core import Structure
3435

3536
__author__ = "Ryan Kingsbury, Matt McDermott, Shyue Ping Ong, Anubhav Jain"
@@ -50,12 +51,12 @@ class EnergyAdjustment(MSONable):
5051

5152
def __init__(
5253
self,
53-
value,
54-
uncertainty=np.nan,
55-
name="Manual adjustment",
56-
cls=None,
57-
description="",
58-
):
54+
value: float,
55+
uncertainty: float = np.nan,
56+
name: str = "Manual adjustment",
57+
cls: dict | None = None,
58+
description: str = "",
59+
) -> None:
5960
"""
6061
Args:
6162
value (float): value of the energy adjustment in eV
@@ -69,20 +70,20 @@ def __init__(
6970
self.cls = cls or {}
7071
self.description = description
7172
self._value = value
72-
self._uncertainty = uncertainty
73+
self._uncertainty = float(uncertainty)
7374

7475
@property
75-
def value(self):
76+
def value(self) -> float:
7677
"""The value of the energy correction in eV."""
7778
return self._value
7879

7980
@property
80-
def uncertainty(self):
81+
def uncertainty(self) -> float:
8182
"""The uncertainty in the value of the energy adjustment in eV."""
8283
return self._uncertainty
8384

8485
@abc.abstractmethod
85-
def normalize(self, factor):
86+
def normalize(self, factor: float) -> None:
8687
"""Scale the value of the current energy adjustment by factor in-place.
8788
8889
This method is utilized in ComputedEntry.normalize() to scale the energies to a formula unit basis
@@ -91,10 +92,10 @@ def normalize(self, factor):
9192

9293
@property
9394
@abc.abstractmethod
94-
def explain(self):
95+
def explain(self) -> str:
9596
"""Return an explanation of how the energy adjustment is calculated."""
9697

97-
def __repr__(self):
98+
def __repr__(self) -> str:
9899
name, value, uncertainty, description = (
99100
self.name,
100101
float(self.value),
@@ -115,28 +116,28 @@ class ConstantEnergyAdjustment(EnergyAdjustment):
115116

116117
def __init__(
117118
self,
118-
value,
119-
uncertainty=np.nan,
120-
name="Constant energy adjustment",
121-
cls=None,
122-
description="Constant energy adjustment",
123-
):
119+
value: float,
120+
uncertainty: float = np.nan,
121+
name: str = "Constant energy adjustment",
122+
cls: dict | None = None,
123+
description: str = "Constant energy adjustment",
124+
) -> None:
124125
"""
125126
Args:
126-
value: float, value of the energy adjustment in eV
127-
uncertainty: float, uncertainty of the energy adjustment in eV. (Default: np.nan)
128-
name: str, human-readable name of the energy adjustment.
127+
value (float): the energy adjustment in eV
128+
uncertainty (float): uncertainty of the energy adjustment in eV. (Default: np.nan)
129+
name (str): human-readable name of the energy adjustment.
129130
(Default: Constant energy adjustment)
130-
cls: dict, Serialized Compatibility class used to generate the energy
131+
cls (dict): Serialized Compatibility class used to generate the energy
131132
adjustment. (Default: None)
132-
description: str, human-readable explanation of the energy adjustment.
133+
description (str): human-readable explanation of the energy adjustment.
133134
"""
134135
super().__init__(value, uncertainty, name=name, cls=cls, description=description)
135136
self._value = value
136-
self._uncertainty = uncertainty
137+
self._uncertainty = float(uncertainty)
137138

138139
@property
139-
def explain(self):
140+
def explain(self) -> str:
140141
"""An explanation of how the energy adjustment is calculated."""
141142
return f"{self.description} ({self.value:.3f} eV)"
142143

@@ -145,7 +146,7 @@ def normalize(self, factor: float) -> None:
145146
factor.
146147
147148
Args:
148-
factor: factor to divide by.
149+
factor (float): factor to divide by.
149150
"""
150151
self._value /= factor
151152
self._uncertainty /= factor
@@ -154,10 +155,10 @@ def normalize(self, factor: float) -> None:
154155
class ManualEnergyAdjustment(ConstantEnergyAdjustment):
155156
"""A manual energy adjustment applied to a ComputedEntry."""
156157

157-
def __init__(self, value):
158+
def __init__(self, value: float) -> None:
158159
"""
159160
Args:
160-
value: float, value of the energy adjustment in eV.
161+
value (float): the energy adjustment in eV.
161162
"""
162163
name = "Manual energy adjustment"
163164
description = "Manual energy adjustment"
@@ -171,44 +172,44 @@ class CompositionEnergyAdjustment(EnergyAdjustment):
171172

172173
def __init__(
173174
self,
174-
adj_per_atom,
175-
n_atoms,
176-
uncertainty_per_atom=np.nan,
177-
name="",
178-
cls=None,
179-
description="Composition-based energy adjustment",
180-
):
175+
adj_per_atom: float,
176+
n_atoms: float,
177+
uncertainty_per_atom: float = np.nan,
178+
name: str = "",
179+
cls: dict | None = None,
180+
description: str = "Composition-based energy adjustment",
181+
) -> None:
181182
"""
182183
Args:
183-
adj_per_atom: float, energy adjustment to apply per atom, in eV/atom
184-
n_atoms: float or int, number of atoms.
185-
uncertainty_per_atom: float, uncertainty in energy adjustment to apply per atom, in eV/atom.
184+
adj_per_atom (float): energy adjustment to apply per atom, in eV/atom
185+
n_atoms (float): number of atoms.
186+
uncertainty_per_atom (float): uncertainty in energy adjustment to apply per atom, in eV/atom.
186187
(Default: np.nan)
187-
name: str, human-readable name of the energy adjustment.
188+
name (str): human-readable name of the energy adjustment.
188189
(Default: "")
189-
cls: dict, Serialized Compatibility class used to generate the energy
190+
cls (dict): Serialized Compatibility class used to generate the energy
190191
adjustment. (Default: None)
191-
description: str, human-readable explanation of the energy adjustment.
192+
description (str): human-readable explanation of the energy adjustment.
192193
"""
193194
self._adj_per_atom = adj_per_atom
194-
self.uncertainty_per_atom = uncertainty_per_atom
195+
self.uncertainty_per_atom = float(uncertainty_per_atom)
195196
self.n_atoms = n_atoms
196197
self.cls = cls or {}
197198
self.name = name
198199
self.description = description
199200

200201
@property
201-
def value(self):
202+
def value(self) -> float:
202203
"""The value of the energy adjustment in eV."""
203204
return self._adj_per_atom * self.n_atoms
204205

205206
@property
206-
def uncertainty(self):
207+
def uncertainty(self) -> float:
207208
"""The value of the energy adjustment in eV."""
208209
return self.uncertainty_per_atom * self.n_atoms
209210

210211
@property
211-
def explain(self):
212+
def explain(self) -> str:
212213
"""An explanation of how the energy adjustment is calculated."""
213214
return f"{self.description} ({self._adj_per_atom:.3f} eV/atom x {self.n_atoms} atoms)"
214215

@@ -229,47 +230,47 @@ class TemperatureEnergyAdjustment(EnergyAdjustment):
229230

230231
def __init__(
231232
self,
232-
adj_per_deg,
233-
temp,
234-
n_atoms,
235-
uncertainty_per_deg=np.nan,
236-
name="",
237-
cls=None,
238-
description="Temperature-based energy adjustment",
239-
):
233+
adj_per_deg: float,
234+
temp: float,
235+
n_atoms: float,
236+
uncertainty_per_deg: float = np.nan,
237+
name: str = "",
238+
cls: dict | None = None,
239+
description: str = "Temperature-based energy adjustment",
240+
) -> None:
240241
"""
241242
Args:
242-
adj_per_deg: float, energy adjustment to apply per degree K, in eV/atom
243-
temp: float, temperature in Kelvin
244-
n_atoms: float or int, number of atoms
245-
uncertainty_per_deg: float, uncertainty in energy adjustment to apply per degree K,
243+
adj_per_deg (float): energy adjustment to apply per degree K, in eV/atom
244+
temp (float): temperature in Kelvin
245+
n_atoms (float): number of atoms
246+
uncertainty_per_deg (float): uncertainty in energy adjustment to apply per degree K,
246247
in eV/atom. (Default: np.nan)
247-
name: str, human-readable name of the energy adjustment.
248+
name (str): human-readable name of the energy adjustment.
248249
(Default: "")
249-
cls: dict, Serialized Compatibility class used to generate the energy
250+
cls (dict): Serialized Compatibility class used to generate the energy
250251
adjustment. (Default: None)
251-
description: str, human-readable explanation of the energy adjustment.
252+
description (str): human-readable explanation of the energy adjustment.
252253
"""
253254
self._adj_per_deg = adj_per_deg
254-
self.uncertainty_per_deg = uncertainty_per_deg
255+
self.uncertainty_per_deg = float(uncertainty_per_deg)
255256
self.temp = temp
256257
self.n_atoms = n_atoms
257258
self.name = name
258259
self.cls = cls or {}
259260
self.description = description
260261

261262
@property
262-
def value(self):
263+
def value(self) -> float:
263264
"""The value of the energy correction in eV."""
264265
return self._adj_per_deg * self.temp * self.n_atoms
265266

266267
@property
267-
def uncertainty(self):
268+
def uncertainty(self) -> float:
268269
"""The value of the energy adjustment in eV."""
269270
return self.uncertainty_per_deg * self.temp * self.n_atoms
270271

271272
@property
272-
def explain(self):
273+
def explain(self) -> str:
273274
"""An explanation of how the energy adjustment is calculated."""
274275
return f"{self.description} ({self._adj_per_deg:.4f} eV/K/atom x {self.temp} K x {self.n_atoms} atoms)"
275276

@@ -298,7 +299,7 @@ def __init__(
298299
parameters: dict | None = None,
299300
data: dict | None = None,
300301
entry_id: str | None = None,
301-
):
302+
) -> None:
302303
"""Initialize a ComputedEntry.
303304
304305
Args:
@@ -322,7 +323,7 @@ def __init__(
322323
super().__init__(composition, energy)
323324
self.energy_adjustments = energy_adjustments or []
324325

325-
if correction != 0.0:
326+
if not math.isclose(correction, 0.0):
326327
if energy_adjustments:
327328
raise ValueError(
328329
f"Argument conflict! Setting correction = {correction:.3f} conflicts "
@@ -395,7 +396,7 @@ def correction_uncertainty(self) -> float:
395396
for ea in self.energy_adjustments
396397
) or ufloat(0.0, np.nan)
397398

398-
if unc.nominal_value != 0 and unc.std_dev == 0:
399+
if not math.isclose(unc.nominal_value, 0) and math.isclose(unc.std_dev, 0):
399400
return np.nan
400401

401402
return unc.std_dev
@@ -498,7 +499,7 @@ def from_dict(cls, dct: dict) -> Self:
498499
# we don't pass correction explicitly because it will be calculated
499500
# on the fly from energy_adjustments
500501
correction = 0
501-
if dct["correction"] != 0 and len(energy_adj) == 0:
502+
if not math.isclose(dct["correction"], 0) and len(energy_adj) == 0:
502503
# this block is for legacy ComputedEntry that were
503504
# serialized before we had the energy_adjustments attribute.
504505
correction = dct["correction"]
@@ -629,7 +630,7 @@ def from_dict(cls, dct: dict) -> Self:
629630
"""
630631
# the first block here is for legacy ComputedEntry that were
631632
# serialized before we had the energy_adjustments attribute.
632-
if dct["correction"] != 0 and not dct.get("energy_adjustments"):
633+
if not math.isclose(dct["correction"], 0) and not dct.get("energy_adjustments"):
633634
struct = MontyDecoder().process_decoded(dct["structure"])
634635
return cls(
635636
struct,
@@ -709,7 +710,7 @@ def __init__(
709710
parameters: dict | None = None,
710711
data: dict | None = None,
711712
entry_id: str | None = None,
712-
):
713+
) -> None:
713714
"""
714715
Args:
715716
structure (Structure): The pymatgen Structure object of an entry.
@@ -891,7 +892,12 @@ def _g_delta_sisso(vol_per_atom, reduced_mass, temp) -> float:
891892
)
892893

893894
@classmethod
894-
def from_pd(cls, pd, temp=300, gibbs_model="SISSO") -> list[Self]:
895+
def from_pd(
896+
cls,
897+
pd: PhaseDiagram,
898+
temp: float = 300,
899+
gibbs_model: Literal["SISSO"] = "SISSO",
900+
) -> list[Self]:
895901
"""Constructor method for initializing a list of GibbsComputedStructureEntry
896902
objects from an existing T = 0 K phase diagram composed of
897903
ComputedStructureEntry objects, as acquired from a thermochemical database;
@@ -900,7 +906,7 @@ def from_pd(cls, pd, temp=300, gibbs_model="SISSO") -> list[Self]:
900906
Args:
901907
pd (PhaseDiagram): T = 0 K phase diagram as created in pymatgen. Must
902908
contain ComputedStructureEntry objects.
903-
temp (int): Temperature [K] for estimating Gibbs free energy of formation.
909+
temp (float): Temperature [K] for estimating Gibbs free energy of formation.
904910
gibbs_model (str): Gibbs model to use; currently the only option is "SISSO".
905911
906912
Returns:
@@ -925,15 +931,20 @@ def from_pd(cls, pd, temp=300, gibbs_model="SISSO") -> list[Self]:
925931
return gibbs_entries
926932

927933
@classmethod
928-
def from_entries(cls, entries, temp=300, gibbs_model="SISSO") -> list[Self]:
934+
def from_entries(
935+
cls,
936+
entries: list,
937+
temp: float = 300,
938+
gibbs_model: Literal["SISSO"] = "SISSO",
939+
) -> list[Self]:
929940
"""Constructor method for initializing GibbsComputedStructureEntry objects from
930941
T = 0 K ComputedStructureEntry objects, as acquired from a thermochemical
931942
database e.g. The Materials Project.
932943
933944
Args:
934945
entries ([ComputedStructureEntry]): List of ComputedStructureEntry objects,
935946
as downloaded from The Materials Project API.
936-
temp (int): Temperature [K] for estimating Gibbs free energy of formation.
947+
temp (float): Temperature [K] for estimating Gibbs free energy of formation.
937948
gibbs_model (str): Gibbs model to use; currently the only option is "SISSO".
938949
939950
Returns:
@@ -978,7 +989,7 @@ def from_dict(cls, dct: dict) -> Self:
978989
entry_id=dct.get("entry_id"),
979990
)
980991

981-
def __repr__(self):
992+
def __repr__(self) -> str:
982993
return (
983994
f"GibbsComputedStructureEntry {self.entry_id} - {self.formula}\n"
984995
f"Gibbs Free Energy (Formation) = {self.energy:.4f}"

0 commit comments

Comments
 (0)