Skip to content

Commit

Permalink
fix bounds in fit_ac and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Sep 1, 2024
1 parent 62fc3b0 commit 6af8c61
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
6 changes: 3 additions & 3 deletions neurodsp/aperiodic/autocorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def fit_autocorr(timepoints, autocorrs, fit_function='single_exp', bounds=None):

if not bounds:
if fit_function == 'single_exp':
p_bounds = ([0, 0, 0], [np.inf, np.inf, np.inf])
bounds = ([0, 0, 0], [np.inf, np.inf, np.inf])
elif fit_function == 'double_exp':
p_bounds = ([0, 0, 0, 0, 0], [np.inf, np.inf, np.inf, np.inf, np.inf])
bounds = ([0, 0, 0, 0, 0], [np.inf, np.inf, np.inf, np.inf, np.inf])

popts, _ = curve_fit(AC_FIT_FUNCS[fit_function], timepoints, autocorrs, bounds=p_bounds)
popts, _ = curve_fit(AC_FIT_FUNCS[fit_function], timepoints, autocorrs, bounds=bounds)

return popts

Expand Down
6 changes: 6 additions & 0 deletions neurodsp/tests/aperiodic/test_autocorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def test_fit_autocorr(tsig):
fit_vals1 = compute_ac_fit(timepoints, *popts1, fit_function='single_exp')
assert np.all(fit_vals1)

# Test with bounds passed in
bounds = ([0, 0, 0], [10, np.inf, np.inf])
popts1 = fit_autocorr(timepoints, autocorrs, 'single_exp', bounds)
fit_vals1 = compute_ac_fit(timepoints, *popts1, fit_function='single_exp')
assert np.all(fit_vals1)

popts2 = fit_autocorr(timepoints, autocorrs, fit_function='double_exp')
fit_vals2 = compute_ac_fit(timepoints, *popts2, fit_function='double_exp')
assert np.all(fit_vals2)

0 comments on commit 6af8c61

Please sign in to comment.