diff --git a/spectrally/_class01_SpectralModel.py b/spectrally/_class01_SpectralModel.py index 334f2bd..3f28f87 100644 --- a/spectrally/_class01_SpectralModel.py +++ b/spectrally/_class01_SpectralModel.py @@ -202,6 +202,7 @@ def get_spectral_model_moments( key=None, key_data=None, lamb=None, + returnas=None, ): """ @@ -221,6 +222,10 @@ def get_spectral_model_moments( key to the data to be used as input for the model's free variables lamb:str or 1d np.ndarray, optional wavelenth vector to be used for computing limited integrals + returnas: str + Flag indicating whether to return dout as: + - 'dict_arrays': nested dict of {'ktype': {'kvar': array}} + - 'dict_varnames': dict of {'kfunc_var': values} Returns ------- @@ -234,6 +239,7 @@ def get_spectral_model_moments( key=key, key_data=key_data, lamb=lamb, + returnas=returnas, ) return dout diff --git a/spectrally/_class01_moments.py b/spectrally/_class01_moments.py index 584e5d0..610b446 100644 --- a/spectrally/_class01_moments.py +++ b/spectrally/_class01_moments.py @@ -11,6 +11,7 @@ import numpy as np import scipy.constants as scpct +import astropy.units as asunits import datastock as ds @@ -30,6 +31,8 @@ def main( key_data=None, lamb=None, dmz=None, + # return + returnas=None, ): # ------------ @@ -37,11 +40,8 @@ def main( # ------------ # key_model vs key_fit - key_model, key_data, lamb = _check( - coll=coll, - key=key, - key_data=key_data, - lamb=lamb, + returnas = _check( + returnas=returnas, ) # all other variables @@ -50,10 +50,10 @@ def main( key_data, key_std, key_lamb, lamb, ref_lamb, details, binning, - returnas, store, store_key, + _, store, store_key, ) = _interpolate._check( coll=coll, - key_model=key_model, + key_model=key, key_data=key_data, lamb=lamb, # others @@ -90,6 +90,7 @@ def main( # ------------ dind = _check_mz(dmz, dind=dind) + axis = coll.ddata[key_data]['ref'].index(ref_nx) # ------------ # get func @@ -101,23 +102,38 @@ def main( c2=c2, dind=dind, param_val=param_val, - axis=coll.ddata[key_data]['ref'].index(ref_nx), + axis=axis, ) - # ------------ - # compute - # ------------ + # ------------------- + # extract / compute + # ------------------- dout = func( x_free=coll.ddata[key_data]['data'], + x_std=None if key_std is None else coll.ddata[key_std]['data'], lamb=lamb, - scale=None, ) # ------------- # format output # ------------- + if returnas == 'dict_varnames': + + ref = tuple([rr for rr in coll.ddata[key_data]['ref'] if rr != ref_nx]) + + dout = _format( + coll=coll, + key_model=key_model, + key_data=key_data, + key_lamb=key_lamb, + din=dout, + dind=dind, + axis=axis, + ref=ref, + ) + return dout @@ -127,35 +143,22 @@ def main( ############################################# -def _check(coll=None, key=None, key_data=None, lamb=None): - - # --------------------- - # key_model vs key_fit - # --------------------- - - wsm = coll._which_model - wsf = coll._which_fit +def _check( + returnas=None, +): - lokm = list(coll.dobj.get(wsm, {}).keys()) - lokf = list(coll.dobj.get(wsf, {}).keys()) + # ------------- + # returnas + # ------------- - key = ds._generic_check._check_var( - key, 'key', + returnas = ds._generic_check._check_var( + returnas, 'returnas', types=str, - allowed=lokm + lokf, + default='dict_varnames', + allowed=['dict_ftypes', 'dict_varnames'], ) - if key in lokf: - key_fit = key - key = coll.dobj[wsf][key_fit]['key_model'] - - if key_data is None: - key_data = coll.dobj[wsf][key_fit]['key_sol'] - - if lamb is None: - lamb = coll.dobj[wsf][key_fit]['key_lamb'] - - return key, key_data, lamb + return returnas def _check_mz( @@ -170,12 +173,12 @@ def _check_mz( # add mz if user-provided if dmz is not None: - for kfunc in ['gauss', 'pvoigt', 'voigt']: + for ktype in ['gauss', 'pvoigt', 'voigt']: - if dind.get(kfunc) is None: + if dind.get(ktype) is None: continue - dind[kfunc]['mz'] + dind[ktype]['mz'] return dind @@ -200,15 +203,19 @@ def _get_func_moments( # prepare # -------------- + # -------------- + # prepare + # -------------- + def func( x_free=None, + x_std=None, lamb=None, param_val=param_val, c0=c0, c1=c1, c2=c2, dind=dind, - scale=None, axis=axis, ): @@ -231,81 +238,102 @@ def func( if c0 is None: x_full = x_free + if x_std is not None: + xf_min = x_full - x_std + xf_max = x_full + x_std + else: + if x_free.ndim > 1: shape = list(x_free.shape) shape[axis] = c0.size x_full = np.full(shape, np.nan) + if x_std is not None: + xf_min = np.full(shape, np.nan) + xf_max = np.full(shape, np.nan) + sli = list(shape) sli[axis] = slice(None) sli = np.array(sli) ich = np.array([ii for ii in range(len(shape)) if ii != axis]) linds = [range(shape[ii]) for ii in ich] - for ind in itt.product(*linds): + + for ii, ind in enumerate(itt.product(*linds)): sli[ich] = ind slii = tuple(sli) x_full[slii] = ( c2.dot(x_free[slii]**2) + c1.dot(x_free[slii]) + c0 ) + if x_std is not None: + xf_min[slii] = ( + c2.dot((x_free[slii] - x_std[slii])**2) + + c1.dot(x_free[slii] - x_std[slii]) + c0 + ) + xf_max[slii] = ( + c2.dot((x_free[slii] + x_std[slii])**2) + + c1.dot(x_free[slii] + x_std[slii]) + c0 + ) + else: x_full = c2.dot(x_free**2) + c1.dot(x_free) + c0 + if x_std is not None: + xf_min = c2.dot((x_free - x_std)**2) + c1.dot(x_free - x_std) + c0 + xf_max = c2.dot((x_free + x_std)**2) + c1.dot(x_free + x_std) + c0 sli = [None if ii == axis else slice(None) for ii in range(x_free.ndim)] - extract = _get_var_extract_func(x_full, dind, axis, sli) - - # ------------------- - # rescale - - if scale is not None: - pass + extract = _get_var_extract_func(dind, axis, sli) # --------------------- # extract all variables - for kfunc, v0 in _model_dict._DMODEL.items(): - if dind.get(kfunc) is not None: + for ktype, v0 in _model_dict._DMODEL.items(): + if dind.get(ktype) is not None: for kvar in v0['var']: - dout[kfunc][kvar] = extract(kfunc, kvar) + dout[ktype][kvar] = extract(ktype, kvar, x_full) + if x_std is not None: + dout[ktype][f"{kvar}_min"] = extract(ktype, kvar, xf_min) + dout[ktype][f"{kvar}_max"] = extract(ktype, kvar, xf_max) # ------------------ # sum all poly - kfunc = 'poly' - if dind.get(kfunc) is not None: + ktype = 'poly' + if dind.get(ktype) is not None: - a0 = dout[kfunc]['a0'] - a1 = dout[kfunc]['a1'] - a2 = dout[kfunc]['a2'] + a0 = dout[ktype]['a0'] + a1 = dout[ktype]['a1'] + a2 = dout[ktype]['a2'] # integral if lamb is not None: - dout[kfunc]['integ'] = ( + dout[ktype]['integ'] = ( a0 * (lamb[-1] - lamb[0]) + a1 * (lamb[-1]**2 - lamb[0]**2)/2 + + a2 * (lamb[-1]**3 - lamb[0]**3)/3 ) # argmax, max - dout[kfunc]['argmax'] = np.full(a0.shape, np.nan) - dout[kfunc]['max'] = np.full(a1.shape, np.nan) + dout[ktype]['argmax'] = np.full(a0.shape, np.nan) + dout[ktype]['max'] = np.full(a1.shape, np.nan) iok = a2 != 0 - dout[kfunc]['argmax'][iok] = lambm - lambD * a1[iok]/(2*a2[iok]) - dout[kfunc]['max'][iok] = a0[iok] - a1[iok]**2 / (4*a2[iok]) + dout[ktype]['argmax'][iok] = lambm - lambD * a1[iok]/(2*a2[iok]) + dout[ktype]['max'][iok] = a0[iok] - a1[iok]**2 / (4*a2[iok]) # -------------------- # sum all exponentials - kfunc = 'exp_lamb' - if dind.get(kfunc) is not None: + ktype = 'exp_lamb' + if dind.get(ktype) is not None: # physics - rate = dout[kfunc]['rate'] - dout[kfunc]['Te'] = (scpct.h * scpct.c / rate) / scpct.e + rate = dout[ktype]['rate'] + dout[ktype]['Te'] = (scpct.h * scpct.c / rate) / scpct.e # integral if lamb is not None: - amp = dout[kfunc]['rate'] - dout[kfunc]['integ'] = ( + amp = dout[ktype]['rate'] + dout[ktype]['integ'] = ( (amp / rate) * (np.exp(lamb[-1] * rate) - np.exp(lamb[0] * rate)) ) @@ -313,28 +341,28 @@ def func( # ----------------- # sum all gaussians - kfunc = 'gauss' - if dind.get(kfunc) is not None: + ktype = 'gauss' + if dind.get(ktype) is not None: - amp = dout[kfunc]['amp'] - sigma = dout[kfunc]['sigma'] - vccos = dout[kfunc]['vccos'] + amp = dout[ktype]['amp'] + sigma = dout[ktype]['sigma'] + vccos = dout[ktype]['vccos'] # argmax - dout[kfunc]['argmax'] = _get_line_argmax( - vccos, param_val, dind, kfunc, amp.shape, axis, + dout[ktype]['argmax'] = _get_line_argmax( + vccos, param_val, dind, ktype, amp.shape, axis, ) # integral - dout[kfunc]['integ'] = amp * sigma * np.sqrt(2 * np.pi) + dout[ktype]['integ'] = amp * sigma * np.sqrt(2 * np.pi) # physics - if dind[kfunc].get('mz') is not None: - dout[kfunc]['Ti'] = _get_Ti( + if dind[ktype].get('mz') is not None: + dout[ktype]['Ti'] = _get_Ti( sigma, param_val, dind, - kfunc, + ktype, sigma.shape, axis, ) @@ -342,46 +370,46 @@ def func( # ------------------- # sum all Lorentzians - kfunc = 'lorentz' - if dind.get(kfunc) is not None: + ktype = 'lorentz' + if dind.get(ktype) is not None: - amp = dout[kfunc]['amp'] - gam = dout[kfunc]['gam'] - vccos = dout[kfunc]['vccos'] + amp = dout[ktype]['amp'] + gam = dout[ktype]['gam'] + vccos = dout[ktype]['vccos'] # argmax - dout[kfunc]['argmax'] = _get_line_argmax( - vccos, param_val, dind, kfunc, amp.shape, axis, + dout[ktype]['argmax'] = _get_line_argmax( + vccos, param_val, dind, ktype, amp.shape, axis, ) # integral - dout[kfunc]['integ'] = amp * np.pi * gam + dout[ktype]['integ'] = amp * np.pi * gam # -------------------- # sum all pseudo-voigt - kfunc = 'pvoigt' - if dind.get(kfunc) is not None: + ktype = 'pvoigt' + if dind.get(ktype) is not None: - amp = dout[kfunc]['amp'] - sigma = dout[kfunc]['sigma'] - vccos = dout[kfunc]['vccos'] + amp = dout[ktype]['amp'] + sigma = dout[ktype]['sigma'] + vccos = dout[ktype]['vccos'] # argmax - dout[kfunc]['argmax'] = _get_line_argmax( - vccos, param_val, dind, kfunc, amp.shape, axis, + dout[ktype]['argmax'] = _get_line_argmax( + vccos, param_val, dind, ktype, amp.shape, axis, ) # integral - dout[kfunc]['integ'] = np.full(sigma.shape, np.nan) + dout[ktype]['integ'] = np.full(sigma.shape, np.nan) # physics - if dind[kfunc].get('mz') is not None: - dout[kfunc]['Ti'] = _get_Ti( + if dind[ktype].get('mz') is not None: + dout[ktype]['Ti'] = _get_Ti( sigma, param_val, dind, - kfunc, + ktype, sigma.shape, axis, ) @@ -389,28 +417,28 @@ def func( # -------------------- # sum all voigt - kfunc = 'voigt' - if dind.get(kfunc) is not None: + ktype = 'voigt' + if dind.get(ktype) is not None: - amp = dout[kfunc]['amp'] - sigma = dout[kfunc]['sigma'] - vccos = dout[kfunc]['vccos'] + amp = dout[ktype]['amp'] + sigma = dout[ktype]['sigma'] + vccos = dout[ktype]['vccos'] # argmax - dout[kfunc]['argmax'] = _get_line_argmax( - vccos, param_val, dind, kfunc, amp.shape, axis, + dout[ktype]['argmax'] = _get_line_argmax( + vccos, param_val, dind, ktype, amp.shape, axis, ) # integral - dout[kfunc]['integ'] = amp + dout[ktype]['integ'] = amp # physics - if dind[kfunc].get('mz') is not None: - dout[kfunc]['Ti'] = _get_Ti( + if dind[ktype].get('mz') is not None: + dout[ktype]['Ti'] = _get_Ti( sigma, param_val, dind, - kfunc, + ktype, sigma.shape, axis, ) @@ -418,16 +446,16 @@ def func( # ------------------ # sum all pulse_exp - kfunc = 'pulse_exp' - if dind.get(kfunc) is not None: + ktype = 'pulse_exp' + if dind.get(ktype) is not None: - amp = dout[kfunc]['amp'] - tau = dout[kfunc]['tau'] - t_down = dout[kfunc]['t_down'] - t_up = dout[kfunc]['t_up'] + amp = dout[ktype]['amp'] + tau = dout[ktype]['tau'] + t_down = dout[ktype]['t_down'] + t_up = dout[ktype]['t_up'] # integral - dout[kfunc]['integ'] = amp * (t_down - t_up) + dout[ktype]['integ'] = amp * (t_down - t_up) # prepare t0 = lamb[0] + lambD * tau @@ -435,11 +463,11 @@ def func( lntdu = np.log(t_down / t_up) # position of max - dout[kfunc]['t0'] = t0 - dout[kfunc]['argmax'] = t0 + lntdu * t_down*t_up / dtdu + dout[ktype]['t0'] = t0 + dout[ktype]['argmax'] = t0 + lntdu * t_down*t_up / dtdu # value at max - dout[kfunc]['max'] = amp * ( + dout[ktype]['max'] = amp * ( np.exp(-lntdu * t_up / dtdu) - np.exp(-lntdu * t_down / dtdu) ) @@ -447,50 +475,50 @@ def func( # ------------------ # sum all pulse_gauss - kfunc = 'pulse_gauss' - if dind.get(kfunc) is not None: + ktype = 'pulse_gauss' + if dind.get(ktype) is not None: - amp = dout[kfunc]['amp'] - tau = dout[kfunc]['tau'] - t_down = dout[kfunc]['t_down'] - t_up = dout[kfunc]['t_up'] + amp = dout[ktype]['amp'] + tau = dout[ktype]['tau'] + t_down = dout[ktype]['t_down'] + t_up = dout[ktype]['t_up'] # integral - dout[kfunc]['integ'] = amp/2 * np.sqrt(np.pi) * (t_up + t_down) + dout[ktype]['integ'] = amp/2 * np.sqrt(np.pi) * (t_up + t_down) # prepare t0 = lamb[0] + lambD * tau # position of max - dout[kfunc]['t0'] = t0 - dout[kfunc]['argmax'] = t0 + dout[ktype]['t0'] = t0 + dout[ktype]['argmax'] = t0 # value at max - dout[kfunc]['max'] = amp + dout[ktype]['max'] = amp # ------------------ # sum all lognorm - kfunc = 'lognorm' - if dind.get(kfunc) is not None: + ktype = 'lognorm' + if dind.get(ktype) is not None: - amp = dout[kfunc]['amp'] - tau = dout[kfunc]['tau'] - sigma = dout[kfunc]['sigma'] - mu = dout[kfunc]['mu'] + amp = dout[ktype]['amp'] + tau = dout[ktype]['tau'] + sigma = dout[ktype]['sigma'] + mu = dout[ktype]['mu'] # integral - dout[kfunc]['integ'] = np.full(mu.shape, np.nan) + dout[ktype]['integ'] = np.full(mu.shape, np.nan) # prepare t0 = lamb[0] + lambD * tau # position of max - dout[kfunc]['t0'] = t0 - dout[kfunc]['argmax'] = t0 + np.exp(mu - sigma**2) + dout[ktype]['t0'] = t0 + dout[ktype]['argmax'] = t0 + np.exp(mu - sigma**2) # value at max - dout[kfunc]['max'] = amp * np.exp(0.5*sigma**2 - mu) + dout[ktype]['max'] = amp * np.exp(0.5*sigma**2 - mu) return dout @@ -503,16 +531,17 @@ def func( # ##################################################################### -def _get_var_extract_func(x_full, dind, axis, sli): - def func(kfunc, kvar, dind=dind, axis=axis, sli=sli, x_full=x_full): - sli[axis] = dind[kfunc][kvar]['ind'] +def _get_var_extract_func(dind, axis, sli): + def func(ktype, kvar, x_full, dind=dind, axis=axis, sli=sli): + sli[axis] = dind[ktype][kvar]['ind'] return x_full[tuple(sli)] return func -def _get_line_argmax(vccos, param_val, dind, kfunc, shape, axis): - # extract lamb0 - lamb0 = param_val[dind[kfunc]['lamb0']] +def _get_line_argmax(vccos, param_val, dind, ktype, shape, axis): + + # lamb0 + lamb0 = param_val[dind[ktype]['lamb0']] # reshape lamb0 reshape = [1 for ii in shape] @@ -521,11 +550,12 @@ def _get_line_argmax(vccos, param_val, dind, kfunc, shape, axis): return lamb0 * (1 + vccos) -def _get_Ti(sigma, param_val, dind, kfunc, shape, axis): - # extract lamb0, mz - lamb0 = param_val[dind[kfunc]['lamb0']] - mz = param_val[dind[kfunc]['mz']] +def _get_Ti(sigma, param_val, dind, ktype, shape, axis): + + # lamb0, mz + lamb0 = param_val[dind[ktype]['lamb0']] + mz = param_val[dind[ktype]['mz']] # reshape lamb0 and mz reshape = [1 for ii in shape] @@ -534,4 +564,208 @@ def _get_Ti(sigma, param_val, dind, kfunc, shape, axis): lamb0 = lamb0.reshape(reshape) mz = mz.reshape(reshape) - return (sigma / lamb0)**2 * mz * scpct.c**2 / scpct.e \ No newline at end of file + return (sigma / lamb0)**2 * mz * scpct.c**2 / scpct.e + + +############################################# +############################################# +# format +############################################# + + +def _format( + coll=None, + key_model=None, + key_data=None, + key_lamb=None, + din=None, + dind=None, + axis=None, + ref=None, +): + + # --------------- + # prepare + # --------------- + + dout = {} + wsm = coll._which_model + sli = [slice(None) for ii in range(len(ref)+1)] + + # ------------- + # units + # ------------- + + # data + try: + units_data = asunits.Unit(coll.ddata[key_data]['units']) + except Exception as err: + units_data = coll.ddata[key_data]['units'] + + # lamb + try: + units_lamb = asunits.Unit(coll.ddata[key_lamb]['units']) + except Exception as err: + units_lamb = coll.ddata[key_lamb]['units'] + + # -------------- + # loop + # -------------- + + for ktype, vtype in din.items(): + + lkfunc = [ + kfunc for kfunc in coll.dobj[wsm][key_model]['keys'] + if coll.dobj[wsm][key_model]['dmodel'][kfunc]['type'] == ktype + ] + + for kvar in vtype.keys(): + + for ii, kfunc in enumerate(lkfunc): + + key = f"{kfunc}_{kvar}" + sli[axis] = ii + + # get units + units = _units( + coll, ktype, kvar, + units_data=units_data, + units_lamb=units_lamb, + ) + + # store + dout[key] = { + 'data': din[ktype][kvar][tuple(sli)], + 'ref': ref, + 'units': units, + } + + return dout + + +############################################# +############################################# +# units +############################################# + + +def _units(coll, ktype, kvar, units_data=None, units_lamb=None): + + if kvar.endswith('_min') or kvar.endswith('_max'): + kvar = kvar[:-4] + units = None + + # ------------ + # trivial + # ------------ + + if kvar == 'integ': + try: + units = units_data * units_lamb + except Exception as err: + units = f"{units_data} x {units_lamb}" + + elif kvar == 'max': + units = units_data + + elif kvar == 'argmax': + units = units_lamb + + # ---------------------------------- + # non-trivial but common to several + # ---------------------------------- + + elif kvar == 'Ti': + + # assuming mz is provided in kg + units = 'eV' + + # ---------------- + # specific to each + # ---------------- + else: + + if ktype == 'poly': + + units = units_data + + elif ktype == 'exp_lamb': + + if kvar == 'amp': + try: + units = units_data * units_lamb + except Exception as err: + units = f"{units_data} x {units_lamb}" + + elif kvar == 'rate': + units = units_lamb + + elif kvar == 'Te': + if units_lamb == 'm': + units = 'eV' + else: + units = asunits('eV') * (asunits.Units('m') / units_lamb) + + elif ktype in ['gauss', 'lorentz', 'pvoigt', 'voigt']: + + if kvar == 'amp': + units = units_data + + elif kvar == 'sigma': + units = units_lamb + + elif kvar == 'gam': + units = units_lamb + + elif kvar == 'vccos': + units = '' + + elif ktype in ['pulse_exp', 'pulse_gauss']: + + if kvar == 'amp': + units = units_data + + elif kvar == 'tau': + units = '' + + elif kvar == 't0': + units = units_lamb + + elif kvar == 't_up': + units = units_lamb + + elif kvar == 't_down': + units = units_lamb + + elif ktype == 'lognorm': + + if kvar == 'amp': + try: + units = units_data * units_lamb + except Exception as err: + units = f"{units_data} x {units_lamb}" + + elif kvar == 'tau': + units = '' + + elif kvar == 't0': + units = units_lamb + + elif kvar == 'mu': + units = 'TBD' + + elif kvar == 'sigma': + units = '' + + # -------------- + # safety check + # -------------- + + if units is None: + msg = f"units not reckognized for {ktype} {kvar}" + raise Exception(msg) + + try: + return asunits.Unit(units) + except Exception as err: + return units \ No newline at end of file diff --git a/spectrally/tests/_hxr_pulses.py b/spectrally/tests/_hxr_pulses.py index 9cc9b14..678fc2f 100644 --- a/spectrally/tests/_hxr_pulses.py +++ b/spectrally/tests/_hxr_pulses.py @@ -93,6 +93,15 @@ def main( # spectral models # -------------- + # spectral constrainst + dconstraints = { + 'gbck': { + 'ref': 'bck_a0', + 'bck_a1': [0, 0, 0], + 'bck_a2': [0, 0, 0], + }, + } + # single exponential pulse coll.add_spectral_model( key='sm_exp', @@ -100,9 +109,7 @@ def main( 'bck': 'poly', 'pulse': 'pulse_exp', }, - dconstraints={ - 'gbck': {'ref': 'bck_a0', 'bck_a1': [0, 0, 0]}, - }, + dconstraints=dconstraints, ) # single gaussian pulse @@ -112,9 +119,7 @@ def main( 'bck': 'poly', 'pulse': 'pulse_gauss', }, - dconstraints={ - 'gbck': {'ref': 'bck_a0', 'bck_a1': [0, 0, 0]}, - }, + dconstraints=dconstraints, ) # single lognorm pulse @@ -124,9 +129,7 @@ def main( 'bck': 'poly', 'pulse': 'lognorm', }, - dconstraints={ - 'gbck': {'ref': 'bck_a0', 'bck_a1': [0, 0, 0]}, - }, + dconstraints=dconstraints, ) # -------------- diff --git a/spectrally/tutorials/tutorial.py b/spectrally/tutorials/tutorial.py index 148a063..7a32bb9 100644 --- a/spectrally/tutorials/tutorial.py +++ b/spectrally/tutorials/tutorial.py @@ -435,7 +435,7 @@ def _add_spectral_fit(coll=None, data=None): for k0 in coll.dobj['spect_model'].keys(): coll.add_spectral_fit( - key=f"sf_{k0.replace('sm_', '')}", + key=k0.replace('sm_', 'sf_'), key_model=k0, key_data='current', key_sigma=None,