diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index 15c10587..9f46e23a 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -1,13 +1,15 @@ """Define object to manage algorithm implementations.""" +import numpy as np + from specparam.utils.checks import check_input_options -from specparam.algorithms.settings import SettingsDefinition +from specparam.algorithms.settings import SettingsDefinition, SettingsValues +from specparam.modutils.docs import docs_get_section, replace_docstring_sections ################################################################################################### ################################################################################################### -FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms'] - +DATA_FORMATS = ['spectrum', 'spectra', 'spectrogram', 'spectrograms'] class Algorithm(): """Template object for defining a fit algorithm. @@ -18,25 +20,43 @@ class Algorithm(): Name of the fitting algorithm. description : str Description of the fitting algorithm. - settings : dict - Name and description of settings for the fitting algorithm. - format : {'spectrum', 'spectra', 'spectrogram', 'spectrograms'} - Set base format of data model can be applied to. + public_settings : SettingsDefinition or dict + Name and description of public settings for the fitting algorithm. + private_settings : SettingsDefinition or dict, optional + Name and description of private settings for the fitting algorithm. + data_format : {'spectrum', 'spectra', 'spectrogram', 'spectrograms'} + Set base data format the model can be applied to. + modes : Modes + Modes object with fit mode definitions. + data : Data + Data object with spectral data and metadata. + results : Results + Results object with model fit results and metrics. + debug : bool + Whether to run in debug state, raising an error if encountered during fitting. """ - def __init__(self, name, description, settings, format, - modes=None, data=None, results=None, debug=False): + def __init__(self, name, description, public_settings, private_settings=None, + data_format='spectrum', modes=None, data=None, results=None, debug=False): """Initialize Algorithm object.""" self.name = name self.description = description - if not isinstance(settings, SettingsDefinition): - settings = SettingsDefinition(settings) - self.settings = settings + if not isinstance(public_settings, SettingsDefinition): + public_settings = SettingsDefinition(public_settings) + self.public_settings = public_settings + self.settings = SettingsValues(self.public_settings.names) + + if private_settings is None: + private_settings = {} + if not isinstance(private_settings, SettingsDefinition): + private_settings = SettingsDefinition(private_settings) + self.private_settings = private_settings + self._settings = SettingsValues(self.private_settings.names) - check_input_options(format, FORMATS, 'format') - self.format = format + check_input_options(data_format, DATA_FORMATS, 'data_format') + self.data_format = data_format self.modes = None self.data = None @@ -60,13 +80,11 @@ def add_settings(self, settings): Parameters ---------- settings : ModelSettings - A data object containing the settings for a power spectrum model. + A data object containing model settings. """ for setting in settings._fields: - setattr(self, setting, getattr(settings, setting)) - - self._check_loaded_settings(settings._asdict()) + setattr(self.settings, setting, getattr(settings, setting)) def get_settings(self): @@ -78,8 +96,8 @@ def get_settings(self): Object containing the settings from the current object. """ - return self.settings.make_model_settings()(\ - **{key : getattr(self, key) for key in self.settings.names}) + return self.public_settings.make_model_settings()(\ + **{key : getattr(self.settings, key) for key in self.public_settings.names}) def get_debug(self): @@ -100,32 +118,6 @@ def set_debug(self, debug): self._debug = debug - def _check_loaded_settings(self, data): - """Check if settings added, and update the object as needed. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If settings not loaded from file, clear from object, so that default - # settings, which are potentially wrong for loaded data, aren't kept - if not set(self.settings.names).issubset(set(data.keys())): - - # Reset all public settings to None - for setting in self.settings.names: - setattr(self, setting, None) - - # Reset internal settings so that they are consistent with what was loaded - # Note that this will set internal settings to None, if public settings unavailable - self._reset_internal_settings() - - - def _reset_internal_settings(self): - """"Can be overloaded if any resetting needed for internal settings.""" - - def _reset_subobjects(self, modes=None, data=None, results=None): """Reset links to sub-objects (mode / data / results). @@ -145,3 +137,83 @@ def _reset_subobjects(self, modes=None, data=None, results=None): self.data = data if results is not None: self.results = results + + +## AlgorithmCF + +CURVE_FIT_SETTINGS = SettingsDefinition({ + 'maxfev' : { + 'type' : 'int', + 'description' : 'The maximum number of calls to the curve fitting function.', + }, + 'tol' : { + 'type' : 'float', + 'description' : \ + 'The tolerance setting for curve fitting (see scipy.curve_fit: ftol / xtol / gtol).' + }, +}) + +@replace_docstring_sections([docs_get_section(Algorithm.__doc__, 'Parameters')]) +class AlgorithmCF(Algorithm): + """Template object for defining a fit algorithm that uses `curve_fit`. + + Parameters + ---------- + % copied in from Algorithm + """ + + def __init__(self, name, description, public_settings, private_settings=None, + data_format='spectrum', modes=None, data=None, results=None, debug=False): + """Initialize Algorithm object.""" + + Algorithm.__init__(self, name=name, description=description, + public_settings=public_settings, private_settings=private_settings, + data_format=data_format, modes=modes, data=data, results=results, + debug=debug) + + self._cf_settings_desc = CURVE_FIT_SETTINGS + self._cf_settings = SettingsValues(self._cf_settings_desc.names) + + + def _initialize_bounds(self, mode): + """Initialize a bounds definition. + + Parameters + ---------- + mode : {'aperiodic', 'periodic'} + Which mode to initialize for. + + Returns + ------- + bounds : tuple of array + Bounds values. + + Notes + ----- + Output follows the needed bounds definition for curve_fit, which is: + ([low_bound_param1, low_bound_param2], + [high_bound_param1, high_bound_param2]) + """ + + n_params = getattr(self.modes, mode).n_params + bounds = (np.array([-np.inf] * n_params), np.array([np.inf] * n_params)) + + return bounds + + def _initialize_guess(self, mode): + """Initialize a guess definition. + + Parameters + ---------- + mode : {'aperiodic', 'periodic'} + Which mode to initialize for. + + Returns + ------- + guess : 1d array + Guess values. + """ + + guess = np.zeros([getattr(self.modes, mode).n_params]) + + return guess diff --git a/specparam/algorithms/settings.py b/specparam/algorithms/settings.py index 1fa601c0..14743188 100644 --- a/specparam/algorithms/settings.py +++ b/specparam/algorithms/settings.py @@ -5,48 +5,132 @@ ################################################################################################### ################################################################################################### +class SettingsValues(): + """Defines a set of algorithm settings values. + + Parameters + ---------- + names : list of str + Names of the settings to hold values for. + + Attributes + ---------- + values : dict of {str : object} + Settings values. + """ + + __slots__ = ('values',) + + def __init__(self, names): + """Initialize settings values.""" + + self.values = {name : None for name in names} + + + def __getattr__(self, name): + """Allow for accessing settings values as attributes.""" + + try: + return self.values[name] + except KeyError: + raise AttributeError(name) + + + def __setattr__(self, name, value): + """Allow for setting settings values as attributes.""" + + if name == 'values': + super().__setattr__(name, value) + else: + getattr(self, name) + self.values[name] = value + + + def __getstate__(self): + """Define how to get object state - for pickling.""" + + return self.values + + + def __setstate__(self, state): + """Define how to set object state - for pickling.""" + + self.values = state + + + @property + def names(self): + """Property attribute for settings names.""" + + return list(self.values.keys()) + + + def clear(self): + """Clear all settings - resetting to None.""" + + for setting in self.names: + self.values[setting] = None + + class SettingsDefinition(): """Defines a set of algorithm settings. Parameters ---------- - settings : dict + definitions : dict Settings definition. Each key should be a str name of a setting. Each value should be a dictionary with keys 'type' and 'description', with str values. + + Attributes + ---------- + names : list of str + Names of the settings defined in the object. + descriptions : dict of {str : str} + Description of each setting. + types : dict of {str : str} + Type for each setting. + values : dict of {str : object} + Value of each setting. """ - def __init__(self, settings): + def __init__(self, definitions): """Initialize settings definition.""" - self._settings = settings + self._definitions = definitions + + def __len__(self): + """Define the length of the object as the number of settings.""" - def _get_settings_subdict(self, field): - """Helper function to select from settings dictionary.""" + return len(self._definitions) - return {label : self._settings[label][field] for label in self._settings.keys()} + + def _get_definitions_subdict(self, field): + """Helper function to select from definitions dictionary.""" + + return {label : self._definitions[label][field] for label in self._definitions.keys()} @property def names(self): """Make property alias for setting names.""" - return list(self._settings.keys()) + return list(self._definitions.keys()) @property def types(self): """Make property alias for setting types.""" - return self._get_settings_subdict('type') + return self._get_definitions_subdict('type') @property def descriptions(self): """Make property alias for setting descriptions.""" - return self._get_settings_subdict('description') + return self._get_definitions_subdict('description') def make_setting_str(self, name): @@ -91,6 +175,11 @@ def make_model_settings(self): class ModelSettings(namedtuple('ModelSettings', self.names)): __slots__ = () + + @property + def names(self): + return list(self._fields) + ModelSettings.__doc__ = self.make_docstring() return ModelSettings diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 4d76918f..763e78de 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -1,6 +1,7 @@ """Define original spectral fitting algorithm object.""" import warnings +from itertools import repeat import numpy as np from numpy.linalg import LinAlgError @@ -10,13 +11,13 @@ from specparam.utils.select import groupby from specparam.reports.strings import gen_width_warning_str from specparam.measures.params import compute_gauss_std -from specparam.algorithms.algorithm import Algorithm +from specparam.algorithms.algorithm import AlgorithmCF from specparam.algorithms.settings import SettingsDefinition ################################################################################################### ################################################################################################### -SPECTRAL_FIT_SETTINGS = SettingsDefinition({ +SPECTRAL_FIT_SETTINGS_DEF = SettingsDefinition({ 'peak_width_limits' : { 'type' : 'tuple of (float, float), optional, default: (0.5, 12.0)', 'description' : 'Limits on possible peak width, in Hz, as (lower_bound, upper_bound).', @@ -28,64 +29,75 @@ 'min_peak_height' : { 'type' : 'float, optional, default: 0', 'description' : \ - 'Absolute threshold for detecting peaks.\n ' \ + 'Absolute threshold for detecting peaks.' + '\n ' 'This threshold is defined in absolute units of the power spectrum (log power).', }, 'peak_threshold' : { 'type' : 'float, optional, default: 2.0', 'description' : \ - 'Relative threshold for detecting peaks.\n ' \ + 'Relative threshold for detecting peaks.' + '\n ' 'Threshold is defined in relative units of the power spectrum (standard deviation).', }, }) -class SpectralFitAlgorithm(Algorithm): +SPECTRAL_FIT_PRIVATE_SETTINGS_DEF = SettingsDefinition({ + 'ap_percentile_thresh' : { + 'type' : 'float', + 'description' : \ + 'Percentile threshold to select data from flat spectrum for an initial aperiodic fit.' + '\n ' + 'Points are selected at a low percentile value to restrict to non-peak points.', + }, + 'ap_guess' : { + 'type' : 'list of float', + 'description' : \ + 'Guess parameters for fitting the aperiodic component.' + '\n ' + 'The guess parameters should match the length and order of the aperiodic parameters.' + '\n ' + 'If \'offset\' is a parameter, default guess is the first value of the power spectrum.' + '\n ' + 'If \'exponent\' is a parameter, ' + 'default guess is the abs(log-log slope) of first & last points.' + }, + 'ap_bounds' : { + 'type' : 'tuple of tuple of float', + 'description' : \ + 'Bounds for aperiodic fitting, as ((param1_low_bound, ...) (param1_high_bound, ...)).' + '\n ' + 'By default, aperiodic fitting is unbound, but can be restricted here.', + }, + 'cf_bound' : { + 'type' : 'float', + 'description' : \ + 'Parameter bounds for center frequency when fitting peaks, as +/- std dev.', + }, + 'bw_std_edge' : { + 'type' : 'float', + 'description' : \ + 'Threshold for how far a peak has to be from edge to keep.' + '\n ' + 'This is defined in units of peak standard deviation.', + }, + 'gauss_overlap_thresh' : { + 'type' : 'float', + 'description' : \ + 'Degree of overlap between peak guesses for one to be dropped.' + '\n ' + 'This is defined in units of peak standard deviation.', + }, +}) + + +class SpectralFitAlgorithm(AlgorithmCF): """Base object defining model & algorithm for spectral parameterization. Parameters ---------- % public settings described in Spectral Fit Algorithm Settings - _ap_percentile_thresh : float - Percentile threshold, to select points from a flat spectrum for an initial aperiodic fit - Points are selected at a low percentile value to restrict to non-peak points. - _ap_guess : list of [float, float, float] - Guess parameters for fitting the aperiodic component, as [offset, knee, exponent]. - If offset guess is None, the first value of the power spectrum is used as offset guess - If exponent guess is None, the abs(log-log slope) of first & last points is used - _ap_bounds : tuple of tuple of float - Bounds for aperiodic fitting, as: ((offset_low_bound, knee_low_bound, exp_low_bound), - (offset_high_bound, knee_high_bound, exp_high_bound)) - By default, aperiodic fitting is unbound, but can be restricted here. - Even if fitting without knee, leave bounds for knee (they are dropped later). - _cf_bound : float - Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev. - _bw_std_edge : float - Threshold for how far a peak has to be from edge to keep. - This is defined in units of gaussian standard deviation. - _gauss_overlap_thresh : float - Degree of overlap between gaussian guesses for one to be dropped. - This is defined in units of gaussian standard deviation. - _maxfev : int - The maximum number of calls to the curve fitting function. - _tol : float - The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol). - The default value reduce tolerance to speed fitting (as compared to curve_fit's default). - Set value to 1e-8 to match curve_fit default. - - Attributes - ---------- - _gauss_std_limits : list of [float, float] - Settings attribute: peak width limits, to use for gaussian standard deviation parameter. - This attribute is computed based on `peak_width_limits` and should not be updated directly. - _spectrum_flat : 1d array - Data attribute: flattened power spectrum, with the aperiodic component removed. - _spectrum_peak_rm : 1d array - Data attribute: power spectrum, with peaks removed. - _ap_fit : 1d array - Model attribute: values of the isolated aperiodic fit. - _peak_fit : 1d array - Model attribute: values of the isolated peak fit. """ # pylint: disable=attribute-defined-outside-init @@ -99,29 +111,29 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h super().__init__( name='spectral fit', description='Original parameterizing neural power spectra algorithm.', - settings=SPECTRAL_FIT_SETTINGS, format='spectrum', + public_settings=SPECTRAL_FIT_SETTINGS_DEF, + private_settings=SPECTRAL_FIT_PRIVATE_SETTINGS_DEF, modes=modes, data=data, results=results, debug=debug) ## Public settings - self.peak_width_limits = peak_width_limits - self.max_n_peaks = max_n_peaks - self.min_peak_height = min_peak_height - self.peak_threshold = peak_threshold + self.settings.peak_width_limits = peak_width_limits + self.settings.max_n_peaks = max_n_peaks + self.settings.min_peak_height = min_peak_height + self.settings.peak_threshold = peak_threshold ## Private settings: model parameters related settings - self._ap_percentile_thresh = ap_percentile_thresh - self._ap_guess = ap_guess - self._set_ap_bounds(ap_bounds) - self._cf_bound = cf_bound - self._bw_std_edge = bw_std_edge - self._gauss_overlap_thresh = gauss_overlap_thresh + self._settings.ap_percentile_thresh = ap_percentile_thresh + self._settings.ap_guess = ap_guess + self._settings.ap_bounds = self._get_ap_bounds(ap_bounds) + self._settings.cf_bound = cf_bound + self._settings.bw_std_edge = bw_std_edge + self._settings.gauss_overlap_thresh = gauss_overlap_thresh - ## Private setting: curve_fit related settings - self._maxfev = maxfev - self._tol = tol - - ## Set internal settings, based on inputs, and initialize data & results attributes - self._reset_internal_settings() + ## curve_fit settings + # Note - default reduces tolerance to speed fitting (as compared to curve_fit's default). + # Set value to 1e-8 to match curve_fit default. + self._cf_settings.maxfev = maxfev + self._cf_settings.tol = tol def _fit_prechecks(self, verbose=True): @@ -134,8 +146,9 @@ def _fit_prechecks(self, verbose=True): """ if verbose: - if 1.5 * self.data.freq_res >= self.peak_width_limits[0]: - print(gen_width_warning_str(self.data.freq_res, self.peak_width_limits[0])) + if 1.5 * self.data.freq_res >= self.settings.peak_width_limits[0]: + print(gen_width_warning_str(self.data.freq_res, + self.settings.peak_width_limits[0])) def _fit(self): @@ -144,56 +157,37 @@ def _fit(self): ## FIT PROCEDURES # Take an initial fit of the aperiodic component - temp_aperiodic_params_ = self._robust_ap_fit(self.data.freqs, self.data.power_spectrum) - temp_ap_fit = self.modes.aperiodic.func(self.data.freqs, *temp_aperiodic_params_) + temp_aperiodic_params = self._robust_ap_fit(self.data.freqs, self.data.power_spectrum) + temp_ap_fit = self.modes.aperiodic.func(self.data.freqs, *temp_aperiodic_params) - # Find peaks from the flattened power spectrum, and fit them with gaussians + # Find peaks from the flattened power spectrum, and fit them temp_spectrum_flat = self.data.power_spectrum - temp_ap_fit - self.results.gaussian_params_ = self._fit_peaks(temp_spectrum_flat) + self.results.params.gaussian = self._fit_peaks(temp_spectrum_flat) # Calculate the peak fit # Note: if no peaks are found, this creates a flat (all zero) peak fit - self.results._peak_fit = self.modes.periodic.func(\ - self.data.freqs, *np.ndarray.flatten(self.results.gaussian_params_)) + self.results.model._peak_fit = self.modes.periodic.func(\ + self.data.freqs, *np.ndarray.flatten(self.results.params.gaussian)) # Create peak-removed (but not flattened) power spectrum - self.results._spectrum_peak_rm = self.data.power_spectrum - self.results._peak_fit + self.results.model._spectrum_peak_rm = \ + self.data.power_spectrum - self.results.model._peak_fit # Run final aperiodic fit on peak-removed power spectrum - self.results.aperiodic_params_ = self._simple_ap_fit(\ - self.data.freqs, self.results._spectrum_peak_rm) - self.results._ap_fit = self.modes.aperiodic.func(\ - self.data.freqs, *self.results.aperiodic_params_) + self.results.params.aperiodic = self._simple_ap_fit(\ + self.data.freqs, self.results.model._spectrum_peak_rm) + self.results.model._ap_fit = self.modes.aperiodic.func(\ + self.data.freqs, *self.results.params.aperiodic) # Create remaining model components: flatspec & full power_spectrum model fit - self.results._spectrum_flat = self.data.power_spectrum - self.results._ap_fit - self.results.modeled_spectrum_ = self.results._peak_fit + self.results._ap_fit + self.results.model._spectrum_flat = self.data.power_spectrum - self.results.model._ap_fit + self.results.model.modeled_spectrum = \ + self.results.model._peak_fit + self.results.model._ap_fit ## PARAMETER UPDATES - # Convert gaussian definitions to peak parameters - self.results.peak_params_ = self._create_peak_params(self.results.gaussian_params_) - - - def _reset_internal_settings(self): - """Set, or reset, internal settings, based on what is provided in init. - - Notes - ----- - These settings are for internal use, based on what is provided to, or set in `__init__`. - They should not be altered by the user. - """ - - # Only update these settings if other relevant settings are available - if self.peak_width_limits: - - # Bandwidth limits are given in 2-sided peak bandwidth - # Convert to gaussian std parameter limits - self._gauss_std_limits = tuple(bwl / 2 for bwl in self.peak_width_limits) - - # Otherwise, assume settings are unknown (have been cleared) and set to None - else: - self._gauss_std_limits = None + # Convert fit peak parameters to updated values + self.results.params.peak = self._create_peak_params(self.results.params.gaussian) def _get_ap_guess(self, freqs, power_spectrum): @@ -206,32 +200,35 @@ def _get_ap_guess(self, freqs, power_spectrum): ToDo - Could be updated to fill in missing guesses. """ - if not self._ap_guess: + if not self._settings.ap_guess: + + ap_guess = self._initialize_guess('aperiodic') - ap_guess = [] - for label in self.modes.aperiodic.params.labels: + for label, ind in self.modes.aperiodic.params.indices.items(): if label == 'offset': # Offset guess is the power value for lowest available frequency - ap_guess.append(power_spectrum[0]) + ap_guess[ind] = power_spectrum[0] elif 'exponent' in label: # Exponent guess is a quick calculation of the log-log slope - ap_guess.append(np.abs((power_spectrum[-1] - power_spectrum[0]) / - (np.log10(freqs[-1]) - np.log10(freqs[0])))) - elif 'knee' in label: - # Knee guess set to zero (no real guess) - ap_guess.append(0) - else: - # Any other (un-anticipated) parameter set to guess of 0 - ap_guess.append(0) - - ap_guess = np.array(ap_guess) + ap_guess[ind] = np.abs((power_spectrum[-1] - power_spectrum[0]) / + (np.log10(freqs[-1]) - np.log10(freqs[0]))) return ap_guess - def _set_ap_bounds(self, ap_bounds): + def _get_ap_bounds(self, ap_bounds): """Set the default bounds for the aperiodic fit. + Parameters + ---------- + bounds : tuple of tuple or None + Bounds definition. If None, creates default bounds. + + Returns + ------- + bounds : tuple of tuple + Bounds definition. + Notes ----- The bounds for aperiodic parameters are set in general, and currently do not update @@ -240,12 +237,11 @@ def _set_ap_bounds(self, ap_bounds): if ap_bounds: msg = 'Provided aperiodic bounds do not have right length for fit function.' - assert len(self._ap_bounds[0]) == len(self._ap_bounds[1]) == \ - self.modes.aperiodic.n_params, msg - self._ap_bounds = ap_bounds + assert len(ap_bounds[0]) == len(ap_bounds[1]) == self.modes.aperiodic.n_params, msg else: - self._ap_bounds = (tuple([-np.inf] * self.modes.aperiodic.n_params), - tuple([np.inf] * self.modes.aperiodic.n_params)) + ap_bounds = self._initialize_bounds('aperiodic') + + return ap_bounds def _simple_ap_fit(self, freqs, power_spectrum): @@ -275,9 +271,12 @@ def _simple_ap_fit(self, freqs, power_spectrum): with warnings.catch_warnings(): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs, power_spectrum, - p0=ap_guess, bounds=self._ap_bounds, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + p0=ap_guess, bounds=self._settings.ap_bounds, + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding parameters in " "the simple aperiodic component fit.") @@ -318,7 +317,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): flatspec[flatspec < 0] = 0 # Use percentile threshold, in terms of # of points, to extract and re-fit - perc_thresh = np.percentile(flatspec, self._ap_percentile_thresh) + perc_thresh = np.percentile(flatspec, self._settings.ap_percentile_thresh) perc_mask = flatspec <= perc_thresh freqs_ignore = freqs[perc_mask] spectrum_ignore = power_spectrum[perc_mask] @@ -330,9 +329,12 @@ def _robust_ap_fit(self, freqs, power_spectrum): warnings.simplefilter("ignore") aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs_ignore, spectrum_ignore, - p0=popt, bounds=self._ap_bounds, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + p0=popt, bounds=self._settings.ap_bounds, + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " "parameters in the robust aperiodic fit.") @@ -355,9 +357,8 @@ def _fit_peaks(self, flatspec): Returns ------- - gaussian_params : 2d array - Parameters that define the gaussian fit(s). - Each row is a gaussian, as [mean, height, standard deviation]. + peak_params : 2d array + Parameters that define the peak fit(s). """ # Take a copy of the flattened spectrum to iterate across @@ -368,22 +369,22 @@ def _fit_peaks(self, flatspec): # Find peak: loop through, finding a candidate peak, & fit with a guess peak # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds - while len(guess) < self.max_n_peaks: + while len(guess) < self.settings.max_n_peaks: # Find candidate peak - the maximum point of the flattened spectrum max_ind = np.argmax(flat_iter) max_height = flat_iter[max_ind] # Stop searching for peaks once height drops below height threshold - if max_height <= self.peak_threshold * np.std(flat_iter): + if max_height <= self.settings.peak_threshold * np.std(flat_iter): break - # Set the guess parameters for gaussian fitting, specifying the mean and height + # Set the guess parameters for peak fitting, specifying the mean and height guess_freq = self.data.freqs[max_ind] guess_height = max_height # Halt fitting process if candidate peak drops below minimum height - if not guess_height > self.min_peak_height: + if not guess_height > self.settings.min_peak_height: break # Data-driven first guess at standard deviation @@ -403,33 +404,33 @@ def _fit_peaks(self, flatspec): for ind in [le_ind, ri_ind] if ind is not None]) # Use the shortest side to estimate full-width, half max (converted to Hz) - # and use this to estimate that guess for gaussian standard deviation + # and use this to estimate that guess for peak standard deviation fwhm = short_side * 2 * self.data.freq_res guess_std = compute_gauss_std(fwhm) except ValueError: # This procedure can fail (very rarely), if both left & right inds end up as None # In this case, default the guess to the average of the peak width limits - guess_std = np.mean(self.peak_width_limits) + guess_std = np.mean(self.settings.peak_width_limits) # Check that guess value isn't outside preset limits - restrict if so + # This also converts the peak_width_limits from 2-sided BW to 1-sided std # Note: without this, curve_fitting fails if given guess > or < bounds - if guess_std < self._gauss_std_limits[0]: - guess_std = self._gauss_std_limits[0] - if guess_std > self._gauss_std_limits[1]: - guess_std = self._gauss_std_limits[1] - - # Collect guess parameters and subtract this guess gaussian from the data - current_guess_params = (guess_freq, guess_height, guess_std) - - ## TEMP - if self.modes.periodic.name == 'skewnorm': - guess_skew = 0 - current_guess_params = (guess_freq, guess_height, guess_std, guess_skew) - - guess = np.vstack((guess, current_guess_params)) - peak_gauss = self.modes.periodic.func(self.data.freqs, *current_guess_params) - flat_iter = flat_iter - peak_gauss + if guess_std < self.settings.peak_width_limits[0] / 2: + guess_std = self.settings.peak_width_limits[0] / 2 + if guess_std > self.settings.peak_width_limits[1] / 2: + guess_std = self.settings.peak_width_limits[0] / 2 + + # Collect guess parameters + cur_guess = [0] * self.modes.periodic.n_params + cur_guess[self.modes.periodic.params.indices['cf']] = guess_freq + cur_guess[self.modes.periodic.params.indices['pw']] = guess_height + cur_guess[self.modes.periodic.params.indices['bw']] = guess_std + + # Fit and subtract guess peak from the spectrum + guess = np.vstack((guess, cur_guess)) + peak_fit = self.modes.periodic.func(self.data.freqs, *cur_guess) + flat_iter = flat_iter - peak_fit # Check peaks based on edges, and on overlap, dropping any that violate requirements guess = self._drop_peak_cf(guess) @@ -437,42 +438,64 @@ def _fit_peaks(self, flatspec): # If there are peak guesses, fit the peaks, and sort results if len(guess) > 0: - gaussian_params = self._fit_peak_guess(flatspec, guess) - gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] + peak_params = self._fit_peak_guess(flatspec, guess) + peak_params = peak_params[peak_params[:, 0].argsort()] else: - gaussian_params = np.empty([0, self.modes.periodic.n_params]) + peak_params = np.empty([0, self.modes.periodic.n_params]) - return gaussian_params + return peak_params - ## TO GENERALIZE FOR MODES def _get_pe_bounds(self, guess): - """Get the bound for the peak fit.""" - - # Set the bounds for CF, enforce positive height value, and set bandwidth limits - # Note that 'guess' is in terms of gaussian std, so +/- BW is 2 * the guess_gauss_std - # This set of list comprehensions is a way to end up with bounds in the form: - # ((cf_low_peak1, height_low_peak1, bw_low_peak1, *repeated for n_peaks*), - # (cf_high_peak1, height_high_peak1, bw_high_peak, *repeated for n_peaks*)) - # ^where each value sets the bound on the specified parameter - lo_bound = [[peak[0] - 2 * self._cf_bound * peak[2], 0, self._gauss_std_limits[0]] - for peak in guess] - hi_bound = [[peak[0] + 2 * self._cf_bound * peak[2], np.inf, self._gauss_std_limits[1]] - for peak in guess] - - # Check that CF bounds are within frequency range - # If they are not, update them to be restricted to frequency range - lo_bound = [bound if bound[0] > self.data.freq_range[0] else \ - [self.data.freq_range[0], *bound[1:]] for bound in lo_bound] - hi_bound = [bound if bound[0] < self.data.freq_range[1] else \ - [self.data.freq_range[1], *bound[1:]] for bound in hi_bound] - - # Unpacks the embedded lists into flat tuples - # This is what the fit function requires as input - gaus_param_bounds = (tuple(item for sublist in lo_bound for item in sublist), - tuple(item for sublist in hi_bound for item in sublist)) - - return gaus_param_bounds + """Get the bound for the peak fit. + + Parameters + ---------- + guess : list + Guess parameters from initial peak search. + + Returns + ------- + pe_bounds : tuple of array + Bounds for periodic fit. + """ + + n_pe_params = self.modes.periodic.n_params + bounds = repeat(self._initialize_bounds('periodic')) + bounds_lo = np.empty(len(guess) * n_pe_params) + bounds_hi = np.empty(len(guess) * n_pe_params) + + for p_ind, peak in enumerate(guess): + for label, ind in self.modes.periodic.params.indices.items(): + + pbounds_lo, pbounds_hi = next(bounds) + + if label == 'cf': + # Set boundaries on CF, weighted by the bandwidth + peak_bw = peak[self.modes.periodic.params.indices['bw']] + lcf = peak[ind] - 2 * self._settings.cf_bound * peak_bw + hcf = peak[ind] + 2 * self._settings.cf_bound * peak_bw + # Check that CF bounds are within frequency range - if not restrict to range + pbounds_lo[ind] = lcf if lcf > self.data.freq_range[0] \ + else self.data.freq_range[0] + pbounds_hi[ind] = hcf if hcf < self.data.freq_range[1] \ + else self.data.freq_range[1] + + if label == 'pw': + # Enforce positive values for height + pbounds_lo[ind] = 0 + + if label == 'bw': + # Set bandwidth limits, converting limits from Hz to guess params in std + pbounds_lo[ind] = self.settings.peak_width_limits[0] / 2 + pbounds_hi[ind] = self.settings.peak_width_limits[1] / 2 + + bounds_lo[p_ind*n_pe_params:(p_ind+1)*n_pe_params] = pbounds_lo + bounds_hi[p_ind*n_pe_params:(p_ind+1)*n_pe_params] = pbounds_hi + + pe_bounds = (bounds_lo, bounds_hi) + + return pe_bounds def _fit_peak_guess(self, flatspec, guess): @@ -498,8 +521,11 @@ def _fit_peak_guess(self, flatspec, guess): p0=np.ndarray.flatten(guess), bounds=self._get_pe_bounds(guess), jac=self.modes.periodic.jacobian, - maxfev=self._maxfev, check_finite=False, - ftol=self._tol, xtol=self._tol, gtol=self._tol) + maxfev=self._cf_settings.maxfev, + check_finite=False, + ftol=self._cf_settings.tol, + xtol=self._cf_settings.tol, + gtol=self._cf_settings.tol) except RuntimeError as excp: error_msg = ("Model fitting failed due to not finding " @@ -517,7 +543,6 @@ def _fit_peak_guess(self, flatspec, guess): return pe_params - ## TO GENERALIZE FOR MODES def _drop_peak_cf(self, guess): """Check whether to drop peaks based on center's proximity to the edge of the spectrum. @@ -532,8 +557,8 @@ def _drop_peak_cf(self, guess): Guess parameters for periodic peak fits. Shape: [n_peaks, n_params_per_peak]. """ - cf_params = guess[:, 0] - bw_params = guess[:, 2] * self._bw_std_edge + cf_params = guess[:, self.modes.periodic.params.indices['cf']] + bw_params = guess[:, self.modes.periodic.params.indices['bw']] * self._settings.bw_std_edge # Check if peaks within drop threshold from the edge of the frequency range keep_peak = \ @@ -547,7 +572,7 @@ def _drop_peak_cf(self, guess): def _drop_peak_overlap(self, guess): - """Checks whether to drop gaussians based on amount of overlap. + """Checks whether to drop peaks based on amount of overlap. Parameters ---------- @@ -564,14 +589,17 @@ def _drop_peak_overlap(self, guess): For any peaks with an overlap > threshold, the lowest height guess peak is dropped. """ + inds = self.modes.periodic.params.indices + # Sort the peak guesses by increasing frequency # This is so adjacent peaks can be compared from right to left - guess = sorted(guess, key=lambda x: float(x[0])) + guess = sorted(guess, key=lambda x: float(x[inds['cf']])) # Calculate standard deviation bounds for checking amount of overlap - # The bounds are the gaussian frequency +/- gaussian standard deviation - bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh, - peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess] + # The bounds are the center frequency +/- width (standard deviation) + bounds = [[peak[inds['cf']] - peak[inds['bw']] * self._settings.gauss_overlap_thresh, + peak[inds['cf']] + peak[inds['bw']] * self._settings.gauss_overlap_thresh]\ + for peak in guess] # Loop through peak bounds, comparing current bound to that of next peak # If the left peak's upper bound extends pass the right peaks lower bound, @@ -583,7 +611,7 @@ def _drop_peak_overlap(self, guess): # Check if bound of current peak extends into next peak if b_0[1] > b_1[0]: - # If so, get the index of the gaussian with the lowest height (to drop) + # If so, get the index of the peak with the lowest height (to drop) drop_inds.append([ind, ind + 1][np.argmin([guess[ind][1], guess[ind + 1][1]])]) # Drop any peaks guesses that overlap too much, based on threshold @@ -593,52 +621,50 @@ def _drop_peak_overlap(self, guess): return guess - ## TO GENERALIZE FOR MODES - def _create_peak_params(self, gaus_params): - """Copies over the gaussian params to peak outputs, updating as appropriate. + def _create_peak_params(self, fit_peak_params): + """Copies over the fit peak parameters output parameters, updating as appropriate. Parameters ---------- - gaus_params : 2d array - Parameters that define the gaussian fit(s), as gaussian parameters. + fit_peak_params : 2d array + Parameters that define the peak parameters directly fit to the spectrum. Returns ------- peak_params : 2d array - Fitted parameter values for the peaks, with each row as [CF, PW, BW]. + Updated parameter values for the peaks. Notes ----- - The gaussian center is unchanged as the peak center frequency. + The center frequency estimate is unchanged as the peak center frequency. - The gaussian height is updated to reflect the height of the peak above - the aperiodic fit. This is returned instead of the gaussian height, as - the gaussian height is harder to interpret, due to peak overlaps. + The peak height is updated to reflect the height of the peak above + the aperiodic fit. This is returned instead of the fit peak height, as + the fit height is harder to interpret, due to peak overlaps. - The gaussian standard deviation is updated to be 'both-sided', to reflect the - 'bandwidth' of the peak, as opposed to the gaussian parameter, which is 1-sided. + The peak bandwidth is updated to be 'both-sided', to reflect the overal width + of the peak, as opposed to the fit parameter, which is 1-sided standard deviation. Performing this conversion requires that the model has been run, - with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. + with `freqs`, `modeled_spectrum` and `_ap_fit` all required to be available. """ - peak_params = np.empty((len(gaus_params), self.modes.periodic.n_params)) + inds = self.modes.periodic.params.indices + + peak_params = np.empty((len(fit_peak_params), self.modes.periodic.n_params)) - for ii, peak in enumerate(gaus_params): + for ii, peak in enumerate(fit_peak_params): + + cpeak = peak.copy() # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - ind = np.argmin(np.abs(self.data.freqs - peak[0])) - - # Collect peak parameter data - if self.modes.periodic.name == 'gaussian': ## TEMP - peak_params[ii] = [peak[0], - self.results.modeled_spectrum_[ind] - self.results._ap_fit[ind], - peak[2] * 2] - - ## TEMP: - if self.modes.periodic.name == 'skewnorm': - peak_params[ii] = [peak[0], - self.results.modeled_spectrum_[ind] - self.results._ap_fit[ind], - peak[2] * 2, peak[3]] + cf_ind = np.argmin(np.abs(self.data.freqs - peak[inds['cf']])) + cpeak[inds['pw']] = \ + self.results.model.modeled_spectrum[cf_ind] - self.results.model._ap_fit[cf_ind] + + # Bandwidth is updated to be 'two-sided' (as opposed to one-sided std dev) + cpeak[inds['bw']] = peak[inds['bw']] * 2 + + peak_params[ii] = cpeak return peak_params diff --git a/specparam/bands/bands.py b/specparam/bands/bands.py index 1d3ba29a..6e04c5f6 100644 --- a/specparam/bands/bands.py +++ b/specparam/bands/bands.py @@ -60,7 +60,7 @@ def __getitem__(self, label): raise ValueError(message) from None - def __repr__(self): + def __str__(self): """Define the string representation as a printout of the band information.""" return '\n'.join(['{:8} : {:2} - {:2} Hz'.format(key, *val) \ diff --git a/specparam/data/conversions.py b/specparam/data/conversions.py index 98258cbf..39f6a385 100644 --- a/specparam/data/conversions.py +++ b/specparam/data/conversions.py @@ -2,7 +2,7 @@ import numpy as np -from specparam.bands.bands import Bands, check_bands +from specparam.bands.bands import check_bands from specparam.modutils.dependencies import safe_import, check_dependency from specparam.data.periodic import get_band_peak_arr from specparam.data.utils import flatten_results_dict diff --git a/specparam/data/periodic.py b/specparam/data/periodic.py index 4be5ae22..753e7386 100644 --- a/specparam/data/periodic.py +++ b/specparam/data/periodic.py @@ -6,7 +6,7 @@ ################################################################################################### def get_band_peak(model, band, select_highest=True, threshold=None, - thresh_param='PW', attribute='peak_params'): + thresh_param='PW', attribute='peak'): """Extract peaks from a band of interest from a model object. Parameters @@ -23,7 +23,7 @@ def get_band_peak(model, band, select_highest=True, threshold=None, A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} + attribute : {'peak', 'gaussian'} Which attribute of peak data to extract data from. Returns @@ -42,11 +42,11 @@ def get_band_peak(model, band, select_highest=True, threshold=None, >>> betas = get_band_peak(model, [13, 30], select_highest=False) # doctest:+SKIP """ - return get_band_peak_arr(getattr(model.results, attribute + '_'), band, + return get_band_peak_arr(getattr(model.results.params, attribute), band, select_highest, threshold, thresh_param) -def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribute='peak_params'): +def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribute='peak'): """Extract peaks from a band of interest from a group model object. Parameters @@ -60,7 +60,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} + attribute : {'peak', 'gaussian'} Which attribute of peak data to extract data from. Returns @@ -99,7 +99,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut threshold, thresh_param) -def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak_params'): +def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribute='peak'): """Extract peaks from a band of interest from an event model object. Parameters @@ -116,7 +116,7 @@ def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribut A minimum threshold value to apply. thresh_param : {'PW', 'BW'} Which parameter to threshold on. 'PW' is power and 'BW' is bandwidth. - attribute : {'peak_params', 'gaussian_params'} + attribute : {'peak', 'gaussian'} Which attribute of peak data to extract data from. Returns diff --git a/specparam/data/utils.py b/specparam/data/utils.py index 30efa160..3c872926 100644 --- a/specparam/data/utils.py +++ b/specparam/data/utils.py @@ -77,7 +77,7 @@ def get_group_params(group_results, modes, name, field=None): List of FitResults objects, reflecting model results across a group of power spectra. modes : Modes Model modes definition. - name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'} + name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'metrics'} Name of the data field to extract across the group. field : str or int, optional Column name / index to extract from selected data, if requested. diff --git a/specparam/io/files.py b/specparam/io/files.py index dde0166b..6b0b6f88 100644 --- a/specparam/io/files.py +++ b/specparam/io/files.py @@ -69,7 +69,7 @@ def load_json(file_name, file_path): # Get dictionary of available attributes, and convert specified lists back into arrays arrays_to_convert = ['freqs', 'power_spectrum', - 'aperiodic_params_', 'peak_params_', 'gaussian_params_'] + 'aperiodic_params', 'peak_params', 'gaussian_params'] data = dict_lst_to_array(data, arrays_to_convert) return data diff --git a/specparam/io/models.py b/specparam/io/models.py index f0ccb75d..6018f3f2 100644 --- a/specparam/io/models.py +++ b/specparam/io/models.py @@ -48,17 +48,12 @@ def save_model(model, file_name, file_path=None, append=False, If the save file is not understood. """ - # Convert object to dictionary & convert all arrays to lists, for JSON serializing - # This 'flattens' the object, getting all relevant attributes in the same dictionary - obj_dict = dict_array_to_lst(model.__dict__) - data_dict = dict_array_to_lst(model.data.__dict__) - results_dict = dict_array_to_lst(model.results.__dict__) - algo_dict = dict_array_to_lst(model.algorithm.__dict__) - obj_dict = {**obj_dict, **data_dict, **results_dict, **algo_dict} + # 'Flatten' the model object by extracting relevant attributes to a dictionary + obj_dict = {**model.data.__dict__, **model.algorithm.settings.values} # Convert modes object to their saveable string name - obj_dict['aperiodic_mode'] = obj_dict['modes'].aperiodic.name - obj_dict['periodic_mode'] = obj_dict['modes'].periodic.name + obj_dict['aperiodic_mode'] = model.modes.aperiodic.name + obj_dict['periodic_mode'] = model.modes.periodic.name mode_labels = ['aperiodic_mode', 'periodic_mode'] # Add bands information to saveable information @@ -66,8 +61,15 @@ def save_model(model, file_name, file_path=None, append=False, if not model.results.bands._n_bands else model.results.bands._n_bands bands_label = ['bands'] if model.results.bands else [] - # Convert metrics results to saveable information - obj_dict['metrics'] = obj_dict['metrics'].results + # Convert results & metrics to saveable information + results_labels = [] + for rfield in model.results.params.fields: + results_labels.append(rfield + '_params') + obj_dict[rfield + '_params'] = getattr(model.results.params, rfield) + obj_dict['metrics'] = model.results.metrics.results + + # Convert all arrays to list for JSON serialization + obj_dict = dict_array_to_lst(obj_dict) # Check for saving out base information / check if base only if save_base is None: @@ -79,7 +81,7 @@ def save_model(model, file_name, file_path=None, append=False, keep = set(\ (mode_labels + bands_label if save_base else []) + \ (model.data._meta_fields if save_base or base_only else []) + \ - (model.results._fields + ['metrics'] if save_results else []) + \ + (results_labels + ['metrics'] if save_results else []) + \ (model.algorithm.settings.names if save_settings else []) + \ (model.data._fields if save_data else [])) diff --git a/specparam/measures/metrics.py b/specparam/measures/metrics.py index 9d1f5106..1352be7c 100644 --- a/specparam/measures/metrics.py +++ b/specparam/measures/metrics.py @@ -20,5 +20,5 @@ 'gof_rsquared' : Metric('gof', 'rsquared', compute_r_squared), 'gof_adjrsquared' : Metric('gof', 'adjrsquared', compute_adj_r_squared, \ {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}) + results.params.peak.size + results.params.aperiodic.size}) } diff --git a/specparam/measures/pointwise.py b/specparam/measures/pointwise.py index b5553636..75be4d33 100644 --- a/specparam/measures/pointwise.py +++ b/specparam/measures/pointwise.py @@ -43,7 +43,7 @@ def compute_pointwise_error(model, plot_errors=True, return_errors=False, **plt_ raise NoModelError("No model is available to use, can not proceed.") errors = compute_pointwise_error_arr(\ - model.results.modeled_spectrum_, model.data.power_spectrum) + model.results.model.modeled_spectrum, model.data.power_spectrum) if plot_errors: plot_spectral_error(model.data.freqs, errors, **plt_kwargs) diff --git a/specparam/models/base.py b/specparam/models/base.py index fe630d11..aadd1981 100644 --- a/specparam/models/base.py +++ b/specparam/models/base.py @@ -3,6 +3,7 @@ from copy import deepcopy from specparam.utils.array import unlog +from specparam.utils.checks import check_array_dim from specparam.modes.modes import Modes from specparam.modutils.errors import NoDataError from specparam.reports.strings import gen_modes_str, gen_settings_str, gen_issue_str @@ -11,7 +12,24 @@ ################################################################################################### class BaseModel(): - """Define BaseModel object.""" + """Define BaseModel object. + + Parameters + ---------- + aperiodic_mode : Mode or str + Mode for aperiodic component, or string specifying which mode to use. + periodic_mode : Mode or str + Mode for periodic component, or string specifying which mode to use. + verbose : bool + Whether to print out updates from the object. + + Attributes + ---------- + modes : Modes + Fit modes definitions. + verbose : bool + Verbosity status. + """ def __init__(self, aperiodic_mode, periodic_mode, verbose): """Initialize object.""" @@ -84,11 +102,11 @@ def get_data(self, component='full', space='log'): output = self.data.power_spectrum if space == 'log' \ else unlog(self.data.power_spectrum) elif component == 'aperiodic': - output = self.results._spectrum_peak_rm if space == 'log' else \ - unlog(self.data.power_spectrum) / unlog(self.results._peak_fit) + output = self.results.model._spectrum_peak_rm if space == 'log' else \ + unlog(self.data.power_spectrum) / unlog(self.results.model._peak_fit) elif component == 'peak': - output = self.results._spectrum_flat if space == 'log' else \ - unlog(self.data.power_spectrum) - unlog(self.results._ap_fit) + output = self.results.model._spectrum_flat if space == 'log' else \ + unlog(self.data.power_spectrum) - unlog(self.results.model._ap_fit) else: raise ValueError('Input for component invalid.') @@ -155,12 +173,14 @@ def _add_from_dict(self, data): tmetrics = data.pop('metrics') self.results.add_metrics(list(tmetrics.keys())) self.results.metrics.add_results(tmetrics) + for label, params in {key : vals for key, vals in data.items() if 'params' in key}.items(): + if 'peak' in label or 'gaussian' in label: + params = check_array_dim(params) + setattr(self.results.params, label.split('_')[0], params) # Add additional attributes directly to object for key in data.keys(): - if getattr(self, key, False) is not False: - setattr(self, key, data[key]) + if getattr(self.algorithm.settings, key, False) is not False: + setattr(self.algorithm.settings, key, data[key]) elif getattr(self.data, key, False) is not False: setattr(self.data, key, data[key]) - elif getattr(self.results, key, False) is not False: - setattr(self.results, key, data[key]) diff --git a/specparam/models/event.py b/specparam/models/event.py index aa19a5d6..319b74cb 100644 --- a/specparam/models/event.py +++ b/specparam/models/event.py @@ -27,10 +27,8 @@ class SpectralTimeEventModel(SpectralTimeModel): """Model a set of event as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -43,9 +41,12 @@ class SpectralTimeEventModel(SpectralTimeModel): Notes ----- % copied in from SpectralModel object - - The event object inherits from the time model, which in turn inherits from the - group object, etc. As such it also has data attributes defined on the underlying - objects (see notes and attribute lists in inherited objects for details). + - The event object inherits from the time model, overwriting the `data` and + `results` objects with versions for fitting models across events. + Event related, temporally organized results are collected into the + `results.event_time_results` attribute, which may include sub-selecting peaks + per band (depending on settings). Note that the `results.event_group_results` attribute + is also available, which maintains the full model results. """ def __init__(self, *args, **kwargs): diff --git a/specparam/models/group.py b/specparam/models/group.py index 40e62172..9966c528 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -31,10 +31,8 @@ class SpectralGroupModel(SpectralModel): """Model a group of power spectra as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -47,14 +45,10 @@ class SpectralGroupModel(SpectralModel): Notes ----- % copied in from SpectralModel object - - The group object inherits from the model object. As such it also has data - attributes (`power_spectrum` & `modeled_spectrum_`), and parameter attributes - (`aperiodic_params_`, `peak_params_`, `gaussian_params_`, `r_squared_`, `error_`) - which are defined in the context of individual model fits. These attributes are - used during the fitting process, but in the group context do not store results - post-fitting. Rather, all model fit results are collected and stored into the - `group_results` attribute. To access individual parameters of the fit, use - the `get_params` method. + - The group object inherits from the model object, and in doing so overwrites the + `data` and `results` objects with versions for fitting groups of power spectra. + All model fit results are collected and stored in the `results.group_results` attribute. + To access individual parameters of the fit, use the `get_params` method. """ def __init__(self, *args, **kwargs): @@ -225,15 +219,15 @@ def load(self, file_name, file_path=None): if 'power_spectrum' in data.keys(): power_spectra.append(data.pop('power_spectrum')) + data_keys = set(data.keys()) self._add_from_dict(data) - # If settings are loaded, check and update based on the first line - if ind == 0: - self.algorithm._check_loaded_settings(data) + # For hearder line, check if settings are loaded and clear defaults if not + if ind == 0 and not set(self.algorithm.settings.names).issubset(data_keys): + self.algorithm.settings.clear() # If results part of current data added, check and update object results - if set(self.results._fields).issubset(set(data.keys())): - self.results._check_loaded_results(data) + if set([el + '_params' for el in self.results.params.fields]).issubset(data_keys): self.results.group_results.append(self.results._get_results()) # Reconstruct frequency vector, if information is available to do so diff --git a/specparam/models/model.py b/specparam/models/model.py index 85813839..26913cda 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -10,7 +10,7 @@ from specparam.models.base import BaseModel from specparam.objs.data import Data from specparam.objs.results import Results -from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS +from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS_DEF from specparam.reports.save import save_model_report from specparam.reports.strings import gen_model_results_str from specparam.modutils.errors import NoDataError, FitError @@ -24,14 +24,12 @@ ################################################################################################### ################################################################################################### -@replace_docstring_sections([SPECTRAL_FIT_SETTINGS.make_docstring()]) +@replace_docstring_sections([SPECTRAL_FIT_SETTINGS_DEF.make_docstring()]) class SpectralModel(BaseModel): """Model a power spectrum as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -64,12 +62,6 @@ class SpectralModel(BaseModel): For example, raw FFT inputs are not appropriate. Where possible and appropriate, use longer time segments for power spectrum calculation to get smoother power spectra, as this will give better model fits. - - Commonly used abbreviations used in this module include: - CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic - - The gaussian params are those that define the gaussian of the fit, where as the peak - params are a modified version, in which the CF of the peak is the mean of the gaussian, - the PW of the peak is the height of the gaussian over and above the aperiodic component, - and the BW of the peak, is 2*std of the gaussian (as 'two sided' bandwidth). """ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, @@ -160,7 +152,7 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None, prechecks=True): # If not set to fail on NaN or Inf data at add time, check data here # This serves as a catch all for curve_fits which will fail given NaN or Inf # Because FitError's are by default caught, this allows fitting to continue - if not self.data._check_data: + if not self.data.checks['data']: if np.any(np.isinf(self.data.power_spectrum)) or \ np.any(np.isnan(self.data.power_spectrum)): raise FitError("Model fitting was skipped because there are NaN or Inf " @@ -276,14 +268,16 @@ def load(self, file_name, file_path=None, regenerate=True): # Add loaded data to object and check loaded data self._add_from_dict(data) - self.algorithm._check_loaded_settings(data) - self.results._check_loaded_results(data) + + # If settings are not loaded, clear defaults to not have potentially incorrect values + if not set(self.algorithm.settings.names).issubset(set(data.keys())): + self.algorithm.settings.clear() # Regenerate model components, based on what is available if regenerate: if self.data.freq_res: self.data._regenerate_freqs() - if np.all(self.data.freqs) and np.all(self.results.aperiodic_params_): + if np.all(self.data.freqs) and np.all(self.results.params.aperiodic): self.results._regenerate_model(self.data.freqs) diff --git a/specparam/models/time.py b/specparam/models/time.py index fd626407..5540cc0e 100644 --- a/specparam/models/time.py +++ b/specparam/models/time.py @@ -22,10 +22,8 @@ class SpectralTimeModel(SpectralGroupModel): """Model a spectrogram as a combination of aperiodic and periodic components. - WARNING: frequency and power values inputs must be in linear space. - - Passing in logged frequencies and/or power spectra is not detected, - and will silently produce incorrect results. + WARNING: frequency and power values inputs must be in linear space. Passing in logged + frequencies and/or power spectra is not detected, and will silently produce incorrect results. Parameters ---------- @@ -38,14 +36,12 @@ class SpectralTimeModel(SpectralGroupModel): Notes ----- % copied in from SpectralModel object - - The time object inherits from the group model, which in turn inherits from the - model object. As such it also has data attributes defined on the model object, - as well as additional attributes that are added to the group object (see notes - and attribute list in SpectralGroupModel). - - Notably, while this object organizes the results into the `time_results` - attribute, which may include sub-selecting peaks per band (depending on settings) - the `group_results` attribute is also available, which maintains the full - model results. + - The time object inherits from the group model, overwriting the `data` and + `results` objects with versions for fitting models across time. Temporally + organized results are collected into the `results.time_results` attribute, + which may include sub-selecting peaks per band (depending on settings). + Note that the `results.group_results` attribute is also available, which maintains + the full model results. """ def __init__(self, *args, **kwargs): diff --git a/specparam/models/utils.py b/specparam/models/utils.py index b3b39731..0e8e7082 100644 --- a/specparam/models/utils.py +++ b/specparam/models/utils.py @@ -39,7 +39,7 @@ def initialize_model_from_source(source, target): """ model = MODELS[target](**source.modes.get_modes()._asdict(), - **source.algorithm.get_settings()._asdict(), + **source.algorithm.settings.values, metrics=source.results.metrics.labels, bands=source.results.bands, verbose=source.verbose) @@ -72,14 +72,19 @@ def compare_model_objs(model_objs, aspect): outputs.append(compare_model_objs(model_objs, caspect)) return np.all(outputs) - check_input_options(aspect, ['settings', 'meta_data', 'metrics'], 'aspect') + aspects = ['modes', 'settings', 'meta_data', 'bands', 'metrics'] + check_input_options(aspect, aspects, 'aspect') # Check specified aspect of the objects are the same across instances for m_obj_1, m_obj_2 in zip(model_objs[:-1], model_objs[1:]): + if aspect == 'modes': + consistent = m_obj_1.modes.get_modes() == m_obj_2.modes.get_modes() if aspect == 'settings': consistent = m_obj_1.algorithm.get_settings() == m_obj_2.algorithm.get_settings() if aspect == 'meta_data': consistent = m_obj_1.data.get_meta_data() == m_obj_2.data.get_meta_data() + if aspect == 'bands': + consistent = m_obj_1.results.bands == m_obj_2.results.bands if aspect == 'metrics': consistent = m_obj_1.results.metrics.labels == m_obj_2.results.metrics.labels @@ -188,7 +193,7 @@ def average_reconstructions(group, avg_method='mean'): models = np.zeros(shape=group.data.power_spectra.shape) for ind in range(len(group.results)): - models[ind, :] = group.get_model(ind, regenerate=True).results.modeled_spectrum_ + models[ind, :] = group.get_model(ind, regenerate=True).results.model.modeled_spectrum avg_model = avg_funcs[avg_method](models, 0) @@ -262,8 +267,8 @@ def combine_model_objs(model_objs): # Set the status for freqs & data checking # Check states gets set as True if any of the inputs have it on, False otherwise group.data.set_checks(\ - check_freqs=any(getattr(m_obj.data, '_check_freqs') for m_obj in model_objs), - check_data=any(getattr(m_obj.data, '_check_data') for m_obj in model_objs)) + check_freqs=any(m_obj.data.checks['freqs'] for m_obj in model_objs), + check_data=any(m_obj.data.checks['data'] for m_obj in model_objs)) # Add data information information group.data.add_meta_data(model_objs[0].data.get_meta_data()) diff --git a/specparam/modes/modes.py b/specparam/modes/modes.py index f2752ed8..43b03a0b 100644 --- a/specparam/modes/modes.py +++ b/specparam/modes/modes.py @@ -56,9 +56,8 @@ def check_mode_definition(mode, options): if isinstance(mode, str): assert mode in list(options.keys()), 'Specific Mode not found.' mode = options[mode] - elif isinstance(mode, Mode): - mode = mode - else: + + if not isinstance(mode, Mode): raise ValueError('Mode input not understood.') return mode diff --git a/specparam/objs/components.py b/specparam/objs/components.py new file mode 100644 index 00000000..25b6ee58 --- /dev/null +++ b/specparam/objs/components.py @@ -0,0 +1,90 @@ +"""Define model components object.""" + +from specparam.utils.array import unlog +from specparam.modutils.errors import NoModelError + +################################################################################################### +################################################################################################### + +class ModelComponents(): + """Object for managing model components. + + Attributes + ---------- + modeled_spectrum : 1d array + Modeled spectrum. + _spectrum_flat : 1d array + Data attribute: flattened power spectrum, with the aperiodic component removed. + _spectrum_peak_rm : 1d array + Data attribute: power spectrum, with peaks removed. + _ap_fit : 1d array + Model attribute: values of the isolated aperiodic fit. + _peak_fit : 1d array + Model attribute: values of the isolated peak fit. + """ + + def __init__(self): + """Initialize ModelComponents object.""" + + self.reset() + + + def reset(self): + """Reset model components attributes.""" + + # Full model + self.modeled_spectrum = None + + # Model components + self._ap_fit = None + self._peak_fit = None + + # Data components + self._spectrum_flat = None + self._spectrum_peak_rm = None + + + def get_component(self, component='full', space='log'): + """Get a model component. + + Parameters + ---------- + component : {'full', 'aperiodic', 'peak'} + Which model component to return. + 'full' - full model + 'aperiodic' - isolated aperiodic model component + 'peak' - isolated peak model component + space : {'log', 'linear'} + Which space to return the model component in. + 'log' - returns in log10 space. + 'linear' - returns in linear space. + + Returns + ------- + output : 1d array + Specified model component, in specified spacing. + + Notes + ----- + The 'space' parameter doesn't just define the spacing of the model component + values, but rather defines the space of the additive model such that + `model = aperiodic_component + peak_component`. + With space set as 'log', this combination holds in log space. + With space set as 'linear', this combination holds in linear space. + """ + + if self.modeled_spectrum is None: + raise NoModelError("No model fit results are available, can not proceed.") + assert space in ['linear', 'log'], "Input for 'space' invalid." + + if component == 'full': + output = self.modeled_spectrum if space == 'log' else unlog(self.modeled_spectrum) + elif component == 'aperiodic': + output = self._ap_fit if space == 'log' else unlog(self._ap_fit) + elif component == 'peak': + output = self._peak_fit if space == 'log' else \ + unlog(self.modeled_spectrum) - unlog(self._ap_fit) + else: + raise ValueError('Input for component invalid.') + + return output diff --git a/specparam/objs/data.py b/specparam/objs/data.py index e7d593dc..cd3546b4 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -1,5 +1,6 @@ -"""Define base data objects.""" +"""Define data objects.""" +from warnings import warn from functools import wraps import numpy as np @@ -9,6 +10,7 @@ from specparam.utils.spectral import trim_spectrum from specparam.utils.checks import check_input_options from specparam.modutils.errors import DataError, InconsistentDataError +from specparam.modutils.docs import docs_get_section, replace_docstring_sections from specparam.plts.settings import PLT_COLORS from specparam.plts.spectra import plot_spectra, plot_spectrogram from specparam.plts.utils import check_plot_kwargs @@ -28,24 +30,28 @@ class Data(): Parameters ---------- check_freqs : bool - Whether to check the frequency values. - If True, checks the frequency values, and raises an error for uneven spacing. + Whether to check the frequency values. If so, raises an error for uneven spacing. check_data : bool - Whether to check the power spectrum values. - If True, checks the power values and raises an error for any NaN / Inf values. + Whether to check the spectral data. If so, raises an error for any NaN / Inf values. format : {'power'} The representation format of the data. Attributes ---------- + checks : dict + Specifiers for which aspects of the data to run checks on. freqs : 1d array - Frequency values for the power spectrum. - power_spectrum : 1d array - Power values, stored internally in log10 scale. + Frequency values for the spectral data. freq_range : list of [float, float] - Frequency range of the power spectrum, as [lowest_freq, highest_freq]. + Frequency range of the spectral data, as [lowest_freq, highest_freq]. freq_res : float - Frequency resolution of the power spectrum. + Frequency resolution of the spectral data. + power_spectrum : 1d array + Power values. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self, check_freqs=True, check_data=True, format='power'): @@ -55,9 +61,10 @@ def __init__(self, check_freqs=True, check_data=True, format='power'): self._fields = DATA_FIELDS self._meta_fields = META_DATA_FIELDS - # Define data check run statuses - self._check_freqs = check_freqs - self._check_data = check_data + self.checks = { + 'freqs' : check_freqs, + 'data' : check_data, + } check_input_options(format, FORMATS, 'format') self.format = format @@ -120,7 +127,7 @@ def get_checks(self): Object containing the check statuses from the current object. """ - return ModelChecks(**{key : getattr(self, '_' + key) for key in ModelChecks._fields}) + return ModelChecks(**{'check_' + key : value for key, value in self.checks.items()}) def get_meta_data(self): @@ -156,9 +163,9 @@ def set_checks(self, check_freqs=None, check_data=None): """ if check_freqs is not None: - self._check_freqs = check_freqs + self.checks['freqs'] = check_freqs if check_data is not None: - self._check_data = check_data + self.checks['data'] = check_data def _reset_data(self, clear_freqs=False, clear_spectrum=False): @@ -256,10 +263,10 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): # Check if freqs start at 0 and move up one value if so # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error if freqs[0] == 0.0: + msg = "specparam fit warning - skipping frequency == 0, " \ + "as this causes a problem with fitting." + warn(msg, category=RuntimeWarning) freqs, powers = trim_spectrum(freqs, powers, [freqs[1], freqs.max()]) - if self.verbose: - print("\nFITTING WARNING: Skipping frequency == 0, " - "as this causes a problem with fitting.") # Calculate frequency resolution, and actual frequency range of the data freq_range = [freqs.min(), freqs.max()] @@ -270,13 +277,13 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): ## Data checks - run checks on inputs based on check statuses - if self._check_freqs: + if self.checks['freqs']: # Check if the frequency data is unevenly spaced, and raise an error if so freq_diffs = np.diff(freqs) if not np.all(np.isclose(freq_diffs, freq_res)): raise DataError("The input frequency values are not evenly spaced. " "The model expects equidistant frequency values in linear space.") - if self._check_data: + if self.checks['data']: # Check if there are any infs / nans, and raise an error if so if np.any(np.isinf(powers)) or np.any(np.isnan(powers)): error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " @@ -288,20 +295,24 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): return freqs, powers, freq_range, freq_res +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data.__doc__, 'Attributes')]) class Data2D(Data): """Base object for managing data for spectral parameterization - for 2D data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the power spectra. + % copied in from Data power_spectra : 2d array Power values for the group of power spectra, as [n_power_spectra, n_freqs]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the power spectra, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the power spectra. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): @@ -385,20 +396,24 @@ def decorated(*args, **kwargs): return decorated +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data2D.__doc__, 'Attributes')]) class Data2DT(Data2D): """Base object for managing data for spectral parameterization - for 2D transposed data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the spectrogram. + % copied in from Data2D spectrogram : 2d array Power values for the spectrogram, as [n_freqs, n_time_windows]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the spectrogram, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the spectrogram. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): @@ -451,20 +466,24 @@ def plot(self, **plt_kwargs): plot_spectrogram(self.freqs, self.spectrogram, **plot_kwargs) +@replace_docstring_sections([docs_get_section(Data.__doc__, 'Parameters'), + docs_get_section(Data2DT.__doc__, 'Attributes')]) class Data3D(Data2DT): """Base object for managing data for spectral parameterization - for 3D data. + Parameters + ---------- + % copied in from Data + Attributes ---------- - freqs : 1d array - Frequency values for the power spectra. + % copied in from Data2DT spectrograms : 3d array Power values for the spectrograms, organized as [n_events, n_freqs, n_time_windows]. - Power values are stored internally in log10 scale. - freq_range : list of [float, float] - Frequency range of the power spectra, as [lowest_freq, highest_freq]. - freq_res : float - Frequency resolution of the power spectra. + + Notes + ----- + All power values are stored internally in log10 scale. """ def __init__(self): diff --git a/specparam/objs/metrics.py b/specparam/objs/metrics.py index 9170f041..d3068db3 100644 --- a/specparam/objs/metrics.py +++ b/specparam/objs/metrics.py @@ -12,8 +12,8 @@ class Metric(): Parameters ---------- - type : str - The type of measure, e.g. 'error' or 'gof'. + category : str + The category of measure, e.g. 'error' or 'gof'. measure : str The specific measure, e.g. 'r_squared'. func : callable @@ -25,10 +25,10 @@ class Metric(): and returns the desired parameter / computed value. """ - def __init__(self, type, measure, func, kwargs=None): + def __init__(self, category, measure, func, kwargs=None): """Initialize metric.""" - self.type = type + self.category = category self.measure = measure self.func = func self.result = np.nan @@ -45,17 +45,17 @@ def __repr__(self): def label(self): """Define label property.""" - return self.type + '_' + self.measure + return self.category + '_' + self.measure @property def flabel(self): """Define formatted label property.""" - if self.type == 'error': - flabel = '{} ({})'.format(self.type.capitalize(), self.measure.upper()) - if self.type == 'gof': - flabel = '{} ({})'.format(self.type.upper(), self.measure) + if self.category == 'error': + flabel = '{} ({})'.format(self.category.capitalize(), self.measure.upper()) + if self.category == 'gof': + flabel = '{} ({})'.format(self.category.upper(), self.measure) return flabel @@ -75,7 +75,7 @@ def compute_metric(self, data, results): for key, lfunc in self.kwargs.items(): kwargs[key] = lfunc(data, results) - self.result = self.func(data.power_spectrum, results.modeled_spectrum_, **kwargs) + self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs) class Metrics(): @@ -162,10 +162,10 @@ def compute_metrics(self, data, results): @property - def types(self): - """Define alias for metric type of all currently defined metrics.""" + def categories(self): + """Define alias for metric categories of all currently defined metrics.""" - return [metric.type for metric in self.metrics] + return [metric.category for metric in self.metrics] @property diff --git a/specparam/objs/params.py b/specparam/objs/params.py new file mode 100644 index 00000000..185a11f4 --- /dev/null +++ b/specparam/objs/params.py @@ -0,0 +1,58 @@ +"""Define model parameters object.""" + +import numpy as np + +################################################################################################### +################################################################################################### + +class ModelParameters(): + """Object to manage model fit parameters. + + Parameters + ---------- + modes : Modes + Fit modes defintion. + If provided, used to initialize parameter arrays to correct sizes. + + Attributes + ---------- + aperiodic : 1d array + Aperiodic parameters of the model fit. + peak : 1d array + Peak parameters of the model fit. + gaussian : 1d array + Gaussian parameters of the model fit. + """ + + def __init__(self, modes=None): + """Initialize ModelParameters object.""" + + self.aperiodic = np.nan + self.peak = np.nan + self.gaussian = np.nan + + self.reset(modes) + + def reset(self, modes=None): + """Reset parameters.""" + + # Aperiodic parameters + if modes: + self.aperiodic = np.array([np.nan] * modes.aperiodic.n_params) + else: + self.aperiodic = np.nan + + # Periodic parameters + if modes: + self.gaussian = np.empty([0, modes.periodic.n_params]) + self.peak = np.empty([0, modes.periodic.n_params]) + else: + self.gaussian = np.nan + self.peak = np.nan + + + @property + def fields(self): + """Alias as a property attribute the list of fields.""" + + return list(vars(self).keys()) diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 66ea6194..cdcecf74 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -1,4 +1,4 @@ -"""Define base results objects.""" +"""Define results objects.""" from copy import deepcopy from itertools import repeat @@ -7,8 +7,9 @@ from specparam.bands.bands import check_bands from specparam.objs.metrics import Metrics +from specparam.objs.params import ModelParameters +from specparam.objs.components import ModelComponents from specparam.measures.metrics import METRICS -from specparam.utils.array import unlog from specparam.utils.checks import check_inds, check_array_dim from specparam.modutils.errors import NoModelError from specparam.modutils.docs import docs_get_section, replace_docstring_sections @@ -22,7 +23,6 @@ ################################################################################################### # Define set of results fields & default metrics to use -RESULTS_FIELDS = ['aperiodic_params_', 'gaussian_params_', 'peak_params_'] DEFAULT_METRICS = ['error_mae', 'gof_rsquared'] @@ -35,8 +35,21 @@ class Results(): Modes object with fit mode definitions. metrics : Metrics Metrics object with metric definitions. - bands : bands + bands : Bands Bands object with band definitions. + + Attributes + ---------- + modes : Modes + Modes object with fit mode definitions. + bands : Bands + Bands object with band definitions. + model : ModelComponents + Manages the model fit and components. + params : ModelParameters + Manages the model fit parameters. + metrics : Metrics + Metrics object with metric definitions. """ # pylint: disable=attribute-defined-outside-init, arguments-differ @@ -48,9 +61,11 @@ def __init__(self, modes=None, metrics=None, bands=None): self.add_bands(bands) self.add_metrics(metrics) + self.model = ModelComponents() + self.params = ModelParameters() + # Initialize results attributes self._reset_results(True) - self._fields = RESULTS_FIELDS @property @@ -65,27 +80,27 @@ def has_model(self): - necessarily defined, as floats, if model has been fit """ - return not np.all(np.isnan(self.aperiodic_params_)) + return not np.all(np.isnan(self.params.aperiodic)) @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit in the model.""" n_peaks = None if self.has_model: - n_peaks = self.peak_params_.shape[0] + n_peaks = self.params.peak.shape[0] return n_peaks @property - def n_params_(self): + def n_params(self): """The total number of parameters fit in the model.""" n_params = None if self.has_model: - n_peak_params = self.modes.periodic.n_params * self.n_peaks_ + n_peak_params = self.modes.periodic.n_params * self.n_peaks n_params = n_peak_params + self.modes.aperiodic.n_params return n_params @@ -133,14 +148,14 @@ def add_results(self, results): A data object containing the results from fitting a power spectrum model. """ - # Add parameter fields and then select and add metrics results - for pfield in self._fields: - setattr(self, pfield, getattr(results, pfield.strip('_'))) + for pfield in self.params.fields: + params = getattr(results, pfield + '_params') + if 'peak' in pfield or 'gaussian' in pfield: + params = check_array_dim(params) + setattr(self.params, pfield, params) self.metrics.add_results(results.metrics) - self._check_loaded_results(results._asdict()) - def get_results(self): """Return model fit parameters and goodness of fit metrics. @@ -152,58 +167,12 @@ def get_results(self): """ results = FitResults( - **{key.strip('_') : getattr(self, key) for key in self._fields}, + **{key + '_params' : getattr(self.params, key) for key in self.params.fields}, metrics=self.metrics.results) return results - def get_component(self, component='full', space='log'): - """Get a model component. - - Parameters - ---------- - component : {'full', 'aperiodic', 'peak'} - Which model component to return. - 'full' - full model - 'aperiodic' - isolated aperiodic model component - 'peak' - isolated peak model component - space : {'log', 'linear'} - Which space to return the model component in. - 'log' - returns in log10 space. - 'linear' - returns in linear space. - - Returns - ------- - output : 1d array - Specified model component, in specified spacing. - - Notes - ----- - The 'space' parameter doesn't just define the spacing of the model component - values, but rather defines the space of the additive model such that - `model = aperiodic_component + peak_component`. - With space set as 'log', this combination holds in log space. - With space set as 'linear', this combination holds in linear space. - """ - - if not self.has_model: - raise NoModelError("No model fit results are available, can not proceed.") - assert space in ['linear', 'log'], "Input for 'space' invalid." - - if component == 'full': - output = self.modeled_spectrum_ if space == 'log' else unlog(self.modeled_spectrum_) - elif component == 'aperiodic': - output = self._ap_fit if space == 'log' else unlog(self._ap_fit) - elif component == 'peak': - output = self._peak_fit if space == 'log' else \ - unlog(self.modeled_spectrum_) - unlog(self._ap_fit) - else: - raise ValueError('Input for component invalid.') - - return output - - def get_params(self, name, field=None): """Return model fit parameters for specified feature(s). @@ -236,22 +205,6 @@ def get_params(self, name, field=None): return get_model_params(self.get_results(), self.modes, name, field) - def _check_loaded_results(self, data): - """Check if results have been added and check data. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If results loaded, check dimensions of peak parameters - # This fixes an issue where they end up the wrong shape if they are empty (no peaks) - if set(self._fields).issubset(set(data.keys())): - self.peak_params_ = check_array_dim(self.peak_params_) - self.gaussian_params_ = check_array_dim(self.gaussian_params_) - - def _reset_results(self, clear_results=False): """Set, or reset, results attributes to empty. @@ -262,29 +215,8 @@ def _reset_results(self, clear_results=False): """ if clear_results: - - # Aperiodic parameters - if self.modes: - self.aperiodic_params_ = np.array([np.nan] * self.modes.aperiodic.n_params) - else: - self.aperiodic_params_ = np.nan - - # Periodic parameters - if self.modes: - self.gaussian_params_ = np.empty([0, self.modes.periodic.n_params]) - self.peak_params_ = np.empty([0, self.modes.periodic.n_params]) - else: - self.gaussian_params_ = np.nan - self.peak_params_ = np.nan - - # Data components - self._spectrum_flat = None - self._spectrum_peak_rm = None - - # Modeled spectrum components - self.modeled_spectrum_ = None - self._ap_fit = None - self._peak_fit = None + self.params.reset(self.modes) + self.model.reset() def _regenerate_model(self, freqs): @@ -296,19 +228,25 @@ def _regenerate_model(self, freqs): Frequency values for the power_spectrum, in linear scale. """ - self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model(freqs, \ - self.modes.aperiodic, self.aperiodic_params_, - self.modes.periodic, self.gaussian_params_, - return_components=True) + self.model.modeled_spectrum, self.model._peak_fit, self.model._ap_fit = \ + gen_model(freqs, self.modes.aperiodic, self.params.aperiodic, + self.modes.periodic, self.params.gaussian, return_components=True) -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results.__doc__, 'Attributes')]) class Results2D(Results): """Object for managing results - 2D version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results + group_results : list of FitResults + Results of the model fit for each power spectrum. """ def __init__(self, modes=None, metrics=None, bands=None): @@ -364,7 +302,7 @@ def has_model(self): @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit for each model.""" n_peaks = None @@ -375,7 +313,7 @@ def n_peaks_(self): @property - def n_null_(self): + def n_null(self): """How many model fits are null.""" n_null = None @@ -386,7 +324,7 @@ def n_null_(self): @property - def null_inds_(self): + def null_inds(self): """The indices for model fits that are null.""" null_inds = None @@ -468,13 +406,20 @@ def get_params(self, name, field=None): return get_group_params(self.group_results, self.modes, name, field) -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results2D.__doc__, 'Attributes')]) class Results2DT(Results2D): """Object for managing results - 2D transpose version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results2D + time_results : dict + Results of the model fit across each time window. """ def __init__(self, modes=None, metrics=None, bands=None): @@ -527,13 +472,23 @@ def convert_results(self): self.time_results = group_to_dict(self.group_results, self.modes, self.bands) -@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters')]) +@replace_docstring_sections([docs_get_section(Results.__doc__, 'Parameters'), + docs_get_section(Results2DT.__doc__, 'Attributes')]) class Results3D(Results2DT): """Object for managing results - 3D version. Parameters ---------- % copied in from Results + + Attributes + ---------- + % copied in from Results2DT + event_group_results : list of list of FitResults + Full model results collected across all events and models. + event_time_results : dict + Results of the model fit across each time window, collected across events. + Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows]. """ def __init__(self, modes=None, metrics=None, bands=None): @@ -571,7 +526,7 @@ def has_model(self): @property - def n_peaks_(self): + def n_peaks(self): """How many peaks were fit for each model, for each event.""" n_peaks = None @@ -666,7 +621,8 @@ def get_params(self, name, field=None): column is appended to the returned array, indicating the index that the peak came from. """ - return [get_group_params(gres, self.modes, name, field) for gres in self.event_group_results] + return [get_group_params(gres, self.modes, name, field) \ + for gres in self.event_group_results] def convert_results(self): diff --git a/specparam/plts/annotate.py b/specparam/plts/annotate.py index 5a8df135..20ae0d70 100644 --- a/specparam/plts/annotate.py +++ b/specparam/plts/annotate.py @@ -33,17 +33,17 @@ def plot_annotated_peak_search(model): # is the same as the one that is used in the peak fitting procedure flatspec = model.data.power_spectrum - \ model.modes.aperiodic.func(model.data.freqs, \ - *model.algorithm._robust_ap_fit(model.data.freqs, model.data.power_spectrum),) + *model.algorithm._robust_ap_fit(model.data.freqs, model.data.power_spectrum)) # Calculate ylims of the plot that are scaled to the range of the data ylims = [min(flatspec) - 0.1 * np.abs(min(flatspec)), max(flatspec) + 0.1 * max(flatspec)] # Sort parameters by peak height - gaussian_params = model.results.gaussian_params_[\ - model.results.gaussian_params_[:, 1].argsort()][::-1] + gaussian_params = model.results.params.gaussian[\ + model.results.params.gaussian[:, 1].argsort()][::-1] # Loop through the iterative search for each peak - for ind in range(model.results.n_peaks_ + 1): + for ind in range(model.results.n_peaks + 1): # This forces the creation of a new plotting axes per iteration ax = check_ax(None, PLT_FIGSIZES['spectral']) @@ -51,10 +51,12 @@ def plot_annotated_peak_search(model): plot_spectra(model.data.freqs, flatspec, linewidth=2.5, label='Flattened Spectrum', color=PLT_COLORS['data'], ax=ax) plot_spectra(model.data.freqs, - [model.algorithm.peak_threshold * np.std(flatspec)] * len(model.data.freqs), + [model.algorithm.settings.peak_threshold * np.std(flatspec)] \ + * len(model.data.freqs), label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed', ax=ax) - plot_spectra(model.data.freqs, [model.algorithm.min_peak_height]*len(model.data.freqs), + plot_spectra(model.data.freqs, + [model.algorithm.settings.min_peak_height] * len(model.data.freqs), label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed', ax=ax) @@ -65,7 +67,7 @@ def plot_annotated_peak_search(model): ax.set_ylim(ylims) ax.set_title('Iteration #' + str(ind+1), fontsize=16) - if ind < model.results.n_peaks_: + if ind < model.results.n_peaks: gauss = model.modes.periodic.func(model.data.freqs, *gaussian_params[ind, :]) plot_spectra(model.data.freqs, gauss, ax=ax, label='Gaussian Fit', @@ -136,10 +138,10 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # See: https://github.com/matplotlib/matplotlib/issues/12820. Fixed in 3.2.1. bug_buff = 0.000001 - if annotate_peaks and model.results.n_peaks_: + if annotate_peaks and model.results.n_peaks: # Extract largest peak, to annotate, grabbing gaussian params - gauss = get_band_peak(model, model.data.freq_range, attribute='gaussian_params') + gauss = get_band_peak(model, model.data.freq_range, attribute='gaussian') peak_ctr, peak_hgt, peak_wid = gauss bw_freqs = [peak_ctr - 0.5 * compute_fwhm(peak_wid), @@ -183,7 +185,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # Annotate Aperiodic Offset # Add a line to indicate offset, without adjusting plot limits below it ax.set_autoscaley_on(False) - ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.modeled_spectrum_[0]], + ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.model.modeled_spectrum[0]], color=PLT_COLORS['aperiodic'], linewidth=lw2, alpha=0.5) ax.annotate('Offset', xy=(freqs[0]+bug_buff, model.data.power_spectrum[0]-y_buff1), diff --git a/specparam/plts/event.py b/specparam/plts/event.py index 97e2df39..3b99df68 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -86,5 +86,5 @@ def plot_event_model(event, **plot_kwargs): title='Fit Quality' if ind == 0 else None, drop_xticks=ind < len(event.results.metrics), add_xlabel=ind == len(event.results.metrics), - color=PARAM_COLORS[event.results.metrics.types[ind]], + color=PARAM_COLORS[event.results.metrics.categories[ind]], xlim=xlim, ax=next(axes)) diff --git a/specparam/plts/model.py b/specparam/plts/model.py index 6a2ca633..2499b1c1 100644 --- a/specparam/plts/model.py +++ b/specparam/plts/model.py @@ -8,7 +8,6 @@ import numpy as np from specparam.modutils.dependencies import safe_import, check_dependency -from specparam.sim.gen import gen_periodic from specparam.utils.select import nearest_ind from specparam.utils.spectral import trim_spectrum from specparam.measures.params import compute_fwhm @@ -87,7 +86,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, 'label' : 'Full Model Fit' if add_legend else None} model_kwargs = check_plot_kwargs(model_kwargs, model_defaults) - plot_spectra(model.data.freqs, model.results.modeled_spectrum_, + plot_spectra(model.data.freqs, model.results.model.modeled_spectrum, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit @@ -96,7 +95,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp 'alpha' : 0.5, 'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None} aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults) - plot_spectra(model.data.freqs, model.results._ap_fit, + plot_spectra(model.data.freqs, model.results.model._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs) # Plot the periodic components of the model fit @@ -169,13 +168,12 @@ def _add_peaks_shade(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.gaussian: peak_freqs = np.log10(model.data.freqs) if plt_log else model.data.freqs - #peak_line = model.results._ap_fit + gen_periodic(model.data.freqs, peak) - peak_line = model.results._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) + peak_line = model.results.model._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) - ax.fill_between(peak_freqs, peak_line, model.results._ap_fit, **plot_kwargs) + ax.fill_between(peak_freqs, peak_line, model.results.model._ap_fit, **plot_kwargs) def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): @@ -196,9 +194,9 @@ def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.peak_params_: + for peak in model.results.params.peak: - ap_point = np.interp(peak[0], model.data.freqs, model.results._ap_fit) + ap_point = np.interp(peak[0], model.data.freqs, model.results.model._ap_fit) freq_point = np.log10(peak[0]) if plt_log else peak[0] # Add the line from the aperiodic fit up the tip of the peak @@ -226,14 +224,13 @@ def _add_peaks_outline(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.gaussian: # Define the frequency range around each peak to plot - peak bandwidth +/- 3 peak_range = [peak[0] - peak[2]*3, peak[0] + peak[2]*3] # Generate a peak reconstruction for each peak, and trim to desired range - #peak_line = model.results._ap_fit + gen_periodic(model.data.freqs, peak) - peak_line = model.results._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) + peak_line = model.results.model._ap_fit + model.modes.periodic.func(model.data.freqs, *peak) peak_freqs, peak_line = trim_spectrum(model.data.freqs, peak_line, peak_range) # Plot the peak outline @@ -261,7 +258,7 @@ def _add_peaks_line(model, plt_log, ax, **plot_kwargs): ylims = ax.get_ylim() - for peak in model.results.peak_params_: + for peak in model.results.params.peak: freq_point = np.log10(peak[0]) if plt_log else peak[0] ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs) @@ -291,7 +288,7 @@ def _add_peaks_width(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.results.gaussian_params_: + for peak in model.results.params.gaussian: peak_top = model.data.power_spectrum[nearest_ind(model.data.freqs, peak[0])] bw_freqs = [peak[0] - 0.5 * compute_fwhm(peak[2]), diff --git a/specparam/plts/time.py b/specparam/plts/time.py index eece420f..6fa435d0 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -74,6 +74,6 @@ def plot_time_model(time, **plot_kwargs): time.results.time_results[time.results.metrics.labels[gof_ind]]], labels=[time.results.metrics.flabels[err_ind], time.results.metrics.flabels[gof_ind]], - colors=[PARAM_COLORS[time.results.metrics.types[err_ind]], - PARAM_COLORS[time.results.metrics.types[gof_ind]]], + colors=[PARAM_COLORS[time.results.metrics.categories[err_ind]], + PARAM_COLORS[time.results.metrics.categories[gof_ind]]], xlim=xlim, title='Fit Quality', ax=next(axes)) diff --git a/specparam/plts/utils.py b/specparam/plts/utils.py index 83d2f018..29078d09 100644 --- a/specparam/plts/utils.py +++ b/specparam/plts/utils.py @@ -93,7 +93,7 @@ def add_shades(ax, shades, colors='r', shade_alpha=0.2, shades = [shades] colors = repeat(colors) if not isinstance(colors, list) else colors - shade_alphas = repeat(shade_alpha) if not isinstance(shade_alpha, list) else alpha + shade_alphas = repeat(shade_alpha) if not isinstance(shade_alpha, list) else shade_alpha for shade, color, alpha in zip(shades, colors, shade_alphas): diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index cea9261b..b3c07bad 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -211,9 +211,9 @@ def gen_settings_str(model, description=False, concise=False): # Loop through algorithm settings, and add information for name in model.algorithm.settings.names: - str_lst.append(name + ' : ' + str(getattr(model.algorithm, name))) + str_lst.append(name + ' : ' + str(getattr(model.algorithm.settings, name))) if description: - str_lst.append(model.algorithm.settings.descriptions[name].split('\n ')[0]) + str_lst.append(model.algorithm.public_settings.descriptions[name].split('\n ')[0]) # Add footer to string str_lst.extend([ @@ -337,10 +337,10 @@ def gen_methods_text_str(model=None): methods_str = template.format(MODULE_VERSION, model.modes.aperiodic.name if model else 'XX', model.modes.periodic.name if model else 'XX', - model.algorithm.peak_width_limits if model else 'XX', - model.algorithm.max_n_peaks if model else 'XX', - model.algorithm.min_peak_height if model else 'XX', - model.algorithm.peak_threshold if model else 'XX', + model.algorithm.settings.peak_width_limits if model else 'XX', + model.algorithm.settings.max_n_peaks if model else 'XX', + model.algorithm.settings.min_peak_height if model else 'XX', + model.algorithm.settings.peak_threshold if model else 'XX', *freq_range) return methods_str @@ -388,13 +388,13 @@ def gen_model_results_str(model, concise=False): 'Aperiodic Parameters (\'{}\' mode)'.format(model.modes.aperiodic.name), '(' + ', '.join(model.modes.aperiodic.params.labels) + ')', ', '.join(['{:2.4f}'] * \ - len(model.results.aperiodic_params_)).format(*model.results.aperiodic_params_), + len(model.results.params.aperiodic)).format(*model.results.params.aperiodic), '', # Peak parameters 'Peak Parameters (\'{}\' mode) {} peaks found'.format(\ - model.modes.periodic.name, model.results.n_peaks_), - *[peak_str.format(*op) for op in model.results.peak_params_], + model.modes.periodic.name, model.results.n_peaks), + *[peak_str.format(*op) for op in model.results.params.peak], '', # Metrics @@ -455,7 +455,7 @@ def gen_group_results_str(group, concise=False): # Peak Parameters 'Peak Parameters (\'{}\' mode) {} total peaks found'.format(\ - group.modes.periodic.name, sum(group.results.n_peaks_)), + group.modes.periodic.name, sum(group.results.n_peaks)), '', # Metrics @@ -651,7 +651,7 @@ def _report_str_n_null(model): output = \ [el for el in ['{} power spectra failed to fit'.format(\ - model.results.n_null_)] if model.results.n_null_] + model.results.n_null)] if model.results.n_null] return output diff --git a/specparam/sim/params.py b/specparam/sim/params.py index 04223559..d5720d90 100644 --- a/specparam/sim/params.py +++ b/specparam/sim/params.py @@ -7,7 +7,6 @@ from specparam.data import SimParams from specparam.modes.modes import check_mode_definition from specparam.modes.definitions import AP_MODES -from specparam.utils.select import groupby from specparam.utils.checks import check_flat from specparam.modutils.errors import InconsistentDataError @@ -33,7 +32,6 @@ def collect_sim_params(aperiodic_params, periodic_params, nlv): """ return SimParams(deepcopy(aperiodic_params), - #sorted(groupby(check_flat(periodic_params), 3)), deepcopy(periodic_params), nlv) diff --git a/specparam/tests/algorithms/test_algorithm.py b/specparam/tests/algorithms/test_algorithm.py index d16f0668..f46f4cb3 100644 --- a/specparam/tests/algorithms/test_algorithm.py +++ b/specparam/tests/algorithms/test_algorithm.py @@ -1,5 +1,6 @@ """Tests for specparam.algorthms.algorithm.""" +from specparam.modes.modes import Modes from specparam.algorithms.settings import SettingsDefinition from specparam.algorithms.algorithm import * @@ -16,12 +17,14 @@ def test_algorithm(): 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, }) - algo = Algorithm(name=tname, description=tdescription, settings=tsettings, format='spectrum') + algo = Algorithm(name=tname, description=tdescription, public_settings=tsettings) assert algo assert algo.name == tname assert algo.description == tdescription - assert isinstance(algo.settings, SettingsDefinition) - assert algo.settings == tsettings + assert isinstance(algo.public_settings, SettingsDefinition) + assert algo.public_settings == tsettings + for setting in algo.public_settings.names: + assert getattr(algo.settings, setting) is None def test_algorithm_settings(): @@ -32,14 +35,46 @@ def test_algorithm_settings(): 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, }) - talgo = Algorithm(name=tname, description=tdescription, settings=tsettings, format='spectrum') + talgo = Algorithm(name=tname, description=tdescription, public_settings=tsettings) - model_settings = talgo.settings.make_model_settings() + model_settings = talgo.public_settings.make_model_settings() settings = model_settings(a=1, b=2) talgo.add_settings(settings) for setting in settings._fields: - assert getattr(talgo, setting) == getattr(settings, setting) + assert getattr(talgo.settings, setting) == getattr(settings, setting) settings_out = talgo.get_settings() assert isinstance(settings, model_settings) assert settings_out == settings + +def test_algorithm_cf(): + + tname = 'test_algo' + tdescription = 'Test algorithm description' + tsettings = SettingsDefinition({ + 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, + 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, + }) + + algo = AlgorithmCF(name=tname, description=tdescription, public_settings=tsettings) + + assert isinstance(algo._cf_settings_desc, SettingsDefinition) + assert algo._cf_settings + for setting in algo._cf_settings.names: + assert getattr(algo._cf_settings, setting) is None + +def test_algorithm_cf_initialize(): + + algo = AlgorithmCF(name='test_algo', description='desc', + public_settings={'a' : {'type' : 'a type desc', 'description' : 'a desc'}}, + modes=Modes('fixed', 'gaussian')) + + ap_bounds = algo._initialize_bounds('aperiodic') + assert len(ap_bounds[0]) == algo.modes.aperiodic.n_params + pe_bounds = algo._initialize_bounds('periodic') + assert len(pe_bounds[0]) == algo.modes.periodic.n_params + + ap_guess = algo._initialize_guess('aperiodic') + assert len(ap_guess) == algo.modes.aperiodic.n_params + pe_guess = algo._initialize_guess('periodic') + assert len(pe_guess) == algo.modes.periodic.n_params diff --git a/specparam/tests/algorithms/test_settings.py b/specparam/tests/algorithms/test_settings.py index 30501a86..4e896fbb 100644 --- a/specparam/tests/algorithms/test_settings.py +++ b/specparam/tests/algorithms/test_settings.py @@ -5,18 +5,36 @@ ################################################################################################### ################################################################################################### +def test_settings_values(): + + tsettings_names = ['a', 'b'] + settings_vals = SettingsValues(tsettings_names) + assert isinstance(settings_vals.values, dict) + assert settings_vals.names == tsettings_names + assert settings_vals.a is None + assert settings_vals.b is None + settings_vals.a = 1 + settings_vals.b = 2 + assert settings_vals.a == 1 + assert settings_vals.b == 2 + + settings_vals.clear() + assert settings_vals.a is None + assert settings_vals.b is None + def test_settings_definition(): - tsettings = { + tdefinitions = { 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, } - settings = SettingsDefinition(tsettings) - assert settings._settings == tsettings - assert settings.names == list(tsettings.keys()) - assert settings.types - assert settings.descriptions - for label in tsettings.keys(): - assert settings.make_setting_str(label) - assert settings.make_docstring() + settings_def = SettingsDefinition(tdefinitions) + assert settings_def._definitions == tdefinitions + assert len(settings_def) == len(tdefinitions) + assert settings_def.names == list(tdefinitions.keys()) + assert settings_def.types + assert settings_def.descriptions + for label in tdefinitions.keys(): + assert settings_def.make_setting_str(label) + assert settings_def.make_docstring() diff --git a/specparam/tests/io/test_models.py b/specparam/tests/io/test_models.py index a22c7594..0160fe40 100644 --- a/specparam/tests/io/test_models.py +++ b/specparam/tests/io/test_models.py @@ -151,17 +151,14 @@ def test_load_file_contents(tfm): """Check that loaded model files contain the contents they should.""" # Loads file saved from `test_save_model_str` - file_name = 'test_model_all' - - loaded_data = load_json(file_name, TEST_DATA_PATH) - + loaded_data = load_json('test_model_all', TEST_DATA_PATH) for mode in tfm.modes.get_modes()._fields: assert mode in loaded_data.keys() assert 'bands' in loaded_data.keys() for setting in tfm.algorithm.settings.names: assert setting in loaded_data.keys() - for result in tfm.results._fields: - assert result in loaded_data.keys() + for result in tfm.results.params.fields: + assert result + '_params' in loaded_data.keys() assert 'metrics' in loaded_data.keys() for datum in tfm.data._fields: assert datum in loaded_data.keys() @@ -169,99 +166,94 @@ def test_load_file_contents(tfm): def test_load_model(tfm): # Loads file saved from `test_save_model_str` - file_name = 'test_model_all' - ntfm = load_model(file_name, TEST_DATA_PATH) + ntfm = load_model('test_model_all', TEST_DATA_PATH) assert isinstance(ntfm, SpectralModel) - - # Check that all elements get loaded - assert tfm.modes.get_modes() == ntfm.modes.get_modes() - assert tfm.results.bands == ntfm.results.bands - for meta_dat in tfm.data._meta_fields: - assert getattr(ntfm.data, meta_dat) is not None - for setting in ntfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None - for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) - assert tfm.results.metrics.results == ntfm.results.metrics.results + compare_model_objs([tfm, ntfm], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) for data in tfm.data._fields: - assert getattr(ntfm.data, data) is not None + assert np.array_equal(getattr(tfm.data, data), getattr(ntfm.data, data)) + for result in tfm.results.params.fields: + assert not np.all(np.isnan(getattr(ntfm.results.params, result))) + assert tfm.results.metrics.results == ntfm.results.metrics.results # Check directory matches (loading didn't add any unexpected attributes) cfm = SpectralModel() assert dir(cfm) == dir(ntfm) + assert dir(cfm.algorithm) == dir(ntfm.algorithm) assert dir(cfm.data) == dir(ntfm.data) assert dir(cfm.results) == dir(ntfm.results) + assert dir(cfm.results.params) == dir(ntfm.results.params) def test_load_model2(tfm2): # Loads file saved from `test_save_model_str2` - file_name = 'test_model_all2' - ntfm2 = load_model(file_name, TEST_DATA_PATH) - assert tfm2.modes.get_modes() == ntfm2.modes.get_modes() - compare_model_objs([tfm2, ntfm2], ['settings', 'meta_data', 'metrics']) + ntfm2 = load_model('test_model_all2', TEST_DATA_PATH) + compare_model_objs([tfm2, ntfm2], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_load_group(tfg): # Loads file saved from `test_save_group` - file_name = 'test_group_all' - ntfg = load_group(file_name, TEST_DATA_PATH) + ntfg = load_group('test_group_all', TEST_DATA_PATH) assert isinstance(ntfg, SpectralGroupModel) - - # Check that all elements get loaded - assert tfg.modes.get_modes() == ntfg.modes.get_modes() - assert tfg.results.bands == ntfg.results.bands - for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None + compare_model_objs([tfg, ntfg], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tfg.data._fields: + assert np.array_equal(getattr(tfg.data, data), getattr(ntfg.data, data)) assert len(ntfg.results.group_results) > 0 for metric in tfg.results.metrics.labels: - assert tfg.results.metrics.results[metric] is not None - assert ntfg.data.power_spectra is not None - for meta_dat in tfg.data._meta_fields: - assert getattr(ntfg.data, meta_dat) is not None + assert tfg.results.metrics.results[metric] == ntfg.results.metrics.results[metric] # Check directory matches (loading didn't add any unexpected attributes) cfg = SpectralGroupModel() assert dir(cfg) == dir(ntfg) + assert dir(cfg.algorithm) == dir(ntfg.algorithm) assert dir(cfg.data) == dir(ntfg.data) assert dir(cfg.results) == dir(ntfg.results) + assert dir(cfg.results.params) == dir(ntfg.results.params) def test_load_group2(tfg2): # Loads file saved from `test_save_group_str2` - file_name = 'test_group_all2' - ntfg2 = load_group(file_name, TEST_DATA_PATH) - assert tfg2.modes.get_modes() == ntfg2.modes.get_modes() - compare_model_objs([tfg2, ntfg2], ['settings', 'meta_data', 'metrics']) + ntfg2 = load_group('test_group_all2', TEST_DATA_PATH) + compare_model_objs([tfg2, ntfg2], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) -def test_load_time(): +def test_load_time(tft): # Loads file saved from `test_save_time` - file_name = 'test_time_all' - - # Load without bands definition - tft = load_time(file_name, TEST_DATA_PATH) + ntft = load_time('test_time_all', TEST_DATA_PATH) assert isinstance(tft, SpectralTimeModel) - assert tft.results.time_results + compare_model_objs([tft, ntft], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tft.data._fields: + assert np.array_equal(getattr(tft.data, data), getattr(ntft.data, data)) + assert tft.results.time_results.keys() == ntft.results.time_results.keys() + for key in tft.results.time_results: + assert np.array_equal(\ + tft.results.time_results[key], ntft.results.time_results[key], equal_nan=True) # Check directory matches (loading didn't add any unexpected attributes) cft = SpectralTimeModel() - assert dir(cft) == dir(tft) - assert dir(cft.data) == dir(tft.data) - assert dir(cft.results) == dir(tft.results) + assert dir(cft) == dir(ntft) + assert dir(cft.algorithm) == dir(ntft.algorithm) + assert dir(cft.data) == dir(ntft.data) + assert dir(cft.results) == dir(ntft.results) + assert dir(cft.results.params) == dir(ntft.results.params) -def test_load_event(): +def test_load_event(tfe): # Loads file saved from `test_save_event` - file_name = 'test_event_all' - - # Load without bands definition - tfe = load_event(file_name, TEST_DATA_PATH) + ntfe = load_event('test_event_all', TEST_DATA_PATH) assert isinstance(tfe, SpectralTimeEventModel) + compare_model_objs([tfe, ntfe], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + for data in tfe.data._fields: + assert np.array_equal(getattr(tfe.data, data), getattr(ntfe.data, data)) assert len(tfe.results) > 1 - assert tfe.results.event_time_results + assert tfe.results.time_results.keys() == ntfe.results.time_results.keys() + for key in tfe.results.time_results: + assert np.array_equal(\ + tfe.results.time_results[key], ntfe.results.time_results[key], equal_nan=True) # Check directory matches (loading didn't add any unexpected attributes) cfe = SpectralTimeEventModel() - assert dir(cfe) == dir(tfe) - assert dir(cfe.data) == dir(tfe.data) - assert dir(cfe.results) == dir(tfe.results) + assert dir(cfe) == dir(ntfe) + assert dir(cfe.algorithm) == dir(ntfe.algorithm) + assert dir(cfe.data) == dir(ntfe.data) + assert dir(cfe.results) == dir(ntfe.results) + assert dir(cfe.results.params) == dir(ntfe.results.params) diff --git a/specparam/tests/measures/test_error.py b/specparam/tests/measures/test_error.py index 7b3c0650..34bed15a 100644 --- a/specparam/tests/measures/test_error.py +++ b/specparam/tests/measures/test_error.py @@ -7,26 +7,29 @@ def test_compute_mean_abs_error(tfm): - error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_mean_squared_error(tfm): - error = compute_mean_squared_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_mean_squared_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_root_mean_squared_error(tfm): - error = compute_root_mean_squared_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_root_mean_squared_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_median_abs_error(tfm): - error = compute_median_abs_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_median_abs_error(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum) assert isinstance(error, float) def test_compute_error(tfm): for metric in ['mae', 'mse', 'rmse', 'medae']: - error = compute_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + error = compute_error(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(error, float) diff --git a/specparam/tests/measures/test_gof.py b/specparam/tests/measures/test_gof.py index a1990749..6b54a4ee 100644 --- a/specparam/tests/measures/test_gof.py +++ b/specparam/tests/measures/test_gof.py @@ -7,16 +7,17 @@ def test_compute_r_squared(tfm): - r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(r_squared, float) def test_compute_adj_r_squared(tfm): - r_squared = compute_adj_r_squared(tfm.data.power_spectrum, tfm.results.modeled_spectrum_, 5) + r_squared = compute_adj_r_squared(tfm.data.power_spectrum, + tfm.results.model.modeled_spectrum, 5) assert isinstance(r_squared, float) def test_compute_gof(tfm): for metric in ['r_squared', 'adj_r_squared']: - gof = compute_gof(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) + gof = compute_gof(tfm.data.power_spectrum, tfm.results.model.modeled_spectrum) assert isinstance(gof, float) diff --git a/specparam/tests/models/test_event.py b/specparam/tests/models/test_event.py index 149827be..24730b9c 100644 --- a/specparam/tests/models/test_event.py +++ b/specparam/tests/models/test_event.py @@ -9,6 +9,7 @@ import numpy as np from specparam.models import SpectralGroupModel, SpectralTimeModel +from specparam.models.utils import compare_model_objs from specparam.sim import sim_spectrogram from specparam.modutils.dependencies import safe_import @@ -41,8 +42,8 @@ def test_event_iter(tfe): def test_event_n_properties(tfe): - assert np.all(tfe.results.n_peaks_) - assert np.all(tfe.results.n_params_) + assert np.all(tfe.results.n_peaks) + assert np.all(tfe.results.n_params) def test_event_fit(): @@ -95,26 +96,27 @@ def test_event_report(skip_if_no_mpl): assert tfe -def test_event_load(): - - file_name_res = 'test_event_res' - file_name_set = 'test_event_set' - file_name_dat = 'test_event_dat' +def test_event_load(tfe): # Test loading results - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_res, TEST_DATA_PATH) - assert tfe.results.event_time_results + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_res', TEST_DATA_PATH) + assert ntfe.results.event_time_results # Test loading settings - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_set, TEST_DATA_PATH) - assert tfe.algorithm.get_settings() + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_set', TEST_DATA_PATH) + assert ntfe.algorithm.get_settings() # Test loading data - tfe = SpectralTimeEventModel(verbose=False) - tfe.load(file_name_dat, TEST_DATA_PATH) - assert np.all(tfe.data.spectrograms) + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_dat', TEST_DATA_PATH) + assert np.all(ntfe.data.spectrograms) + + # Test loading all elements + ntfe = SpectralTimeEventModel(verbose=False) + ntfe.load('test_event_all', TEST_DATA_PATH) + assert compare_model_objs([tfe, ntfe], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_event_get_model(tfe): @@ -122,8 +124,7 @@ def test_event_get_model(tfe): tfm_null = tfe.get_model() assert tfm_null # Check that settings are copied over properly, but data and results are empty - for setting in tfe.algorithm.settings.names: - assert getattr(tfe.algorithm, setting) == getattr(tfm_null.algorithm, setting) + assert tfe.algorithm.settings.values == tfm_null.algorithm.settings.values assert not tfm_null.data.has_data assert not tfm_null.results.has_model @@ -138,7 +139,7 @@ def test_event_get_model(tfe): assert tfm1 assert tfm1.data.has_data assert tfm1.results.has_model - assert np.all(tfm1.results.modeled_spectrum_) + assert np.all(tfm1.results.model.modeled_spectrum) def test_event_get_params(tfe): diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 3d918fe3..6a9c559e 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -12,6 +12,7 @@ from numpy.testing import assert_equal from specparam.measures.metrics import METRICS +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import from specparam.sim import sim_group_power_spectra @@ -64,20 +65,20 @@ def test_has_model(tfg): def test_n_properties(tfg): """Test the n_peaks & n_params property attributes.""" - assert np.all(tfg.results.n_peaks_) - assert np.all(tfg.results.n_params_) + assert np.all(tfg.results.n_peaks) + assert np.all(tfg.results.n_params) def test_n_null(tfg): - """Test the n_null_ property attribute.""" + """Test the n_null property attribute.""" # Since there should have been no failed fits, this should return 0 - assert tfg.results.n_null_ == 0 + assert tfg.results.n_null == 0 def test_null_inds(tfg): - """Test the null_inds_ property attribute.""" + """Test the null_inds property attribute.""" # Since there should be no failed fits, this should return an empty list - assert tfg.results.null_inds_ == [] + assert tfg.results.null_inds == [] def test_fit_nk(): """Test group fit, no knee.""" @@ -152,7 +153,7 @@ def test_fg_fail(): # Use a fg with the max iterations set so low that it will fail to converge ntfg = SpectralGroupModel() - ntfg.algorithm._maxfev = 5 + ntfg.algorithm._cf_settings.maxfev = 5 # Fit models, where some will fail, to see if it completes cleanly ntfg.fit(fs, ps) @@ -174,8 +175,8 @@ def test_fg_fail(): # Test the property attributes related to null model fits # This checks that they do the right thing when there are null fits (failed fits) - assert ntfg.results.n_null_ > 0 - assert ntfg.results.null_inds_ + assert ntfg.results.n_null > 0 + assert ntfg.results.null_inds def test_drop(): """Test function to drop results from group object.""" @@ -273,49 +274,39 @@ def test_load(tfg): """Test load into group object. Note: loads files from test_save_group in specparam/tests/io/test_models.py.""" - file_name_res = 'test_group_res' - file_name_set = 'test_group_set' - file_name_dat = 'test_group_dat' - # Test loading just results ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_res, TEST_DATA_PATH) + ntfg.load('test_group_res', TEST_DATA_PATH) assert len(ntfg.results.group_results) > 0 # Test that settings and data are None for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is None + assert getattr(ntfg.algorithm.settings, setting) is None assert ntfg.data.power_spectra is None # Test loading just settings ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_set, TEST_DATA_PATH) - for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None + ntfg.load('test_group_set', TEST_DATA_PATH) + assert tfg.algorithm.settings.values == ntfg.algorithm.settings.values # Test that results and data are None - for result in tfg.results._fields: - assert np.all(np.isnan(getattr(ntfg.results, result))) + for result in tfg.results.params.fields: + assert np.all(np.isnan(getattr(ntfg.results.params, result))) assert ntfg.data.power_spectra is None # Test loading just data ntfg = SpectralGroupModel(verbose=False) - ntfg.load(file_name_dat, TEST_DATA_PATH) + ntfg.load('test_group_dat', TEST_DATA_PATH) assert ntfg.data.has_data # Test that settings and results are None for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is None - for result in tfg.results._fields: - assert np.all(np.isnan(getattr(ntfg.results, result))) + assert getattr(ntfg.algorithm.settings, setting) is None + for result in tfg.results.params.fields: + assert np.all(np.isnan(getattr(ntfg.results.params, result))) # Test loading all elements ntfg = SpectralGroupModel(verbose=False) - file_name_all = 'test_group_all' - ntfg.load(file_name_all, TEST_DATA_PATH) + ntfg.load('test_group_all', TEST_DATA_PATH) + assert compare_model_objs([tfg, ntfg], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) assert len(ntfg.results.group_results) > 0 - for setting in tfg.algorithm.settings.names: - assert getattr(ntfg.algorithm, setting) is not None - assert ntfg.data.has_data - for meta_dat in tfg.data._meta_fields: - assert getattr(ntfg.data, meta_dat) is not None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" @@ -335,8 +326,7 @@ def test_get_model(tfg): tfm_null = tfg.get_model() assert tfm_null # Check that settings are copied over properly, but data and results are empty - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(tfm_null.algorithm, setting) + assert tfg.algorithm.settings.values == tfm_null.algorithm.settings.values assert not tfm_null.data.has_data assert not tfm_null.results.has_model @@ -344,15 +334,14 @@ def test_get_model(tfg): tfm0 = tfg.get_model(0, False) assert tfm0 # Check that settings are copied over properly - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(tfm0.algorithm, setting) + assert tfg.algorithm.settings.values == tfm0.algorithm.settings.values # Check with regenerating tfm1 = tfg.get_model(1, True) assert tfm1 # Check that regenerated model is created - for result in tfg.results._fields: - assert np.all(getattr(tfm1.results, result)) + for result in tfg.results.params.fields: + assert np.all(getattr(tfm1.results.params, result)) # Test when object has no data (clear a copy of tfg) new_tfg = tfg.copy() @@ -383,9 +372,8 @@ def test_get_group(tfg): assert isinstance(nfg2, SpectralGroupModel) # Check that settings are copied over properly - for setting in tfg.algorithm.settings.names: - assert getattr(tfg.algorithm, setting) == getattr(nfg1.algorithm, setting) - assert getattr(tfg.algorithm, setting) == getattr(nfg2.algorithm, setting) + assert tfg.algorithm.settings.values == nfg1.algorithm.settings.values + assert tfg.algorithm.settings.values == nfg2.algorithm.settings.values # Check that data info is copied over properly for meta_dat in tfg.data._meta_fields: diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index 23210a6e..05dea871 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -14,6 +14,7 @@ from specparam.measures.metrics import METRICS from specparam.sim import gen_freqs, sim_power_spectrum from specparam.modes.definitions import AP_MODES, PE_MODES +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import from specparam.modutils.errors import DataError, NoDataError, InconsistentDataError @@ -56,8 +57,8 @@ def test_has_model(tfm): def test_n_properties(tfm): - assert tfm.results.n_peaks_ - assert tfm.results.n_params_ + assert tfm.results.n_peaks + assert tfm.results.n_params def test_fit_nk(): """Test fit, no knee.""" @@ -71,11 +72,11 @@ def test_fit_nk(): tfm.fit(xs, ys) # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.results.aperiodic_params_, [0.5, 0.1]) + assert np.allclose(ap_params, tfm.results.params.aperiodic, [0.5, 0.1]) # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.results.gaussian_params_[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, tfm.results.params.gaussian[ii], [2.0, 0.5, 1.0]) def test_fit_nk_noise(): """Test fit on noisy data, to make sure nothing breaks.""" @@ -102,11 +103,11 @@ def test_fit_knee(): tfm.fit(xs, ys) # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.results.aperiodic_params_, [1, 2, 0.2]) + assert np.allclose(ap_params, tfm.results.params.aperiodic, [1, 2, 0.2]) # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.results.gaussian_params_[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, tfm.results.params.gaussian[ii], [2.0, 0.5, 1.0]) def test_fit_default_metrics(): """Test goodness of fit & error metrics, post model fitting.""" @@ -115,7 +116,7 @@ def test_fit_default_metrics(): # Hack fake data with known properties: total error magnitude 2 tfm.data.power_spectrum = np.array([1, 2, 3, 4, 5]) - tfm.results.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) + tfm.results.model.modeled_spectrum = np.array([1, 2, 5, 4, 5]) # Check default goodness of fit and error measures tfm.results.metrics.compute_metrics(tfm.data, tfm.results) @@ -194,50 +195,43 @@ def test_load(tfm): # Test loading just results ntfm = SpectralModel(verbose=False) - file_name_res = 'test_model_res' - ntfm.load(file_name_res, TEST_DATA_PATH) + ntfm.load('test_model_res', TEST_DATA_PATH) # Check that result attributes get filled - for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) + for result in tfm.results.params.fields: + assert not np.all(np.isnan(getattr(ntfm.results.params, result))) # Test that settings and data are None for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is None + assert getattr(ntfm.algorithm.settings, setting) is None assert ntfm.data.power_spectrum is None # Test loading just settings ntfm = SpectralModel(verbose=False) - file_name_set = 'test_model_set' - ntfm.load(file_name_set, TEST_DATA_PATH) - for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None + ntfm.load('test_model_set', TEST_DATA_PATH) + assert tfm.algorithm.settings.values == ntfm.algorithm.settings.values # Test that results and data are None - for result in tfm.results._fields: - assert np.all(np.isnan(getattr(ntfm.results, result))) + for result in tfm.results.params.fields: + assert np.all(np.isnan(getattr(ntfm.results.params, result))) assert ntfm.data.power_spectrum is None # Test loading just data ntfm = SpectralModel(verbose=False) - file_name_dat = 'test_model_dat' - ntfm.load(file_name_dat, TEST_DATA_PATH) - assert ntfm.data.power_spectrum is not None + ntfm.load('test_model_dat', TEST_DATA_PATH) + assert ntfm.data.has_data + assert np.array_equal(tfm.data.power_spectrum, ntfm.data.power_spectrum) # Test that settings and results are None for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is None - for result in tfm.results._fields: - assert np.all(np.isnan(getattr(ntfm.results, result))) + assert getattr(ntfm.algorithm.settings, setting) is None + for result in tfm.results.params.fields: + assert np.all(np.isnan(getattr(ntfm.results.params, result))) # Test loading all elements ntfm = SpectralModel(verbose=False) - file_name_all = 'test_model_all' - ntfm.load(file_name_all, TEST_DATA_PATH) - for result in tfm.results._fields: - assert not np.all(np.isnan(getattr(ntfm.results, result))) - for setting in tfm.algorithm.settings.names: - assert getattr(ntfm.algorithm, setting) is not None + ntfm.load('test_model_all', TEST_DATA_PATH) + assert compare_model_objs([tfm, ntfm], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) for data in tfm.data._fields: - assert getattr(ntfm.data, data) is not None - for meta_dat in tfm.data._meta_fields: - assert getattr(ntfm.data, meta_dat) is not None + assert np.array_equal(getattr(tfm.data, data), getattr(ntfm.data, data)) + for result in tfm.results.params.fields: + assert not np.all(np.isnan(getattr(ntfm.results.params, result))) def test_add_data(tresults): """Tests method to add data to model objects.""" @@ -296,7 +290,7 @@ def test_get_component(tfm): for comp in ['full', 'aperiodic', 'peak']: for space in ['log', 'linear']: - assert isinstance(tfm.results.get_component(comp, space), np.ndarray) + assert isinstance(tfm.results.model.get_component(comp, space), np.ndarray) def test_prints(tfm): """Test methods that print (alias and pass through methods). @@ -320,19 +314,14 @@ def test_resets(): # Note: uses it's own tfm, to not clear the global one tfm = get_tfm() - tfm._reset_data_results(True, True, True) - tfm.algorithm._reset_internal_settings() - for field in tfm.data._fields: assert getattr(tfm.data, field) is None - model_components = ['modeled_spectrum_', '_spectrum_flat', - '_spectrum_peak_rm', '_ap_fit', '_peak_fit'] - for field in model_components: - assert getattr(tfm.results, field) is None - for field in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, field))) - assert tfm.data.freqs is None and tfm.results.modeled_spectrum_ is None + for key, value in tfm.results.model.__dict__.items(): + assert value is None + for field in tfm.results.params.fields: + assert np.all(np.isnan(getattr(tfm.results.params, field))) + assert tfm.data.freqs is None and tfm.results.model.modeled_spectrum is None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" @@ -347,13 +336,13 @@ def test_fit_failure(): ## Induce a runtime error, and check it runs through tfm = SpectralModel(verbose=False) - tfm.algorithm._maxfev = 2 + tfm.algorithm._cf_settings.maxfev = 2 tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset - for result in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, result))) + for result in tfm.results.params.fields: + assert np.all(np.isnan(getattr(tfm.results.params, result))) ## Monkey patch to check errors in general # This mimics the main fit-failure, without requiring bad data / waiting for it to fail. @@ -366,14 +355,14 @@ def raise_runtime_error(*args, **kwargs): tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset - for result in tfm.results._fields: - assert np.all(np.isnan(getattr(tfm.results, result))) + for result in tfm.results.params.fields: + assert np.all(np.isnan(getattr(tfm.results.params, result))) def test_debug(): """Test model object in debug state, including with fit failures.""" tfm = SpectralModel(verbose=False) - tfm.algorithm._maxfev = 2 + tfm.algorithm._cf_settings.maxfev = 2 tfm.algorithm.set_debug(True) assert tfm.algorithm._debug is True @@ -406,8 +395,8 @@ def test_set_checks(): # Reset checks to true tfm.data.set_checks(True, True) - assert tfm.data._check_freqs is True - assert tfm.data._check_data is True + assert tfm.data.checks['freqs'] is True + assert tfm.data.checks['data'] is True def test_to_df(tfm, tbands, skip_if_no_pandas): diff --git a/specparam/tests/models/test_time.py b/specparam/tests/models/test_time.py index 8a5e7cf3..14d10bd7 100644 --- a/specparam/tests/models/test_time.py +++ b/specparam/tests/models/test_time.py @@ -9,6 +9,7 @@ import numpy as np from specparam.sim import sim_spectrogram +from specparam.models.utils import compare_model_objs from specparam.modutils.dependencies import safe_import pd = safe_import('pandas') @@ -40,8 +41,8 @@ def test_time_iter(tft): def test_time_n_properties(tft): - assert np.all(tft.results.n_peaks_) - assert np.all(tft.results.n_params_) + assert np.all(tft.results.n_peaks) + assert np.all(tft.results.n_params) def test_time_fit(): @@ -78,26 +79,27 @@ def test_time_report(skip_if_no_mpl): assert tft -def test_time_load(): - - file_name_res = 'test_time_res' - file_name_set = 'test_time_set' - file_name_dat = 'test_time_dat' +def test_time_load(tft): # Test loading results - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_res, TEST_DATA_PATH) - assert tft.results.time_results + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_res', TEST_DATA_PATH) + assert ntft.results.time_results # Test loading settings - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_set, TEST_DATA_PATH) - assert tft.algorithm.get_settings() + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_set', TEST_DATA_PATH) + assert ntft.algorithm.get_settings() # Test loading data - tft = SpectralTimeModel(verbose=False) - tft.load(file_name_dat, TEST_DATA_PATH) - assert np.all(tft.data.power_spectra) + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_dat', TEST_DATA_PATH) + assert np.all(ntft.data.spectrogram) + + # Test loading all elements + ntft = SpectralTimeModel(verbose=False) + ntft.load('test_time_all', TEST_DATA_PATH) + assert compare_model_objs([tft, ntft], ['modes', 'settings', 'meta_data', 'bands', 'metrics']) def test_time_drop(): diff --git a/specparam/tests/models/test_utils.py b/specparam/tests/models/test_utils.py index b34e0399..8e2d9903 100644 --- a/specparam/tests/models/test_utils.py +++ b/specparam/tests/models/test_utils.py @@ -33,17 +33,25 @@ def test_compare_model_objs(tfm, tfg): f_obj2 = f_obj.copy() - assert compare_model_objs([f_obj, f_obj2], ['settings', 'meta_data', 'metrics']) + assert compare_model_objs([f_obj, f_obj2], + ['modes', 'settings', 'meta_data', 'bands', 'metrics']) + + assert compare_model_objs([f_obj, f_obj2], 'modes') + f_obj2.add_modes('knee', 'cauchy') + assert not compare_model_objs([f_obj, f_obj2], 'modes') assert compare_model_objs([f_obj, f_obj2], 'settings') - f_obj2.algorithm.peak_width_limits = [2, 4] - f_obj2.algorithm._reset_internal_settings() + f_obj2.algorithm.settings.peak_width_limits = [2, 4] assert not compare_model_objs([f_obj, f_obj2], 'settings') assert compare_model_objs([f_obj, f_obj2], 'meta_data') f_obj2.data.freq_range = [5, 25] assert not compare_model_objs([f_obj, f_obj2], 'meta_data') + assert compare_model_objs([f_obj, f_obj2], 'bands') + f_obj2.results.add_bands({'new' : [1, 4]}) + assert not compare_model_objs([f_obj, f_obj2], 'bands') + assert compare_model_objs([f_obj, f_obj2], 'metrics') f_obj2.results.metrics.add_metric(METRICS['error_rmse']) assert not compare_model_objs([f_obj, f_obj2], 'metrics') @@ -128,9 +136,7 @@ def test_combine_errors(tfm, tfg): # Incompatible settings for f_obj in [tfm, tfg]: f_obj2 = f_obj.copy() - f_obj2.algorithm.peak_width_limits = [2, 4] - f_obj2.algorithm._reset_internal_settings() - + f_obj2.algorithm.settings.peak_width_limits = [2, 4] with raises(IncompatibleSettingsError): combine_model_objs([f_obj, f_obj2]) diff --git a/specparam/tests/objs/test_components.py b/specparam/tests/objs/test_components.py new file mode 100644 index 00000000..b9fa3d32 --- /dev/null +++ b/specparam/tests/objs/test_components.py @@ -0,0 +1,13 @@ +"""Tests for specparam.objs.components.""" + +from specparam.objs.components import * + +################################################################################################### +################################################################################################### + +## ModelComponents object + +def test_model_components(): + + mc = ModelComponents() + assert mc diff --git a/specparam/tests/objs/test_data.py b/specparam/tests/objs/test_data.py index 87aeab46..9b8593c0 100644 --- a/specparam/tests/objs/test_data.py +++ b/specparam/tests/objs/test_data.py @@ -1,4 +1,4 @@ -"""Tests for specparam.objs.data, including the data object and it's methods.""" +"""Tests for specparam.objs.data.""" from specparam.data import SpectrumMetaData, ModelChecks @@ -44,15 +44,14 @@ def test_data_get_set_checks(tdata): tdata.set_checks(False, False) tchecks1 = tdata.get_checks() assert isinstance(tchecks1, ModelChecks) - assert tdata._check_freqs == tchecks1.check_freqs == False - assert tdata._check_data == tchecks1.check_data == False + assert tdata.checks['freqs'] == tchecks1.check_freqs == False + assert tdata.checks['data'] == tchecks1.check_data == False tdata.set_checks(True, True) tchecks2 = tdata.get_checks() assert isinstance(tchecks2, ModelChecks) - assert tdata._check_freqs == tchecks2.check_freqs == True - assert tdata._check_data == tchecks2.check_data == True - + assert tdata.checks['freqs'] == tchecks2.check_freqs == True + assert tdata.checks['data'] == tchecks2.check_data == True @plot_test def test_data_plot(tdata, skip_if_no_mpl): diff --git a/specparam/tests/objs/test_metrics.py b/specparam/tests/objs/test_metrics.py index 0885453a..f06e0a95 100644 --- a/specparam/tests/objs/test_metrics.py +++ b/specparam/tests/objs/test_metrics.py @@ -23,7 +23,7 @@ def test_metric_kwargs(tfm): metric = Metric('gof', 'ar2', compute_adj_r_squared, {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}) + results.params.peak.size + results.params.aperiodic.size}) assert isinstance(metric, Metric) assert isinstance(metric.label, str) @@ -55,8 +55,8 @@ def test_metrics_obj(tfm): def test_metrics_dict(tfm): - er_met_def = {'type' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} - gof_met_def = {'type' : 'gof', 'measure' : 'rsquared', 'func' : compute_r_squared} + er_met_def = {'category' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} + gof_met_def = {'category' : 'gof', 'measure' : 'rsquared', 'func' : compute_r_squared} metrics = Metrics([er_met_def, gof_met_def]) assert isinstance(metrics, Metrics) @@ -73,11 +73,11 @@ def test_metrics_dict(tfm): def test_metrics_kwargs(tfm): - er_met_def = {'type' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} - ar2_met_def = {'type' : 'gof', 'measure' : 'arsquared', + er_met_def = {'category' : 'error', 'measure' : 'mae', 'func' : compute_mean_abs_error} + ar2_met_def = {'category' : 'gof', 'measure' : 'arsquared', 'func' : compute_adj_r_squared, 'kwargs' : {'n_params' : lambda data, results: \ - results.peak_params_.size + results.aperiodic_params_.size}} + results.params.peak.size + results.params.aperiodic.size}} metrics = Metrics([er_met_def, ar2_met_def]) assert isinstance(metrics, Metrics) diff --git a/specparam/tests/objs/test_params.py b/specparam/tests/objs/test_params.py new file mode 100644 index 00000000..5a19426e --- /dev/null +++ b/specparam/tests/objs/test_params.py @@ -0,0 +1,13 @@ +"""Tests for specparam.objs.params.""" + +from specparam.objs.params import * + +################################################################################################### +################################################################################################### + +## ModelParameters object + +def test_model_parameters(): + + mp = ModelParameters() + assert mp diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py index d2317590..c34bee29 100644 --- a/specparam/tests/objs/test_results.py +++ b/specparam/tests/objs/test_results.py @@ -1,4 +1,4 @@ -"""Tests for specparam.objs.results, including the data object and it's methods.""" +"""Tests for specparam.objs.results.""" from specparam.objs.results import * @@ -18,8 +18,8 @@ def test_results_results(tresults): tres.add_results(tresults) assert tres.has_model - for result in tres._fields: - assert np.array_equal(getattr(tres, result), getattr(tresults, result.strip('_'))) + for result in tres.params.fields: + assert np.array_equal(getattr(tres.params, result), getattr(tresults, result + '_params')) results_out = tres.get_results() assert results_out == tresults diff --git a/specparam/tests/tdata.py b/specparam/tests/tdata.py index 6791d2d0..2b585d3a 100644 --- a/specparam/tests/tdata.py +++ b/specparam/tests/tdata.py @@ -53,7 +53,8 @@ def get_tdata2d(): def get_tfm(): """Get a model object, with a fit power spectrum, for testing.""" - tfm = SpectralModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfm = SpectralModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8]) tfm.fit(*sim_power_spectrum(*default_spectrum_params())) return tfm @@ -62,6 +63,7 @@ def get_tfm2(): """Get a model object, with a fit power spectrum, for testing - custom metrics & modes.""" tfm2 = SpectralModel(bands=Bands({'alpha' : (7, 14), 'beta' : [15, 30]}), + min_peak_height=0.05, peak_width_limits=[1, 8], metrics=['error_mse', 'gof_adjrsquared'], aperiodic_mode='knee', periodic_mode='gaussian') tfm2.fit(*sim_power_spectrum(*default_spectrum_params())) @@ -72,7 +74,8 @@ def get_tfg(): """Get a group object, with some fit power spectra, for testing.""" n_spectra = 3 - tfg = SpectralGroupModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfg = SpectralGroupModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8]) tfg.fit(*sim_group_power_spectra(n_spectra, *default_group_params())) return tfg @@ -82,6 +85,7 @@ def get_tfg2(): n_spectra = 3 tfg2 = SpectralGroupModel(bands=Bands({'alpha' : (7, 14), 'beta' : [15, 30]}), + min_peak_height=0.05, peak_width_limits=[1, 8], metrics=['error_mse', 'gof_adjrsquared'], aperiodic_mode='knee', periodic_mode='gaussian') tfg2.fit(*sim_group_power_spectra(n_spectra, *default_group_params())) @@ -94,7 +98,8 @@ def get_tft(): n_spectra = 3 xs, ys = sim_spectrogram(n_spectra, *default_group_params()) - tft = SpectralTimeModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tft = SpectralTimeModel(bands=Bands({'alpha' : (7, 14)}), \ + min_peak_height=0.05, peak_width_limits=[1, 8],) tft.fit(xs, ys) return tft @@ -106,7 +111,8 @@ def get_tfe(): xs, ys = sim_spectrogram(n_spectra, *default_group_params()) ys = [ys, ys] - tfe = SpectralTimeEventModel(bands=Bands({'alpha' : (7, 14)}), verbose=False) + tfe = SpectralTimeEventModel(bands=Bands({'alpha' : (7, 14)}), + min_peak_height=0.05, peak_width_limits=[1, 8],) tfe.fit(xs, ys) return tfe