Skip to content

Commit

Permalink
Merge pull request #46 from ToFuProject/Issue041_CurveFit
Browse files Browse the repository at this point in the history
Issue041 curve fit
  • Loading branch information
Didou09 authored Jul 31, 2024
2 parents f5fa4ca + 1aa077f commit c94bcff
Show file tree
Hide file tree
Showing 9 changed files with 451 additions and 143 deletions.
222 changes: 160 additions & 62 deletions spectrally/_class01_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main(

(
key_model, ref_nx, ref_nf,
key_data,
key_data, key_std,
key_lamb, lamb, ref_lamb,
details, binning,
returnas, store, store_key,
Expand Down Expand Up @@ -68,14 +68,25 @@ def main(
ref_in = coll.ddata[key_data]['ref']
ndim_in = data_in.ndim

# ----------
# details

iref_nx = ref_in.index(ref_nx)
if details is True:
iref_nx_out = iref_nx + 1
else:
iref_nx_out = iref_nx

# ----------
# std_in

if key_std is not None:
std_in = coll.ddata[key_std]['data']
nx = coll.dref[ref_nx]['size']

# -----------------------
# prepare loop on indices
# -----------------------

key_bs = None
if key_bs is None:
Expand All @@ -88,6 +99,7 @@ def main(

# -------------
# initialize
# -------------

# shape_out, ref_out
shape_in = data_in.shape
Expand Down Expand Up @@ -119,29 +131,25 @@ def main(
func='sum',
)['sum']

# ----------------
# compute
# ----------------

# --------------
# prepare slices
# --------------

if ndim_in > 1:

# slices
sli_in = list(shape_in)
sli_out = list(shape_out)
# slices
sli_in = list(shape_in)
sli_out = list(shape_out)

sli_in[iref_nx] = slice(None)
sli_out[iref_nx_out] = slice(None)
if details is True:
sli_out[0] = slice(None)
sli_in[iref_nx] = slice(None)
sli_out[iref_nx_out] = slice(None)
if details is True:
sli_out[0] = slice(None)

# as array
sli_in = np.array(sli_in)
sli_out = np.array(sli_out)
# as array
sli_in = np.array(sli_in)
sli_out = np.array(sli_out)

# indices to change
# indices to change
if ndim_in > 1:
ind0 = np.array(
[ii for ii in range(len(shape_in)) if ii != iref_nx],
dtype=int,
Expand All @@ -156,32 +164,69 @@ def main(
else:
ind0 = None

# -------
# loop
# -------------------------
# loop to compute data_out
# -------------------------

if ind0 is None:
for ind in itt.product(*lind):

# update slices
if ind0 is not None:
sli_in[ind0] = ind
sli_out[ind0_out] = ind

# call func
data_out = func(
x_free=data_in,
data_out[tuple(sli_out)] = func(
x_free=data_in[tuple(sli_in)],
lamb=lamb,
binning=binning,
)

else:
# -----------------------------
# loop on std to get error bar
# -----------------------------

data_min = None
data_max = None
if key_std is not None:

data_min = np.full(data_out.shape, np.inf)
data_max = np.full(data_out.shape, -np.inf)
inc = np.r_[-1, 0, 1]
lind_std = [inc for ii in range(nx)]

for ind in itt.product(*lind):

# update slices
sli_in[ind0] = ind
sli_out[ind0_out] = ind

# call func
data_out[tuple(sli_out)] = func(
x_free=data_in[tuple(sli_in)],
lamb=lamb,
binning=binning,
)
if ind0 is not None:
sli_in[ind0] = ind
sli_out[ind0_out] = ind

datain = data_in[tuple(sli_in)]

for stdi in itt.product(*lind_std):

# data = data_in + std * (-1, 0, 1)
datain = (
data_in[tuple(sli_in)]
+ np.r_[stdi] * std_in[tuple(sli_in)]
)

# call func
datai = func(
x_free=datain,
lamb=lamb,
binning=binning,
)

# update min, max
data_min[tuple(sli_out)] = np.minimum(
data_min[tuple(sli_out)], datai,
)
data_max[tuple(sli_out)] = np.maximum(
data_max[tuple(sli_out)],
datai,
)

# --------------
# return
Expand All @@ -195,6 +240,8 @@ def main(
'lamb': lamb,
'details': details,
'data': data_out,
'data_min': data_min,
'data_max': data_max,
'ref': tuple(ref_out),
'dim': coll.ddata[key_data]['dim'],
'quant': coll.ddata[key_data]['quant'],
Expand All @@ -207,9 +254,16 @@ def main(

if store is True:

lout = ['key_data', 'key_model', 'key_lamb', 'lamb', 'details']
lout = [
'key_data', 'key_model', 'key_lamb',
'lamb', 'details',
'data_min', 'data_max',
]
coll.add_data(
**{k0: v0 for k0, v0 in dout.items() if k0 not in lout},
**{
k0: v0 for k0, v0 in dout.items()
if k0 not in lout
},
)

return dout
Expand All @@ -235,38 +289,22 @@ def _check(
store_key=None,
):

# ----------
# key_model
# ----------
# ---------------------
# key_model, key_data
# ---------------------

wsm = coll._which_model
key_model = ds._generic_check._check_var(
key_model, 'key_model',
types=str,
allowed=list(coll.dobj.get(wsm, {}).keys()),
key_model, key_data, key_std, lamb = _check_keys(
coll=coll,
key=key_model,
key_data=key_data,
lamb=lamb,
)

# derive ref_model
wsm = coll._which_model
ref_nf = coll.dobj[wsm][key_model]['ref_nf']
ref_nx = coll.dobj[wsm][key_model]['ref_nx']

# ----------
# key_data
# ----------

# list of acceptable values
lok = [
k0 for k0, v0 in coll.ddata.items()
if ref_nx in v0['ref']
]

# check
key_data = ds._generic_check._check_var(
key_data, 'key_data',
types=str,
allowed=lok,
)

# -----------------
# lamb
# -----------------
Expand Down Expand Up @@ -366,18 +404,78 @@ def _check(

return (
key_model, ref_nx, ref_nf,
key_data,
key_data, key_std,
key_lamb, lamb, ref_lamb,
details, binning,
returnas, store, store_key,
)


def _check_keys(coll=None, key=None, key_data=None, lamb=None):

# ---------------------
# key_model vs key_fit
# ---------------------

wsm = coll._which_model
wsf = coll._which_fit

lokm = list(coll.dobj.get(wsm, {}).keys())
lokf = list(coll.dobj.get(wsf, {}).keys())

key = ds._generic_check._check_var(
key, 'key',
types=str,
allowed=lokm + lokf,
)

# ---------------
# if key_fit
# ---------------

key_std = None
if key in lokf:
key_fit = key
key_model = coll.dobj[wsf][key_fit]['key_model']

if key_data is None:
key_data = coll.dobj[wsf][key_fit]['key_sol']
key_std = coll.dobj[wsf][key_fit]['key_std']

if lamb is None:
lamb = coll.dobj[wsf][key_fit]['key_lamb']

else:
key_model = key

# ----------
# key_data
# ----------

# derive ref_model
ref_nx = coll.dobj[wsm][key_model]['ref_nx']

# list of acceptable values
lok = [
k0 for k0, v0 in coll.ddata.items()
if ref_nx in v0['ref']
]

# check
key_data = ds._generic_check._check_var(
key_data, 'key_data',
types=str,
allowed=lok,
)

return key_model, key_data, key_std, lamb


def _err_lamb(lamb):
msg = (
"Arg 'lamb' nust be either:\n"
"\t- 1d np.ndarray with finite values only\n"
"\t- str: a key to an existing 1d vector with finite values only\n"
f"Provided:\n{lamb}"
)
raise Exception(msg)
raise Exception(msg)
5 changes: 4 additions & 1 deletion spectrally/_class01_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main(
# all other variables
(
key_model, ref_nx, ref_nf,
key_data,
key_data, key_std,
key_lamb, lamb, ref_lamb,
details, binning,
returnas, store, store_key,
Expand Down Expand Up @@ -435,6 +435,7 @@ def func(
lntdu = np.log(t_down / t_up)

# position of max
dout[kfunc]['t0'] = t0
dout[kfunc]['argmax'] = t0 + lntdu * t_down*t_up / dtdu

# value at max
Expand All @@ -461,6 +462,7 @@ def func(
t0 = lamb[0] + lambD * tau

# position of max
dout[kfunc]['t0'] = t0
dout[kfunc]['argmax'] = t0

# value at max
Expand All @@ -484,6 +486,7 @@ def func(
t0 = lamb[0] + lambD * tau

# position of max
dout[kfunc]['t0'] = t0
dout[kfunc]['argmax'] = t0 + np.exp(mu - sigma**2)

# value at max
Expand Down
Loading

0 comments on commit c94bcff

Please sign in to comment.