diff --git a/qpretrieve/fourier/__init__.py b/qpretrieve/fourier/__init__.py index e1d4383..c135f66 100644 --- a/qpretrieve/fourier/__init__.py +++ b/qpretrieve/fourier/__init__.py @@ -1,5 +1,6 @@ # flake8: noqa: F401 import warnings +from typing import Type from .base import FFTFilter from .ff_numpy import FFTFilterNumpy @@ -12,7 +13,7 @@ PREFERRED_INTERFACE = None -def get_available_interfaces() -> list: +def get_available_interfaces() -> list[Type[FFTFilter]]: """Return a list of available FFT algorithms""" interfaces = [ FFTFilterPyFFTW, @@ -25,7 +26,7 @@ def get_available_interfaces() -> list: return interfaces_available -def get_best_interface(): +def get_best_interface() -> Type[FFTFilter]: """Return the fastest refocusing interface available If `pyfftw` is installed, :class:`.FFTFilterPyFFTW` diff --git a/qpretrieve/interfere/base.py b/qpretrieve/interfere/base.py index 05a1c44..7bd4323 100644 --- a/qpretrieve/interfere/base.py +++ b/qpretrieve/interfere/base.py @@ -1,5 +1,6 @@ import warnings from abc import ABC, abstractmethod +from typing import Type import numpy as np @@ -25,7 +26,7 @@ class BaseInterferogram(ABC): } def __init__(self, data: np.ndarray, - fft_interface: str | FFTFilter = "auto", + fft_interface: str | Type[FFTFilter] = "auto", subtract_mean=True, padding=2, copy=True, **pipeline_kws) -> None: """ diff --git a/qpretrieve/interfere/if_qlsi.py b/qpretrieve/interfere/if_qlsi.py index a8af42f..ed1ceaf 100644 --- a/qpretrieve/interfere/if_qlsi.py +++ b/qpretrieve/interfere/if_qlsi.py @@ -24,12 +24,11 @@ class QLSInterferogram(BaseInterferogram): def __init__(self, data, reference=None, *args, **kwargs): super(QLSInterferogram, self).__init__(data, *args, **kwargs) - ff_iface = get_best_interface() if reference is not None: - self.fft_ref = ff_iface(data=reference, - subtract_mean=self.fft.subtract_mean, - padding=self.fft.padding) + self.fft_ref = self.ff_iface(data=reference, + subtract_mean=self.fft.subtract_mean, + padding=self.fft.padding) else: self.fft_ref = None