from collections.abc import Callable
import warnings
import numpy as np
from scipy.optimize import curve_fit, OptimizeWarning
import sympy as sp
from uncertainties import ufloat
from uncertainties.core import Variable
[docs]
def fit_data(
xs: list[float],
ys: list[float],
es: list[float],
fit_ansatz: Callable,
fit_param_bounds: tuple[list, list],
fit_y_scale: str,
SNR_threshold: float = 5,
Abs_threshold: float = np.inf,
verbose: bool = False,
) -> tuple[Variable, ...]:
"""Fit given data using the given ansatz.
Uses scipy's curve_fit with method='trf'.
Args:
xs (list[float]): The independent variable data.
ys (list[float]): The dependent variable data.
es (list[float]): The uncertainity in the dependent variable.
fit_ansatz (Callable): The fit function.
fit_param_bounds (tuple[list, list]): The bounds for the fit parameters.
fit_y_scale (str): Either 'linear' or 'log'. If 'log' then ys is scaled to log(ys) before fit.
SNR_threshold (float, optional): The signal-to-noise threshold below which data is discarded. Defaults to 5.
Abs_threshold (float, optional): The error rate above which the data is discarded. Defaults to infinity.
verbose (bool, optional): Whether to print info. Defaults to False.
Raises:
ValueError: If xs, ys and es don't have the same length.
ValueError: If at least two points not provided, or less than two points remain after filtering.
Returns:
tuple[Variable, ...]: The fit parameters as uncertainities Variable types. Use p.nominal_value and p.std_dev to
access the stored values.
"""
if len(xs) != len(ys) or len(xs) != len(es):
raise ValueError("xs, ys and es don't have the same length.")
if len(xs) < 2:
raise ValueError("Fitting cannot work without at least two points.")
# convert data into np array and sort
xs, ys, es = (np.array(_) for _ in [xs, ys, es])
inds = np.argsort(xs)
xs, ys, es = (_[inds] for _ in [xs, ys, es])
# Do SNR threshold filtering
if SNR_threshold is not None:
inds = np.where(ys / es > SNR_threshold)[0]
if verbose:
print(
"Fit is ignoring (SNR-based): ",
xs[np.where(ys / es <= SNR_threshold)[0]],
)
xs, ys, es = (_[inds] for _ in [xs, ys, es])
# Do Abs threshold filtering
if Abs_threshold is not None:
inds = np.where(ys < Abs_threshold)[0]
if verbose:
print("Fit is ignoring (Abs-based): ", xs[np.where(ys >= Abs_threshold)[0]])
xs, ys, es = (_[inds] for _ in [xs, ys, es])
if len(xs) < 2:
raise ValueError(
"After filtering, less than two points remain. Fitting not possible."
)
if fit_y_scale == "linear":
ys_scaled = ys
elif fit_y_scale == "log":
ys_scaled = np.log(ys)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
try:
p_opt, p_cov = curve_fit(
fit_ansatz,
xs,
ys_scaled,
sigma=es / ys,
bounds=fit_param_bounds,
method="trf",
)
if len(w):
raise RuntimeError
except RuntimeError:
raise RuntimeError("Covariance of the parameters can not be estimated.")
mean = np.array(p_opt)
std = np.sqrt(np.diagonal(p_cov))
return tuple(ufloat(m, s) for m, s in zip(mean, std))