Skip to content

Commit

Permalink
Fix rdf plotly test (#320)
Browse files Browse the repository at this point in the history
* Refactor plot comparison

* Fix radial distribution test
  • Loading branch information
stefsmeets authored May 29, 2024
1 parent f0b4d9b commit e1a165d
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 62 deletions.
49 changes: 49 additions & 0 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import inspect
from functools import partial
from pathlib import Path
from typing import Any

from matplotlib.testing.compare import compare_images
from matplotlib.testing.decorators import image_comparison

image_comparison2 = partial(
Expand All @@ -10,3 +14,48 @@
extensions=['png'],
savefig_kwarg={'bbox_inches': 'tight'},
)


def assert_figures_similar(fig, *, name: str, ext: str = 'png', rms: float = 0.0):
"""Compare plotly figures and raise if different."""
# Ensure same font is used on different machines (local/CI)
fig.update_layout(
font_family='Arial',
title_font_family='Arial',
)

# Get path of caller
frame = inspect.stack()[1]
module = inspect.getmodule(frame[0])
modulepath = Path(module.__file__) # type: ignore

results_dir = Path() / 'result_images' / modulepath.stem
results_dir.mkdir(exist_ok=True, parents=True)

filename = f'{name}.{ext}'

actual = results_dir / filename
fig.write_image(actual)

expected_dir = modulepath.parent / 'baseline_images' / modulepath.stem
expected = expected_dir / filename
expected_link = results_dir / f'{name}-expected.{ext}'

if expected_link.exists():
expected_link.unlink()

expected_link.symlink_to(expected)

err: dict[str, Any] = compare_images(
expected=str(expected_link), actual=str(actual), tol=rms, in_decorator=True
) # type: ignore

if err:
for key in ('actual', 'expected', 'diff'):
err[key] = Path(err[key]).relative_to('.')
raise AssertionError(
(
'images not close (RMS {rms:.3f}):'
'\n\t{actual}\n\t{expected}\n\t{diff}'.format(**err)
)
)
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
80 changes: 18 additions & 62 deletions tests/integration/plot_plotly_test.py
Original file line number Diff line number Diff line change
@@ -1,121 +1,77 @@
from __future__ import annotations

from helpers import assert_figures_similar
import numpy as np

from gemdat.io import load_known_material
from gemdat.plots import plotly as plots
from pathlib import Path
import pytest

from typing import Any

from matplotlib.testing.compare import compare_images


def assert_image_similar(fig, *, name: str, ext: str = 'png', rms: float = 0.0):
# Ensure same font is used on different machines (local/CI)
fig.update_layout(
font_family='Arial',
title_font_family='Arial',
)

RESULTS_DIR = Path() / 'result_images' / Path(__file__).stem
RESULTS_DIR.mkdir(exist_ok=True, parents=True)

EXPECTED_DIR = Path(__file__).parent / 'baseline_images' / Path(__file__).stem

filename = f'{name}.{ext}'

actual = RESULTS_DIR / filename
fig.write_image(actual)

expected = EXPECTED_DIR / filename
expected_link = RESULTS_DIR / f'{name}-expected.{ext}'

if expected_link.exists():
expected_link.unlink()

expected_link.symlink_to(expected)

err: dict[str, Any] = compare_images(
expected=str(expected_link), actual=str(actual), tol=rms, in_decorator=True
) # type: ignore

if err:
for key in ('actual', 'expected', 'diff'):
err[key] = Path(err[key]).relative_to('.')
raise AssertionError(
(
'images not close (RMS {rms:.3f}):'
'\n\t{actual}\n\t{expected}\n\t{diff}'.format(**err)
)
)


def test_displacement_per_element(vasp_traj):
fig = plots.displacement_per_element(trajectory=vasp_traj)

assert_image_similar(fig, name='displacement_per_element', rms=0.5)
assert_figures_similar(fig, name='displacement_per_element', rms=0.5)


def test_displacement_per_atom(vasp_traj):
diff_trajectory = vasp_traj.filter('Li')
fig = plots.displacement_per_atom(trajectory=diff_trajectory)

assert_image_similar(fig, name='displacement_per_atom', rms=0.5)
assert_figures_similar(fig, name='displacement_per_atom', rms=0.5)


def test_displacement_histogram(vasp_traj):
diff_trajectory = vasp_traj.filter('Li')
fig = plots.displacement_histogram(trajectory=diff_trajectory)

assert_image_similar(fig, name='displacement_histogram', rms=0.5)
assert_figures_similar(fig, name='displacement_histogram', rms=0.5)


def test_frequency_vs_occurence(vasp_traj):
diff_traj = vasp_traj.filter('Li')
fig = plots.frequency_vs_occurence(trajectory=diff_traj)

assert_image_similar(fig, name='frequency_vs_occurence', rms=0.5)
assert_figures_similar(fig, name='frequency_vs_occurence', rms=0.5)


def test_vibrational_amplitudes(vasp_traj):
diff_traj = vasp_traj.filter('Li')
fig = plots.vibrational_amplitudes(trajectory=diff_traj)

assert_image_similar(fig, name='vibrational_amplitudes', rms=0.5)
assert_figures_similar(fig, name='vibrational_amplitudes', rms=0.5)


def test_jumps_vs_distance(vasp_jumps):
fig = plots.jumps_vs_distance(jumps=vasp_jumps)

assert_image_similar(fig, name='jumps_vs_distance', rms=0.5)
assert_figures_similar(fig, name='jumps_vs_distance', rms=0.5)


def test_jumps_vs_time(vasp_jumps):
fig = plots.jumps_vs_time(jumps=vasp_jumps)

assert_image_similar(fig, name='jumps_vs_time', rms=0.5)
assert_figures_similar(fig, name='jumps_vs_time', rms=0.5)


def test_collective_jumps(vasp_jumps):
fig = plots.collective_jumps(jumps=vasp_jumps)

assert_image_similar(fig, name='collective_jumps', rms=0.5)
assert_figures_similar(fig, name='collective_jumps', rms=0.5)


def test_jumps_3d(vasp_jumps):
fig = plots.jumps_3d(jumps=vasp_jumps)

assert_image_similar(fig, name='jumps_3d', rms=0.5)
assert_figures_similar(fig, name='jumps_3d', rms=0.5)


def test_radial_distribution(vasp_rdf_data):
assert len(vasp_rdf_data) == 3
for rdfs in vasp_rdf_data.values():
for i, rdfs in enumerate(vasp_rdf_data.values()):
fig = plots.radial_distribution(rdfs)

assert_image_similar(fig, name='radial_distribution', rms=0.5)
assert_figures_similar(fig, name=f'radial_distribution_{i}', rms=0.5)


@pytest.mark.xfail(reason='not implemented yet')
Expand All @@ -124,20 +80,20 @@ def test_shape(vasp_shape_data):
for i, shape in vasp_shape_data:
fig = plots.shape(shape)

assert_image_similar(fig, name='shape_{i}', rms=0.5)
assert_figures_similar(fig, name='shape_{i}', rms=0.5)


def test_msd_per_element(vasp_traj):
fig = plots.msd_per_element(trajectory=vasp_traj[-500:])

assert_image_similar(fig, name='msd_per_element', rms=0.5)
assert_figures_similar(fig, name='msd_per_element', rms=0.5)


def test_energy_along_path(vasp_path):
structure = load_known_material('argyrodite')
fig = plots.energy_along_path(path=vasp_path, structure=structure)

assert_image_similar(fig, name='energy_along_path', rms=0.5)
assert_figures_similar(fig, name='energy_along_path', rms=0.5)


def test_rectilinear(vasp_orientations):
Expand All @@ -152,16 +108,16 @@ def test_rectilinear(vasp_orientations):
orientations = vasp_orientations.normalize().transform(matrix=matrix)
fig = plots.rectilinear(orientations=orientations, normalize_histo=False)

assert_image_similar(fig, name='rectilinear', rms=0.5)
assert_figures_similar(fig, name='rectilinear', rms=0.5)


def test_bond_length_distribution(vasp_orientations):
fig = plots.bond_length_distribution(orientations=vasp_orientations, bins=50)

assert_image_similar(fig, name='bond_length_distribution', rms=0.5)
assert_figures_similar(fig, name='bond_length_distribution', rms=0.5)


def test_autocorrelation(vasp_orientations):
fig = plots.autocorrelation(orientations=vasp_orientations)

assert_image_similar(fig, name='autocorrelation', rms=0.5)
assert_figures_similar(fig, name='autocorrelation', rms=0.5)

0 comments on commit e1a165d

Please sign in to comment.