Skip to content

Commit

Permalink
conslidate use of 'function'
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Sep 8, 2024
1 parent f02869c commit a9a393a
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 84 deletions.
22 changes: 11 additions & 11 deletions neurodsp/sim/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
###################################################################################################
###################################################################################################

def sig_yielder(sim_func, params, n_sims):
def sig_yielder(function, params, n_sims):
"""Generator to yield simulated signals from a given simulation function and parameters.
Parameters
----------
sim_func : str or callable
function : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : dict
The parameters for the simulated signal, passed into `sim_func`.
The parameters for the simulated signal, passed into `function`.
n_sims : int, optional
Number of simulations to set as the max.
If None, creates an infinite generator.
Expand All @@ -28,21 +28,21 @@ def sig_yielder(sim_func, params, n_sims):
Simulated time series.
"""

sim_func = get_sim_func(sim_func)
function = get_sim_func(function)
for _ in counter(n_sims):
yield sim_func(**params)
yield function(**params)


def sig_sampler(sim_func, params, return_params=False, n_sims=None):
def sig_sampler(function, params, return_params=False, n_sims=None):
"""Generator to yield simulated signals from a parameter sampler.
Parameters
----------
sim_func : str or callable
function : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : iterable
The parameters for the simulated signal, passed into `sim_func`.
The parameters for the simulated signal, passed into `function`.
return_params : bool, optional, default: False
Whether to yield the simulation parameters as well as the simulated time series.
n_sims : int, optional
Expand All @@ -58,7 +58,7 @@ def sig_sampler(sim_func, params, return_params=False, n_sims=None):
Only returned if `return_params` is True.
"""

sim_func = get_sim_func(sim_func)
function = get_sim_func(function)

# If `params` has a size, and `n_sims` is defined, check that they are compatible
# To do so, we first check if the iterable has a __len__ attr, and if so check values
Expand All @@ -69,9 +69,9 @@ def sig_sampler(sim_func, params, return_params=False, n_sims=None):
for ind, sample_params in zip(counter(n_sims), params):

if return_params:
yield sim_func(**sample_params), sample_params
yield function(**sample_params), sample_params
else:
yield sim_func(**sample_params)
yield function(**sample_params)

if n_sims and ind >= n_sims:
break

Check warning on line 77 in neurodsp/sim/generators.py

View check run for this annotation

Codecov / codecov/patch

neurodsp/sim/generators.py#L77

Added line #L77 was not covered by tests
56 changes: 38 additions & 18 deletions neurodsp/sim/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,40 @@

SIM_MODULES = ['periodic', 'aperiodic', 'cycles', 'transients', 'combined']

def get_sim_funcs(module_name):
def get_sim_funcs(module):
"""Get the available sim functions from a specified sub-module.
Parameters
----------
module_name : {'periodic', 'aperiodic', 'cycles', 'transients', 'combined'}
module : {'periodic', 'aperiodic', 'cycles', 'transients', 'combined'}
Simulation sub-module to get sim functions from.
Returns
-------
funcs : dictionary
functions : dictionary
A dictionary containing the available sim functions from the requested sub-module.
"""

check_param_options(module_name, 'module_name', SIM_MODULES)
check_param_options(module, 'module', SIM_MODULES)

# Note: imports done within function to avoid circular import
from neurodsp.sim import periodic, aperiodic, transients, combined, cycles

module = eval(module_name)
module = eval(module)

funcs = {name : func for name, func in getmembers(module, isfunction) \
if name[0:4] == 'sim_' and func.__module__.split('.')[-1] == module.__name__.split('.')[-1]}
module_name = module.__name__.split('.')[-1]
functions = {name : function for name, function in getmembers(module, isfunction) \
if name[0:4] == 'sim_' and function.__module__.split('.')[-1] == module_name}

return funcs
return functions


def get_sim_names(module_name):
def get_sim_names(module):
"""Get the names of the available sim functions from a specified sub-module.
Parameters
----------
module_name : {'periodic', 'aperiodic', 'transients', 'combined'}
module : {'periodic', 'aperiodic', 'transients', 'combined'}
Simulation sub-module to get sim functions from.
Returns
Expand All @@ -50,15 +51,15 @@ def get_sim_names(module_name):
The names of the available functions in the requested sub-module.
"""

return list(get_sim_funcs(module_name).keys())
return list(get_sim_funcs(module).keys())


def get_sim_func(function_name, modules=SIM_MODULES):
def get_sim_func(function, modules=SIM_MODULES):
"""Get a specified sim function.
Parameters
----------
function_name : str or callabe
function : str or callabe
Name of the sim function to retrieve.
If callable, returns input.
If string searches for corresponding callable sim function.
Expand All @@ -67,21 +68,40 @@ def get_sim_func(function_name, modules=SIM_MODULES):
Returns
-------
func : callable
function : callable
Requested sim function.
"""

if callable(function_name):
return function_name
if callable(function):
return function

for module in modules:
try:
func = get_sim_funcs(module)[function_name]
function = get_sim_funcs(module)[function]
break
except KeyError:
continue

else:
raise ValueError('Requested simulation function not found.') from None

return func
return function


def get_sim_func_name(function):
"""Get the name of a simulation function.
Parameters
----------
function : str or callabe
Function to get name for.
Returns
-------
name : str
Name of the function.
"""

name = function.__name__ if callable(function) else function

return name
10 changes: 5 additions & 5 deletions neurodsp/sim/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def save_sims(sims, label, file_path=None, replace=False):

assert '_' not in label, 'Cannot have underscores in simulation label.'

save_path_items = ['sim_unknown' if not sims.sim_func else sims.sim_func]
save_path_items = ['sim_unknown' if not sims.function else sims.function]
if isinstance(sims, (VariableSimulations, MultiSimulations)):
if sims.component:
save_path_items.append(sims.component)

Check warning on line 131 in neurodsp/sim/io.py

View check run for this annotation

Codecov / codecov/patch

neurodsp/sim/io.py#L131

Added line #L131 was not covered by tests
Expand Down Expand Up @@ -178,7 +178,7 @@ def load_sims(load_name, file_path=None):
load_name = matches[0]

splits = load_name.split('_')
sim_func = '_'.join(splits[0:2]) if splits[1] != 'unknown' else None
function = '_'.join(splits[0:2]) if splits[1] != 'unknown' else None

update, component = None, None
if len(splits) > 3:
Expand All @@ -192,18 +192,18 @@ def load_sims(load_name, file_path=None):
if 'signals.npy' not in load_files:

msims = [load_sims(load_file, load_folder) for load_file in load_files]
sims = MultiSimulations(msims, None, sim_func, update, component)
sims = MultiSimulations(msims, None, function, update, component)

else:

sigs = np.load(load_folder / 'signals.npy')

if 'params.json' in load_files:
params = load_json(load_folder / 'params.json')
sims = Simulations(sigs, params, sim_func)
sims = Simulations(sigs, params, function)

elif 'params.jsonlines' in load_files:
params = load_jsonlines(load_folder / 'params.jsonlines')
sims = VariableSimulations(sigs, params, sim_func, update, component)
sims = VariableSimulations(sigs, params, function, update, component)

return sims
38 changes: 19 additions & 19 deletions neurodsp/sim/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
###################################################################################################
###################################################################################################

def sim_multiple(sim_func, params, n_sims):
def sim_multiple(function, params, n_sims):
"""Simulate multiple samples of a specified simulation.
Parameters
----------
sim_func : callable
function : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : dict
The parameters for the simulated signal, passed into `sim_func`.
The parameters for the simulated signal, passed into `function`.
n_sims : int
Number of simulations to create.
Expand All @@ -36,23 +36,23 @@ def sim_multiple(sim_func, params, n_sims):
>>> sims = sim_multiple(sim_powerlaw, params, n_sims=3)
"""

sims = Simulations(n_sims, params, sim_func)
for ind, sig in enumerate(sig_yielder(sim_func, params, n_sims)):
sims = Simulations(n_sims, params, function)
for ind, sig in enumerate(sig_yielder(function, params, n_sims)):
sims.add_signal(sig, index=ind)

return sims


def sim_across_values(sim_func, params):
def sim_across_values(function, params):
"""Simulate signals across different parameter values.
Parameters
----------
sim_func : callable
function : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : ParamIter or iterable or list of dict
Simulation parameters for `sim_func`.
Simulation parameters for `function`.
Returns
-------
Expand All @@ -77,28 +77,28 @@ def sim_across_values(sim_func, params):
>>> sims = sim_across_values(sim_powerlaw, params)
"""

sims = VariableSimulations(len(params), get_base_params(params), sim_func,
sims = VariableSimulations(len(params), get_base_params(params), function,
update=getattr(params, 'update', None),
component=getattr(params, 'component', None))

sim_func = get_sim_func(sim_func)
function = get_sim_func(function)

for ind, cur_params in enumerate(params):
sims.add_signal(sim_func(**cur_params), cur_params, index=ind)
sims.add_signal(function(**cur_params), cur_params, index=ind)

return sims


def sim_multi_across_values(sim_func, params, n_sims):
def sim_multi_across_values(function, params, n_sims):
"""Simulate multiple signals across different parameter values.
Parameters
----------
sim_func : callable
function : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
params : ParamIter or iterable or list of dict
Simulation parameters for `sim_func`.
Simulation parameters for `function`.
n_sims : int
Number of simulations to create per parameter definition.
Expand Down Expand Up @@ -128,17 +128,17 @@ def sim_multi_across_values(sim_func, params, n_sims):
sims = MultiSimulations(update=getattr(params, 'update', None),
component=getattr(params, 'component', None))
for cur_params in params:
sims.add_signals(sim_multiple(sim_func, cur_params, n_sims))
sims.add_signals(sim_multiple(function, cur_params, n_sims))

return sims


def sim_from_sampler(sim_func, sampler, n_sims):
def sim_from_sampler(function, sampler, n_sims):
"""Simulate a set of signals from a parameter sampler.
Parameters
----------
sim_func : str callable
function : str or callable
Function to create the simulated time series.
If string, should be the name of the desired simulation function.
sampler : ParamSampler
Expand All @@ -164,8 +164,8 @@ def sim_from_sampler(sim_func, sampler, n_sims):
>>> sims = sim_from_sampler(sim_powerlaw, param_sampler, n_sims=2)
"""

sims = VariableSimulations(n_sims, get_base_params(sampler), sim_func)
for ind, (sig, params) in enumerate(sig_sampler(sim_func, sampler, True, n_sims)):
sims = VariableSimulations(n_sims, get_base_params(sampler), function)
for ind, (sig, params) in enumerate(sig_sampler(function, sampler, True, n_sims)):
sims.add_signal(sig, params, index=ind)

return sims
Loading

0 comments on commit a9a393a

Please sign in to comment.