From 6f8375e2b4a0a978f6ae8818d870b7c0f57e89f0 Mon Sep 17 00:00:00 2001 From: Stef Smeets Date: Mon, 18 Nov 2024 10:52:05 +0100 Subject: [PATCH] Fix all mypy issues (#341) * Fix all mypy issues * Fix kaleido version 0.4.1 contains an issue that breaks the tests: https://github.com/plotly/Kaleido/issues/223 * Fix tests * Fix version limit --- pyproject.toml | 2 +- scripts/analyse_md.py | 2 +- src/gemdat/io.py | 2 +- src/gemdat/jumps.py | 2 +- src/gemdat/orientations.py | 3 ++ src/gemdat/path.py | 3 +- src/gemdat/plots/_shared.py | 3 ++ .../plots/matplotlib/_msd_per_element.py | 7 +++-- src/gemdat/plots/plotly/_msd_per_element.py | 3 ++ src/gemdat/plots/plotly/_plot3d.py | 27 +++++++++-------- src/gemdat/rdf.py | 4 +-- src/gemdat/shape.py | 5 ++-- src/gemdat/trajectory.py | 29 ++++++++++++++----- src/gemdat/volume.py | 3 +- 14 files changed, 63 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 38668231..814231ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ changelog = "https://github.com/GEMDAT-repos/GEMDAT/releases" [project.optional-dependencies] develop = [ - "kaleido", + "kaleido < 0.4", # 0.4: https://github.com/plotly/Kaleido/issues/223 "bump-my-version", "coverage[toml]", "mypy", diff --git a/scripts/analyse_md.py b/scripts/analyse_md.py index 24dfc367..131dde47 100644 --- a/scripts/analyse_md.py +++ b/scripts/analyse_md.py @@ -81,7 +81,7 @@ def analyse_md( """ trajectory = Trajectory.from_vasprun(vasp_xml) - equilibration_steps = round(equil_time / trajectory.time_step) + equilibration_steps = round(equil_time / trajectory.time_step) # type: ignore trajectory = trajectory[equilibration_steps:] diff --git a/src/gemdat/io.py b/src/gemdat/io.py index c6b7e40e..bf7aff7e 100644 --- a/src/gemdat/io.py +++ b/src/gemdat/io.py @@ -31,7 +31,7 @@ def write_cif(structure: Structure, filename: Path | str): filename : Path | str Filename to write to """ - filename = Path(filename).with_suffix('.cif') + filename = str(Path(filename).with_suffix('.cif')) structure.to_file(filename) diff --git a/src/gemdat/jumps.py b/src/gemdat/jumps.py index aa18efac..81b0994c 100644 --- a/src/gemdat/jumps.py +++ b/src/gemdat/jumps.py @@ -160,7 +160,7 @@ def site_pairs(self) -> list[tuple[str, str]]: """Return list of all unique site pairs.""" labels = self.sites.labels site_pairs = product(labels, repeat=2) - return [pair for pair in site_pairs] + return [pair for pair in site_pairs] # type: ignore @property def jump_names(self) -> list[str]: diff --git a/src/gemdat/orientations.py b/src/gemdat/orientations.py index 4957aa12..91f54fe2 100644 --- a/src/gemdat/orientations.py +++ b/src/gemdat/orientations.py @@ -58,6 +58,7 @@ def __post_init__(self, in_vectors: np.ndarray | None = None): @property def _time_step(self) -> float: """Return the time step of the trajectory.""" + assert self.trajectory.time_step return self.trajectory.time_step @property @@ -75,7 +76,9 @@ def _distances(self) -> np.ndarray: """Calculate distances between every central atom and all satellite atoms.""" central_start_coord = self._trajectory_cent.base_positions + assert central_start_coord is not None satellite_start_coord = self._trajectory_sat.base_positions + assert satellite_start_coord is not None lattice = self.trajectory.get_lattice() distance = np.array( [ diff --git a/src/gemdat/path.py b/src/gemdat/path.py index 7207bdfd..a8b418a2 100644 --- a/src/gemdat/path.py +++ b/src/gemdat/path.py @@ -69,9 +69,10 @@ def total_length(self, lattice: Lattice) -> FloatWithUnit: length : FloatWithUnit Total distance in Ã…ngstrom """ - length = 0 + length = 0.0 for a, b in pairwise(self.frac_sites()): dist, _ = lattice.get_distance_and_image(a, b) + assert dist length += dist return FloatWithUnit(length, 'ang') diff --git a/src/gemdat/plots/_shared.py b/src/gemdat/plots/_shared.py index ec3f427b..55e4b583 100644 --- a/src/gemdat/plots/_shared.py +++ b/src/gemdat/plots/_shared.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +from pymatgen.core import Element, Species from scipy.optimize import curve_fit from scipy.stats import skewnorm @@ -25,6 +26,8 @@ def _mean_displacements_per_element( grouped = defaultdict(list) for sp, distances in zip(species, trajectory.distances_from_base_position()): + assert isinstance(sp, (Species, Element)), f'got {type(sp)}' + grouped[sp.symbol].append(distances) means = {} diff --git a/src/gemdat/plots/matplotlib/_msd_per_element.py b/src/gemdat/plots/matplotlib/_msd_per_element.py index 99ad42dc..c8855266 100644 --- a/src/gemdat/plots/matplotlib/_msd_per_element.py +++ b/src/gemdat/plots/matplotlib/_msd_per_element.py @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt import numpy as np +from pymatgen.core import Element, Species if TYPE_CHECKING: import matplotlib.figure @@ -41,6 +42,8 @@ def msd_per_element( t_values = np.arange(len(trajectory)) * time_ps for sp in species: + assert isinstance(sp, (Species, Element)), f'got {type(sp)}' + traj = trajectory.filter(sp.symbol) msd = traj.mean_squared_displacement() @@ -52,9 +55,9 @@ def msd_per_element( last_color = ax.lines[-1].get_color() if show_traces: - for i, traj in enumerate(msd): + for i, y_values in enumerate(msd): label = f'{sp.symbol} trajectories' if (i == 0) else None - ax.plot(t_values, traj, lw=0.1, c=last_color, label=label) + ax.plot(t_values, y_values, lw=0.1, c=last_color, label=label) if show_shaded: ax.fill_between( diff --git a/src/gemdat/plots/plotly/_msd_per_element.py b/src/gemdat/plots/plotly/_msd_per_element.py index 79ce9538..c7647f58 100644 --- a/src/gemdat/plots/plotly/_msd_per_element.py +++ b/src/gemdat/plots/plotly/_msd_per_element.py @@ -4,6 +4,7 @@ import numpy as np import plotly.graph_objects as go +from pymatgen.core import Element, Species from gemdat.plots._shared import hex2rgba @@ -31,6 +32,8 @@ def msd_per_element(*, trajectory: Trajectory) -> go.Figure: species = list(set(trajectory.species)) for i, sp in enumerate(species): + assert isinstance(sp, (Species, Element)), f'got {type(sp)}' + color_hex = fig.layout['template']['layout']['colorway'][i] color_rgba = hex2rgba(color_hex, opacity=0.3) diff --git a/src/gemdat/plots/plotly/_plot3d.py b/src/gemdat/plots/plotly/_plot3d.py index 1e1dce84..7537f8f3 100644 --- a/src/gemdat/plots/plotly/_plot3d.py +++ b/src/gemdat/plots/plotly/_plot3d.py @@ -246,35 +246,34 @@ def plot_jumps(jumps: Jumps, *, fig: go.Figure): fig : plotly.graph_objects.Figure Plotly figure to add traces too """ - coords = jumps.sites.frac_coords + site_coords = jumps.sites.frac_coords lattice = jumps.trajectory.get_lattice() - for i, j in zip(*np.triu_indices(len(coords), k=1)): + for i, j in zip(*np.triu_indices(len(site_coords), k=1)): count = jumps.matrix()[i, j] + jumps.matrix()[j, i] if count == 0: continue - coord_i = tuple(coords[i].tolist()) - coord_j = tuple(coords[j].tolist()) + site_coord_i = tuple(site_coords[i].tolist()) + site_coord_j = tuple(site_coords[j].tolist()) lw = 1 + np.log(count) - length, image = lattice.get_distance_and_image(coord_i, coord_j) + length, image = lattice.get_distance_and_image(site_coord_i, site_coord_j) if np.any(image != 0): - lines = [(coord_i, coord_j + image), (coord_i - image, coord_j)] + lines = [(site_coord_i, site_coord_j + image), (site_coord_i - image, site_coord_j)] else: - lines = [(coord_i, coord_j)] + lines = [(site_coord_i, site_coord_j)] for line in lines: - line = lattice.get_cartesian_coords(line) - line_t = [_ for _ in zip(*line)] # transpose, but pythonic + x, y, z = lattice.get_cartesian_coords(line).T fig.add_trace( go.Scatter3d( - x=line_t[0], - y=line_t[1], - z=line_t[2], + x=x, + y=y, + z=z, mode='lines', showlegend=False, line_dash='dashdot' if any(image) != 0 else 'solid', @@ -356,6 +355,10 @@ def plot_3d( lattice = structure.lattice elif jumps: lattice = jumps.trajectory.get_lattice() + else: + raise ValueError( + 'Lattice cannot be determined form volume, structure, or jumps object.' + ) else: raise ValueError('Cannot derive lattice from input.') diff --git a/src/gemdat/rdf.py b/src/gemdat/rdf.py index a65a1d15..c2b58dbd 100644 --- a/src/gemdat/rdf.py +++ b/src/gemdat/rdf.py @@ -143,8 +143,8 @@ def radial_distribution( coords = trajectory.positions sp_coords = trajectory.filter(floating_specie).positions - states2str = _get_states(sites.labels) - states_array = _get_states_array(transitions, sites.labels) + states2str = _get_states(sites.labels) # type: ignore + states_array = _get_states_array(transitions, sites.labels) # type: ignore symbol_indices = _get_symbol_indices(base_structure) bins = np.arange(0, max_dist + resolution, resolution) diff --git a/src/gemdat/shape.py b/src/gemdat/shape.py index 4df60873..aff4234c 100644 --- a/src/gemdat/shape.py +++ b/src/gemdat/shape.py @@ -12,6 +12,7 @@ from .utils import warn_lattice_not_close if TYPE_CHECKING: + from pymatgen.symmetry.analyzer import SpacegroupOperations from pymatgen.symmetry.groups import SpaceGroup from pymatgen.symmetry.structure import SymmetrizedStructure @@ -95,7 +96,7 @@ def __init__( *, sites: Collection[PeriodicSite], lattice: Lattice, - spacegroup: SpaceGroup, + spacegroup: SpaceGroup | SpacegroupOperations, ): """Set up shape analyzer from a collection of unique periodic sites, the lattice, and spacegroup. @@ -400,7 +401,7 @@ def to_structure(self) -> Structure: sg=self.spacegroup.int_number, lattice=self.lattice, species=[site.specie for site in self.sites], - coords=[site.frac_coords for site in self.sites], + coords=[site.frac_coords for site in self.sites], # type: ignore labels=[site.label for site in self.sites], ) return structure diff --git a/src/gemdat/trajectory.py b/src/gemdat/trajectory.py index 0192a60d..b88a6ce9 100644 --- a/src/gemdat/trajectory.py +++ b/src/gemdat/trajectory.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Collection, Optional import numpy as np -from pymatgen.core import Element, Lattice +from pymatgen.core import Element, Lattice, Species from pymatgen.core.trajectory import Trajectory as PymatgenTrajectory from pymatgen.io import vasp @@ -133,16 +133,19 @@ def to_volume(self, resolution: float = 0.2) -> Volume: @property def time_step_ps(self) -> float: """Return time step in picoseconds.""" + assert self.time_step return self.time_step * 1e12 @property def total_time(self) -> float: """Return total time for trajectory.""" + assert self.time_step return len(self) * self.time_step @property def sampling_frequency(self) -> float: """Return number of time steps per second.""" + assert self.time_step return 1 / self.time_step @property @@ -469,9 +472,9 @@ def get_lattice(self, idx: int | None = None) -> Lattice: Pymatgen Lattice object """ if self.constant_lattice: - return Lattice(self.lattice) + return Lattice(self.lattice) # type: ignore - latt = self.lattices[idx] + latt = self.lattices[idx] # type: ignore return Lattice(latt) @property @@ -503,7 +506,10 @@ def distances_from_base_position(self) -> np.ndarray: def center_of_mass(self) -> Trajectory: """Return trajectory with center of mass for positions.""" - weights = [s.atomic_mass for s in self.species] + weights = [] + for s in self.species: + assert isinstance(s, (Species, Element)), f'got {type(s)=}' + weights.append(s.atomic_mass) positions_no_pbc = self.base_positions + self.cumulative_displacements @@ -547,8 +553,13 @@ def drift( if fixed_species: displacements = self.filter(species=fixed_species).displacements elif floating_species: - species = {sp.symbol for sp in self.species if sp.symbol not in floating_species} - displacements = self.filter(species=species).displacements + species = set() + for sp in self.species: + assert isinstance(sp, Species), f'got {type(sp)=}' + if sp.symbol not in floating_species: + species.add(sp) + + displacements = self.filter(species=species).displacements # type: ignore else: displacements = self.displacements @@ -609,7 +620,11 @@ def filter(self, species: str | Collection[str]) -> Trajectory: if isinstance(species, str): species = [species] - idx = [sp.symbol in species for sp in self.species] + idx = [] + for sp in self.species: + assert isinstance(sp, (Species, Element)) + idx.append(sp.symbol in species) + new_coords = self.positions[:, idx] new_species = list(compress(self.species, idx)) diff --git a/src/gemdat/volume.py b/src/gemdat/volume.py index c23e455e..4b712b57 100644 --- a/src/gemdat/volume.py +++ b/src/gemdat/volume.py @@ -93,7 +93,7 @@ def from_volumetric_data(cls, volume: VolumetricData): Input volumetric data """ return cls( - data=volume.data, + data=volume.data['total'], lattice=volume.structure.lattice, ) @@ -506,5 +506,4 @@ def trajectory_to_volume( data=data, lattice=lattice, label='trajectory', - units=None, )