Skip to content

Commit

Permalink
viz.spectogram refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadtourei committed Jan 23, 2025
1 parent 8977e1e commit 91a4c8b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
29 changes: 15 additions & 14 deletions dascore/viz/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@
from collections.abc import Sequence
from typing import Literal

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import spectrogram as scipy_spectrogram

from dascore.constants import PatchType
from dascore.utils.patch import patch_function
from dascore.utils.plotting import _get_ax, _get_cmap


def _get_other_dim(dim, dims):
Expand All @@ -25,6 +21,7 @@ def _get_other_dim(dim, dims):
else:
return dims[0] if dims[1] == dim else dims[1]


@patch_function()
def spectrogram(
patch: PatchType,
Expand All @@ -47,11 +44,11 @@ def spectrogram(
ax
A matplotlib object, if None create one.
dim
Dimension along which spectogram is being plotted.
Dimension along which spectogram is being plotted.
Default is "time"
aggr_domain
"time" or "frequency" in which the mean value of the other
dimension is caluclated. No need to specify if other dimension's
"time" or "frequency" in which the mean value of the other
dimension is caluclated. No need to specify if other dimension's
coord size is 1.
Default is "frequency"
cmap
Expand All @@ -75,18 +72,22 @@ def spectrogram(
"""
dims = patch.dims
if len(dims) > 2 or len(dims) < 1:
raise ValueError(f"Can only make spectogram of 1D or 2D patches.")
other_dim = _get_other_dim(dim, dims)
raise ValueError("Can only make spectogram of 1D or 2D patches.")

other_dim = _get_other_dim(dim, dims)
if other_dim is not None:
if aggr_domain=="time":
if aggr_domain == "time":
patch_aggr = patch.aggregate(other_dim, method="mean", dim_reduce="squeeze")
spec = patch_aggr.spectrogram(dim)
elif aggr_domain=="frequency":
elif aggr_domain == "frequency":
_spec = patch.spectrogram(dim).squeeze()
spec = _spec.aggregate(other_dim, method="mean").squeeze()
else:
raise ValueError(f"The aggr_domain '{aggr_domain}' should be either 'time' or 'frequency'.")
raise ValueError(
f"The aggr_domain '{aggr_domain}' should be either 'time' or 'frequency'."
)
else:
spec = patch.spectrogram(dim)
return spec.viz.waterfall(ax=ax, cmap=cmap, scale=scale, scale_type=scale_type, log=log, show=show)
return spec.viz.waterfall(
ax=ax, cmap=cmap, scale=scale, scale_type=scale_type, log=log, show=show
)
13 changes: 9 additions & 4 deletions tests/test_viz/test_spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
import matplotlib.pyplot as plt
import pytest

from dascore.viz.spectrogram import _get_other_dim
from dascore.viz.spectrogram import _get_other_dim


def test_get_other_dim_valid():
dims = ("time", "distance")
assert _get_other_dim("time", dims) == "distance"
assert _get_other_dim("distance", dims) == "time"


def test_get_other_dim_invalid():
dims = ("time", "distance")
with pytest.raises(ValueError, match="not in patch's dimensions"):
_get_other_dim("frequency", dims)


def test_get_other_dim_invalid_dim_type():
dims = ("time", "distance")
with pytest.raises(TypeError, match="Expected 'dim' to be a string"):
Expand All @@ -31,7 +34,7 @@ def spectro_axis(self, random_patch):
"""Return the axis from the spectrogram function."""
patch = random_patch.aggregate(dim="distance")
return patch.viz.spectrogram()

def test_axis_returned(self, random_patch):
"""Ensure a matplotlib axis is returned."""
axis = random_patch.viz.spectrogram(dim="time")
Expand All @@ -50,8 +53,10 @@ def test_invalid_aggr_domain(self, random_patch):

def test_invalid_patch_dims(self, random_patch):
"""Ensure ValueError is raised for patches with invalid dimensions."""
patch_3D = random_patch.correlate(distance=[0,1])
with pytest.raises(ValueError, match="Can only make spectogram of 1D or 2D patches"):
patch_3D = random_patch.correlate(distance=[0, 1])
with pytest.raises(
ValueError, match="Can only make spectogram of 1D or 2D patches"
):
patch_3D.viz.spectrogram(dim="distance")

def test_1d_patch(self, random_patch):
Expand Down

0 comments on commit 91a4c8b

Please sign in to comment.