Skip to content

Commit

Permalink
Merge pull request #339 from utf/oxi-featurizer
Browse files Browse the repository at this point in the history
Allow oxidation conversion featurizers to return original object
  • Loading branch information
computron authored Dec 14, 2018
2 parents d471e18 + 2b9d081 commit 1692c66
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 10 deletions.
45 changes: 35 additions & 10 deletions matminer/featurizers/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,6 @@ class StructureToOxidStructure(ConversionFeaturizer):
but instead can be applied to pre-process data or as part of a Pipeline.
Args:
**kwargs: Parameters to control the settings for
`pymatgen.io.structure.Structure.add_oxidation_state_by_guess()`.
target_col_id (str or None): The column in which the converted data will
be written. If the column already exists then an error will be
thrown unless `overwrite_data` is set to `True`. If `target_col_id`
Expand All @@ -386,12 +384,18 @@ class StructureToOxidStructure(ConversionFeaturizer):
will only work if `overwrite_data=True`).
overwrite_data (bool): Overwrite any data in `target_column` if it
exists.
return_original_on_error: If the oxidation states cannot be
guessed and set to True, the structure without oxidation states will
be returned. If set to False, an error will be thrown.
**kwargs: Parameters to control the settings for
`pymatgen.io.structure.Structure.add_oxidation_state_by_guess()`.
"""

def __init__(self, target_col_id='structure_oxid', overwrite_data=False,
**kwargs):
return_original_on_error=False, **kwargs):
super().__init__(target_col_id, overwrite_data)
self.oxi_guess_params = kwargs
self.return_original_on_error = return_original_on_error

def featurize(self, structure):
"""Add oxidation states to a Structure using pymatgen's guessing routines.
Expand All @@ -403,7 +407,17 @@ def featurize(self, structure):
(`pymatgen.core.structure.Structure`): A Structure object decorated
with oxidation states.
"""
structure.add_oxidation_state_by_guess(**self.oxi_guess_params)
els_have_oxi_states = [hasattr(s, "oxi_state") for s in
structure.composition.elements]
if all(els_have_oxi_states):
return [structure]

try:
structure.add_oxidation_state_by_guess(**self.oxi_guess_params)
except ValueError as e:
if not self.return_original_on_error:
raise e

return [structure]

def citations(self):
Expand All @@ -429,8 +443,6 @@ class CompositionToOxidComposition(ConversionFeaturizer):
but instead can be applied to pre-process data or as part of a Pipeline.
Args:
**kwargs: Parameters to control the settings for
`pymatgen.io.structure.Structure.add_oxidation_state_by_guess()`.
target_col_id (str or None): The column in which the converted data will
be written. If the column already exists then an error will be
thrown unless `overwrite_data` is set to `True`. If `target_col_id`
Expand All @@ -444,14 +456,21 @@ class CompositionToOxidComposition(ConversionFeaturizer):
coerce_mixed (bool): If a composition has both species containing
oxid states and not containing oxid states, strips all of the
oxid states and guesses the entire composition's oxid states.
return_original_on_error: If the oxidation states cannot be
guessed and set to True, the composition without oxidation states
will be returned. If set to False, an error will be thrown.
**kwargs: Parameters to control the settings for
`pymatgen.io.structure.Structure.add_oxidation_state_by_guess()`.
"""

def __init__(self, target_col_id='composition_oxid', overwrite_data=False,
coerce_mixed=True, **kwargs):
coerce_mixed=True, return_original_on_error=False,
**kwargs):
super().__init__(target_col_id, overwrite_data)
self.oxi_guess_params = kwargs
self.coerce_mixed = coerce_mixed
self.return_original_on_error = return_original_on_error

def featurize(self, comp):
"""Add oxidation states to a Structure using pymatgen's guessing routines.
Expand All @@ -464,8 +483,10 @@ def featurize(self, comp):
decorated with oxidation states.
"""
els_have_oxi_states = [hasattr(s, "oxi_state") for s in comp.elements]

if all(els_have_oxi_states):
return [comp]

elif any(els_have_oxi_states):
if self.coerce_mixed:
comp = comp.element_composition
Expand All @@ -474,8 +495,13 @@ def featurize(self, comp):
"and without oxidation states. Please enable "
"coercion to all oxidation states with "
"coerce_mixed.".format(comp))
return [comp.add_charges_from_oxi_state_guesses(
**self.oxi_guess_params)]
try:
comp = comp.add_charges_from_oxi_state_guesses(
**self.oxi_guess_params)
except ValueError as e:
if not self.return_original_on_error:
raise e
return [comp]

def citations(self):
return [(
Expand All @@ -488,4 +514,3 @@ def citations(self):

def implementors(self):
return ["Anubhav Jain", "Alex Ganose", "Alex Dunn"]

30 changes: 30 additions & 0 deletions matminer/featurizers/tests/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,43 @@ def test_structure_to_oxidstructure(self):
df = sto.featurize_dataframe(df, 'structure')
self.assertEqual(df["structure"].tolist()[0][0].specie.oxi_state, -1)

# test error handling
test_struct = Structure([5, 0, 0, 0, 5, 0, 0, 0, 5], ['Sb', 'F', 'O'],
[[0, 0, 0], [0.2, 0.2, 0.2], [0.5, 0.5, 0.5]])
df = DataFrame(data={'structure': [test_struct]})
sto = StructureToOxidStructure(return_original_on_error=False,
max_sites=2)
self.assertRaises(ValueError, sto.featurize_dataframe, df,
'structure')

# check non oxi state structure returned correctly
sto = StructureToOxidStructure(return_original_on_error=True,
max_sites=2)
df = sto.featurize_dataframe(df, 'structure')
self.assertEqual(df["structure_oxid"].tolist()[0][0].specie,
Element("Sb"))

def test_composition_to_oxidcomposition(self):
df = DataFrame(data={"composition": [Composition("Fe2O3")]})
cto = CompositionToOxidComposition()
df = cto.featurize_dataframe(df, 'composition')
self.assertEqual(df["composition_oxid"].tolist()[0],
Composition({"Fe3+": 2, "O2-": 3}))

# test error handling
df = DataFrame(data={"composition": [Composition("Fe2O3")]})
cto = CompositionToOxidComposition(
return_original_on_error=False, max_sites=2)
self.assertRaises(ValueError, cto.featurize_dataframe, df,
'composition')

# check non oxi state structure returned correctly
cto = CompositionToOxidComposition(
return_original_on_error=True, max_sites=2)
df = cto.featurize_dataframe(df, 'composition')
self.assertEqual(df["composition_oxid"].tolist()[0],
Composition({"Fe": 2, "O": 3}))

def test_to_istructure(self):
cscl = Structure(Lattice([[4.209, 0, 0], [0, 4.209, 0], [0, 0, 4.209]]),
["Cl", "Cs"], [[0.45, 0.5, 0.5], [0, 0, 0]])
Expand Down

0 comments on commit 1692c66

Please sign in to comment.