Skip to content
Commits on Source (4)
...@@ -109,7 +109,6 @@ def calculate_nlm_file( ...@@ -109,7 +109,6 @@ def calculate_nlm_file(
if fs_r != FS_GLOBAL: if fs_r != FS_GLOBAL:
reference = soxr.resample(reference, fs_r, fs, quality=resample_quality) reference = soxr.resample(reference, fs_r, fs, quality=resample_quality)
# sf.write(r"d:\SVN\Projects\ts104063-nonlinearity-measure\data\TR-41\tmp.wav", np.stack((measured.squeeze(), reference), axis=1), FS_GLOBAL)
return calculate_nlm(measured, reference, fs=fs, **kwargs) return calculate_nlm(measured, reference, fs=fs, **kwargs)
...@@ -163,7 +162,7 @@ def calculate_nlm( ...@@ -163,7 +162,7 @@ def calculate_nlm(
measurement = compensate_delays(measurement, delays, fs, reference.shape[0]) measurement = compensate_delays(measurement, delays, fs, reference.shape[0])
# Clause 5.2: ensure same length # Clause 5.2: ensure same length
min_len, max_len = ( max_len, min_len = (
max(measurement.shape[0], reference.shape[0]), max(measurement.shape[0], reference.shape[0]),
min(measurement.shape[0], reference.shape[0]), min(measurement.shape[0], reference.shape[0]),
) )
...@@ -234,13 +233,14 @@ def calculate_nlm( ...@@ -234,13 +233,14 @@ def calculate_nlm(
for key in ["measurement", "linear"]: for key in ["measurement", "linear"]:
spectra.set_value(key, spectra.get_value(key).get_aggregate()) spectra.set_value(key, spectra.get_value(key).get_aggregate())
# Clause 5.5/5.6: # note: subtraction of two SpectrumAvgBase objects (z = a - b) always ensures a result larger than zero -> z = max(a-b,0)!
# - ensure that linear estimate does not exceed measurement (minus noise estimate) # Clause 5.7, Eq. (14): ensure that linear estimate does not exceed measurement (minus noise estimate)
# - subtract estimates of linear and noise spectrum from measurement to obtain non-linearity spectrum
spectra.linear = min(spectra.measurement - spectra.noise, spectra.linear) spectra.linear = min(spectra.measurement - spectra.noise, spectra.linear)
# Clause 5.7, Eq. (15): subtract estimates of linear and noise spectrum from measurement to obtain non-linearity spectrum
spectra.nonlinear = spectra.measurement - spectra.noise - spectra.linear spectra.nonlinear = spectra.measurement - spectra.noise - spectra.linear
# Clause 5.7: aggregate to single value vs frequency (from fmin to fmax) # Clause 5.8: aggregate to single value vs frequency (from fmin to fmax)
levels = NLMDict(None) levels = NLMDict(None)
for name, spec in spectra.asdict().items(): for name, spec in spectra.asdict().items():
# TODO: apply weighting to non-linear component or to estimated noise? # TODO: apply weighting to non-linear component or to estimated noise?
......
...@@ -5,80 +5,71 @@ Created on May 29 2024 16:02 ...@@ -5,80 +5,71 @@ Created on May 29 2024 16:02
@author: Jan.Reimes @author: Jan.Reimes
""" """
import numpy as np
import logging import logging
from scipy import signal from enum import auto
from typing import List, Tuple, Optional from typing import Optional
import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from scipy import signal
from scipy.signal import freqz from scipy.signal import freqz
from ensure import Check
from . import MIN_LEVEL_DB, MIN_LEVEL_POW, MIN_LEVEL_RMS, FreqRespResult from . import MIN_COHERENCE_RATIO_ERROR, MIN_COHERENCE_RATIO_WARNING, MIN_LEVEL_POW, FreqRespResult, ParsableStrEnum
from .spectrum import FftSpectrumVsTime, AggregateFn, check_spectrum_arguments # get_active_frames, csd_vs_time, fft_to_octave_bands from .spectrum import AggregateFn, FftSpectrumVsTime, check_spectrum_arguments
#from .utils.misc import check_spectrum_arguments
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
MIN_COHERENCE_RATIO = 0.25
NDIntArray = NDArray[np.int64]
def smooth_impulse_response(h: NDArray, fs: int, fraction: int, nfft: Optional[int] = None,
window: str = 'flattop', min_phase: bool = False, ntaps_out: Optional[int] = None) -> NDArray:
Check(fraction).is_greater_than(0).or_raise(ValueError, f"Fractional bands parameter must be >= 1")
Check(fraction).is_an(int).or_raise(ValueError, f"Fractional bands parameter must be integer")
nfft = nfft or 2 * h.shape[0]
ntaps_out = ntaps_out or h.shape[0]
# impulse response -> frequency response
w, H2 = freqz(h, 1, worN=nfft, fs=fs)
# aggregate (=smooth) frequency response over fractional-octave bands NDIntArray = NDArray[np.int64]
f_bands, H_bands = fft_to_octave_bands(H2, fs, fraction, is_transfer_function=True)
# convert back to impulse response
if f_bands[0] > 0.0:
f_bands = np.insert(f_bands, 0, 0.0)
H_bands = np.insert(H_bands, 0, 0.0)
if f_bands[-1] < fs / 2:
f_bands = np.append(f_bands, fs / 2)
H_bands = np.append(H_bands, 0.0)
h_smooth = signal.firwin2(ntaps_out, f_bands, H_bands, fs=fs, window=window, antisymmetric=False)
# convert to minimum phase if needed class TrFuncEstimator(ParsableStrEnum):
if min_phase: """Transfer function estimator string-enum."""
h_smooth = signal.minimum_phase(h_smooth)
return h_smooth Hs = auto() # symmetric estimate / geometric mean of H1 and H2 (real-valued result)
H0 = auto() # no noise on input or output -> Y/X
def estimate_ir_csd(y: NDArray, x: NDArray, fs: int, ntaps: int = 512, min_coherence: float = 0.025, H1 = auto() # noise on output - default estimator
# spectrum parameters H2 = auto() # noise on input
window: str = 'flattop', nperseg: Optional[int] = None, H3 = auto() # arithmetic mean of H1 and H2 (rarely used)
noverlap: Optional[int] = None, nfft: Optional[int] = None, """
fmin: Optional[float] = None, fmax: Optional[float] = None, Hv is for future study, see also:
I. J. Leontaritis and S. A. Billings (1987), "Input-output parametric models for non-linear systems"
M. J. Roorda and L. Ljung (1994), "Least-squares estimation of transfer functions"
"""
Hv = auto() # total least squares estimator
def estimate_ir_csd(
y: NDArray,
x: NDArray,
fs: int,
ntaps: int = 512,
min_coherence: Optional[float] = None,
window: str = "flattop",
nperseg: Optional[int] = None,
noverlap: Optional[int] = None,
nfft: Optional[int] = None,
fmin: Optional[float] = None,
fmax: Optional[float] = None,
inactivity_fraction: Optional[float] = None, inactivity_fraction: Optional[float] = None,
agg_fn: Optional[AggregateFn] = None, agg_fn: Optional[AggregateFn] = None,
**kwargs) -> FreqRespResult: # Tuple[NDArray, bool, NDArray, NDArray]: bandpass_weighting: bool = False,
tf_estimator: TrFuncEstimator | str = "H1",
**kwargs,
) -> FreqRespResult:
# check input arguments # check input arguments
axis: int = kwargs.pop('axis', 0) axis: int = kwargs.pop("axis", 0)
agg_fn = agg_fn or 'mean' agg_fn = agg_fn or "mean"
direct_irfft: bool = kwargs.get('direct_irfft', False) direct_irfft: bool = kwargs.get("direct_irfft", False)
min_phase: bool = kwargs.get('min_phase', False) min_phase: bool = kwargs.get("min_phase", False)
fraction: int = kwargs.pop('fraction', None) or -1 min_coherence = min_coherence or MIN_COHERENCE_RATIO_ERROR
apply_smoothing: bool = fraction > 0 tf_estimator = TrFuncEstimator.get(tf_estimator) # validate estimator value
ntaps_smooth: Optional[int] = kwargs.pop('ntaps_smooth', None)
min_coherence = min_coherence if min_coherence is not None else 0.0
fmin = fmin or 0.0 fmin = fmin or 0.0
fmax = fmax or (fs / 2) fmax = fmax or (fs / 2)
# spectrum parameters # spectrum parameters
spec_args = check_spectrum_arguments(fs, window, nperseg, noverlap, nfft, **kwargs) spec_args = check_spectrum_arguments(fs, window, nperseg, noverlap, nfft, **kwargs)
ntaps = np.minimum(ntaps, spec_args['nfft']) ntaps = np.minimum(ntaps, spec_args["nfft"])
# for firwin2: antisymmetric=False & odd ntaps -> type I filter -> no requirement on magnitude for f=0 and f=fs/2 # for firwin2: antisymmetric=False & odd ntaps -> type I filter -> no requirement on magnitude for f=0 and f=fs/2
ntaps_fir = ntaps + 1 if ntaps % 2 == 0 else ntaps ntaps_fir = ntaps + 1 if ntaps % 2 == 0 else ntaps
...@@ -89,12 +80,9 @@ def estimate_ir_csd(y: NDArray, x: NDArray, fs: int, ntaps: int = 512, min_coher ...@@ -89,12 +80,9 @@ def estimate_ir_csd(y: NDArray, x: NDArray, fs: int, ntaps: int = 512, min_coher
spec_xy = FftSpectrumVsTime.from_signals(x, y=y, axis=axis, **spec_args) spec_xy = FftSpectrumVsTime.from_signals(x, y=y, axis=axis, **spec_args)
spec_xx = FftSpectrumVsTime.from_signals(x, y=None, axis=axis, **spec_args) spec_xx = FftSpectrumVsTime.from_signals(x, y=None, axis=axis, **spec_args)
spec_yy = FftSpectrumVsTime.from_signals(y, y=None, axis=axis, **spec_args) spec_yy = FftSpectrumVsTime.from_signals(y, y=None, axis=axis, **spec_args)
spec_yx = FftSpectrumVsTime.from_signals(x, y=x, axis=axis, **spec_args)
t, f = spec_xy.time, spec_xy.get_freqs(all_freqs=True) _, f = spec_xy.time, spec_xy.get_freqs(all_freqs=True)
#f, t, pxy = csd_vs_time(x, y, **spec_args)
#_, _, pxx = csd_vs_time(x, x, **spec_args)
#_, _, pyy = csd_vs_time(y, y, **spec_args)
# check active frames -> all frames above inactivity-percentile # check active frames -> all frames above inactivity-percentile
# idx_active = np.ones_like(t, dtype=bool) # idx_active = np.ones_like(t, dtype=bool)
...@@ -103,62 +91,80 @@ def estimate_ir_csd(y: NDArray, x: NDArray, fs: int, ntaps: int = 512, min_coher ...@@ -103,62 +91,80 @@ def estimate_ir_csd(y: NDArray, x: NDArray, fs: int, ntaps: int = 512, min_coher
spec_xx = spec_xx.with_active_frames(idx_active) spec_xx = spec_xx.with_active_frames(idx_active)
spec_yy = spec_yy.with_active_frames(idx_active) spec_yy = spec_yy.with_active_frames(idx_active)
spec_xy = spec_xy.with_active_frames(idx_active) spec_xy = spec_xy.with_active_frames(idx_active)
spec_yx = spec_yx.with_active_frames(idx_active)
#idx_active = spec_xx.get_level_vs_time().get_active_frames(inactivity_fraction, weighting=None)
#idx_active = get_active_frames(x, inactivity_fraction, **spec_args)
# average spectrum over (active) time # average spectrum over (active) time
pxy = spec_xy.get_aggregate(agg_fn).get_values('pow', all_freqs=True) pxy = spec_xy.get_aggregate(agg_fn).get_values("pow", all_freqs=True)
pxx = spec_xx.get_aggregate(agg_fn).get_values('pow', all_freqs=True) pxx = spec_xx.get_aggregate(agg_fn).get_values("pow", all_freqs=True)
pyy = spec_yy.get_aggregate(agg_fn).get_values('pow', all_freqs=True) pyy = spec_yy.get_aggregate(agg_fn).get_values("pow", all_freqs=True)
pyx = spec_yx.get_aggregate(agg_fn).get_values("pow", all_freqs=True)
#pxy = pxy[:, idx_active].mean(axis=-1)
#pxx = pxx[:, idx_active].mean(axis=-1)
#pyy = pyy[:, idx_active].mean(axis=-1)
# limit frequency range # limit frequency range
idx_fmin = np.argmin(np.abs(f - fmin)) idx_f = (f >= fmin) & (f <= fmax) # frequency indices that are "in band"
idx_fmax = np.argmin(np.abs(f - fmax))
idx_f = (f >= fmin) & (f <= fmax)
idx_f[[0, -1]] = False # always exclude first/last frequency (=include for firwin2)
# "weighting" function for transfer function H1: # transfer function
H1xy = np.divide(pxy, np.maximum(pxx, MIN_LEVEL_POW))
H1yx = np.divide(pyy, np.maximum(pyx, MIN_LEVEL_POW))
if tf_estimator == TrFuncEstimator.H0:
# no noise on input or output -> simply Y/X
H = np.divide(pyy, np.maximum(pxx, MIN_LEVEL_POW))
elif tf_estimator == TrFuncEstimator.Hs:
# symmetric estimate / geometric mean of H1 and H2
H = np.abs(np.sqrt(H1xy * H1yx))
elif tf_estimator == TrFuncEstimator.H1:
# noise on output - default estimator
H = H1xy
elif tf_estimator == TrFuncEstimator.H2:
# noise on input (rarely used)
H = H1yx
elif tf_estimator == TrFuncEstimator.H3:
# arithmetic mean of H1 and H2 (used even more rarely)
H = (H1xy + H1yx) / 2.0
elif tf_estimator == TrFuncEstimator.Hv:
# total least squares estimator; estimate eta^2 (noise power ratio output/input)
pvv = spec_yy.get_noise_estimate(percentile=0.03).get_values("pow", all_freqs=True) # noise estimate from y
puu = spec_xx.get_noise_estimate(percentile=0.03).get_values("pow", all_freqs=True) # noise estimate from x - typically almost zero (clean speech signal)
eta2 = pvv / np.maximum(puu, MIN_LEVEL_POW)
H = pyy - (eta2 * pxx) + np.sqrt(np.power(pyy - (eta2 * pxx), 2) + 4 * eta2 * np.power(np.abs(pxy), 2))
H /= np.maximum(2 * pyx, MIN_LEVEL_POW)
else:
raise ValueError(f"Unknown transfer function estimator: {tf_estimator.name}")
# optional frequency weighting for the ends of the spectrum
weighting = np.ones_like(f)
if bandpass_weighting:
# "weighting" or "fade" function for transfer function H:
# - ramp up from 0.0 to 1.0 between 0 and fmin # - ramp up from 0.0 to 1.0 between 0 and fmin
# - constant 1.0 between fmin and fmax # - constant 1.0 between fmin and fmax
# - ramp down from 1.0 to 0.0 between fmax and fs/2 # - ramp down from 1.0 to 0.0 between fmax and fs/2
weighting = np.ones_like(f) idx_fmin = np.argmin(np.abs(f - fmin))
weighting[:idx_fmin] = np.linspace(0.0, 1.0, idx_fmin) idx_fmax = np.argmin(np.abs(f - fmax))
weighting[idx_fmax:] = np.linspace(1.0, 0.0, f.shape[0] - idx_fmax) weighting[: (idx_fmin // 2)] = np.linspace(0.0, 1.0, idx_fmin // 2)
weighting[idx_fmax + (f.shape[0] - idx_fmax) // 2 :] = np.linspace(1.0, 0.0, (f.shape[0] - idx_fmax) // 2)
# transfer function H *= weighting
H1 = np.divide(pxy, np.maximum(pxx, MIN_LEVEL_POW))
H1 *= weighting
#H1 = np.sqrt(H1)
# check coherence within frequency range # check coherence within frequency range
#cxy = np.abs(pxy) ** 2 / np.maximum(pxx * pyy, MIN_LEVEL_POW) # identical to scipy.signal.coherence(y,x) cxy = np.power(np.abs(pxy), 2) / np.maximum(pxx * pyy, MIN_LEVEL_POW**2) # identical to scipy.signal.coherence(y,x)
cxy = np.abs(pxy) / np.maximum(np.sqrt(pxx*pyy), MIN_LEVEL_POW) # identical to scipy.signal.coherence(y,x) idx_coherent = cxy >= min_coherence # frequency indices that are considered as coherent
idx_coherent = cxy >= min_coherence idx_coherent_f = idx_coherent & idx_f # frequency indices that are considered as coherent and "in band"
idx_coherent_f = idx_coherent & idx_f
# store "raw" frequency response H1 in result
result = FreqRespResult(np.zeros(ntaps), f.copy(), H1.copy(), cxy, idx_coherent)
# check coherence ratio within frequency range # check coherence ratio within frequency range
coherence_ratio = idx_coherent_f.sum() / idx_f.sum() coherence_ratio = idx_coherent_f.sum() / idx_f.sum()
if coherence_ratio < MIN_COHERENCE_RATIO: if coherence_ratio < MIN_COHERENCE_RATIO_WARNING:
log.warning(f"Low coherence ratio vs frequency: {coherence_ratio:.2f} - " log.warning(f"Low coherence ratio vs frequency: {coherence_ratio:.2f} - consider decreasing min_coherence and/or check signals!")
f"consider decreasing min_coherence and/or check signals!")
# estimate IR # estimate IR
if direct_irfft: if direct_irfft:
# interpolate H1 between incoherent frequencies # interpolate H1 between incoherent frequencies
H1[idx_f] = np.interp(f[idx_f], f[idx_coherent_f], H1[idx_coherent_f]) H[idx_f] = np.interp(f[idx_f], f[idx_coherent_f], H[idx_coherent_f])
# direct inverse real FFT -> always minimum phase # direct inverse real FFT -> always minimum phase
# min_phase = False # min_phase = False
h1 = np.fft.irfft(H1) h1 = np.fft.irfft(H)
h1 = h1[:ntaps] h1 = h1[:ntaps]
# apply halved Hann windows around peak for fading IR in/out # apply halved Hann windows around peak for fading IR in/out
...@@ -167,31 +173,22 @@ def estimate_ir_csd(y: NDArray, x: NDArray, fs: int, ntaps: int = 512, min_coher ...@@ -167,31 +173,22 @@ def estimate_ir_csd(y: NDArray, x: NDArray, fs: int, ntaps: int = 512, min_coher
h1[peak:] *= signal.windows.hann(2 * (ntaps - peak))[(ntaps - peak) :] h1[peak:] *= signal.windows.hann(2 * (ntaps - peak))[(ntaps - peak) :]
# direct method already results in minimum phase filter -> convert to linear phase if needed # direct method already results in minimum phase filter -> convert to linear phase if needed
# (and if not applied later by fractional smoothing): if not min_phase:
if not min_phase and (not apply_smoothing):
w, H2 = freqz(h1, 1, worN=f, fs=fs) w, H2 = freqz(h1, 1, worN=f, fs=fs)
h1 = signal.firwin2(ntaps_fir, w, np.abs(H2), fs=fs, window=window, antisymmetric=False) h1 = signal.firwin2(ntaps_fir, w, np.abs(H2), fs=fs, window=window, antisymmetric=False)
else: else:
# firwin2() does not need continuous magnitude/frequency pairs -> just discard invalid frequencies # firwin2() does not need continuous magnitude/frequency pairs -> just discard invalid frequencies
#idx_coherent[[0, -1]] = True # always include first / last frequency for firwin2 idx_coherent[[0, -1]] = True # always include first / last frequency for firwin2
H1 = H1[idx_coherent | ~idx_f]
f = f[idx_coherent | ~idx_f]
# estimate IR with firwin2 # estimate IR with firwin2
h1 = signal.firwin2(ntaps_fir, f, np.abs(H1), fs=fs, window=window, antisymmetric=False) h1 = signal.firwin2(ntaps_fir, f[idx_coherent], np.abs(H[idx_coherent]), fs=fs, window=window, antisymmetric=False)
# convert to minimum phase if needed (and if not applied later by fractional smoothing) # convert to minimum phase if needed
if min_phase and (not apply_smoothing): if min_phase:
h1 = signal.minimum_phase(h1) h1 = signal.minimum_phase(h1)
result.idx_coherent[:] = idx_coherent | ~idx_f # result storage
result = FreqRespResult(h1, f.copy(), H.copy(), cxy, idx_coherent_f)
# fractional smoothing of impulse response?
if apply_smoothing:
h1 = smooth_impulse_response(h1, fs, fraction, min_phase=min_phase, ntaps_out=ntaps_smooth)
# update result
result.imp_resp = h1
return result return result
......
...@@ -77,13 +77,14 @@ class SpectrumAvgBase(ABC, ISpectrumAvg): ...@@ -77,13 +77,14 @@ class SpectrumAvgBase(ABC, ISpectrumAvg):
agg_fn = getattr(np, self._aggregate) # ensure aggregate method exists agg_fn = getattr(np, self._aggregate) # ensure aggregate method exists
agg_fn = partial(agg_fn, axis=Ax.TIME) agg_fn = partial(agg_fn, axis=Ax.TIME)
else: else:
agg_fn = self._aggregate
# ToDo: check if agg_fn has an argument 'axis' and set it to Ax.TIME (?) # ToDo: check if agg_fn has an argument 'axis' and set it to Ax.TIME (?)
agg_fn = self._aggregate
return self._ref_spec_vs_time.get_values(representation, weighting, all_freqs, value_transform=agg_fn, include_window_correction=include_window_correction) return self._ref_spec_vs_time.get_values(representation, weighting, all_freqs, value_transform=agg_fn, include_window_correction=include_window_correction)
def get_level(self, representation: ValueRepresentation = "dB", weighting: Weighting = None, all_freqs: bool = False, include_window_correction: bool = True) -> float | NDArray: def get_level(
self, representation: ValueRepresentation = "dB", weighting: Weighting = None, all_freqs: bool = False, include_window_correction: bool = True
) -> float | NDArray:
""" """
Calculate the level of the spectrum. Calculate the level of the spectrum.
Args: Args:
...@@ -147,6 +148,7 @@ class SpectrumAvgComposed(SpectrumAvgBase): ...@@ -147,6 +148,7 @@ class SpectrumAvgComposed(SpectrumAvgBase):
right: Right operand spectrum average right: Right operand spectrum average
operation: Function to combine the values (e.g., lambda a, b: a + b) operation: Function to combine the values (e.g., lambda a, b: a + b)
""" """
# Validate compatibility of frequency bins, time, fs, axis, and spectrum arguments # Validate compatibility of frequency bins, time, fs, axis, and spectrum arguments
if left.fs != right.fs: if left.fs != right.fs:
raise ValueError("Spectrum averages must have matching sampling frequencies (fs)") raise ValueError("Spectrum averages must have matching sampling frequencies (fs)")
......
...@@ -4,42 +4,38 @@ Created on Sept 05 2024 18:22 ...@@ -4,42 +4,38 @@ Created on Sept 05 2024 18:22
@author: Jan.Reimes @author: Jan.Reimes
""" """
import unittest
from pathlib import Path from pathlib import Path
from ddt import ddt, data
import pytest
from nonlinearity import NLMResult from nonlinearity import NLMResult
from nonlinearity.__main__ import main from nonlinearity.__main__ import main
from tools.plot_spectra import NLMPlot
from tests.test_base import BaseTestCase, get_config
from tests import get_output_dir from tests import get_output_dir
from tests.test_base import get_config
from tests.testfiles.file_dl import get_p501_file from tests.testfiles.file_dl import get_p501_file
from tools.plot_spectra import NLMPlot
@ddt class TestNlmCalculate:
class NlmCalculateTestCase(BaseTestCase): @pytest.mark.parametrize("use_activity_threshold", [True, False])
@data(True, False)
def test_nlm_calc_main(self, use_activity_threshold: bool): def test_nlm_calc_main(self, use_activity_threshold: bool):
# define input files # define input files
deg: Path = get_p501_file('English_FB_clause_7.3/FB_male_female_single-talk_seq_compressed') deg: Path = get_p501_file("English_FB_clause_7.3/FB_male_female_single-talk_seq_compressed")
ref: Path = get_p501_file('English_FB_clause_7.3/FB_male_female_single-talk_seq.wav') ref: Path = get_p501_file("English_FB_clause_7.3/FB_male_female_single-talk_seq.wav")
overrides = [f'file_measured="{deg}"', f'file_ref="{ref}"', f"use_activity_threshold={use_activity_threshold}"] overrides = [f'file_measured="{deg}"', f'file_ref="{ref}"', f"use_activity_threshold={use_activity_threshold}"]
# run with config # run with config
with get_config(overrides=overrides, no_custom_config=False) as cfg: with get_config(overrides=overrides, no_custom_config=False) as cfg:
res: NLMResult = main(cfg) res: NLMResult = main(cfg)
assert isinstance(res, NLMResult)
assert res.nlm is not None
assert res.nlm == pytest.approx(9.85, abs=0.01)
# plot results # plot results
nlm_plot = NLMPlot(res, title=f"NLM Test: {deg.name}") nlm_plot = NLMPlot(res, title=f"NLM Test: {deg.name}")
nlm_plot.plot(plot_nlm=True, plot_fr=True, plot_components=True, plot_cxy=True, nlm_plot.plot(plot_nlm=True, plot_fr=True, plot_components=True, plot_cxy=True, show_blocking=True, save_path=get_output_dir(sub_dir="plots"))
show_blocking=True,
save_path=get_output_dir(sub_dir="plots")
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pass