Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spm speedup #117

Merged
merged 4 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ par-geds-raw-blindcal = "legenddataflow.scripts.par.geds.raw.blindcal:par_geds
par-geds-raw-blindcheck = "legenddataflow.scripts.par.geds.raw.blindcheck:par_geds_raw_blindcheck"
par-geds-tcm-pulser = "legenddataflow.scripts.par.geds.tcm.pulser:par_geds_tcm_pulser"
par-spms-dsp-trg-thr = "legenddataflow.scripts.par.spms.dsp.trigger_threshold:par_spms_dsp_trg_thr"
par-spms-dsp-trg-thr-multi = "legenddataflow.scripts.par.spms.dsp.trigger_threshold:par_spms_dsp_trg_thr_multi"

[tool.uv.workspace]
exclude = ["generated", "inputs", "software", "workflow"]
Expand Down
6 changes: 5 additions & 1 deletion workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ from datetime import datetime
from collections import OrderedDict
import logging

from dbetto import AttrsDict
from dbetto import AttrsDict, TextDB
from legendmeta import LegendMetadata
from legenddataflow import CalGrouping, execenv, utils
from legenddataflow.patterns import get_pattern_tier
from legenddataflow.pre_compile_catalog import pre_compile_catalog

utils.subst_vars_in_snakemake_config(workflow, config)
config = AttrsDict(config)
Expand All @@ -33,6 +34,9 @@ meta = utils.metadata_path(config)
det_status = utils.det_status_path(config)
basedir = workflow.basedir

det_status_textdb = pre_compile_catalog(Path(det_status) / "statuses")
channelmap_textdb = pre_compile_catalog(Path(chan_maps) / "channelmaps")

time = datetime.now().strftime("%Y%m%dT%H%M%SZ")

# NOTE: this will attempt a clone of legend-metadata, if the directory does not exist
Expand Down
69 changes: 44 additions & 25 deletions workflow/rules/chanlist_gen.smk
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import random
import re
from pathlib import Path

from legenddataflow.FileKey import ChannelProcKey
from legenddataflow.patterns import (
Expand All @@ -11,31 +12,42 @@ from legenddataflow.patterns import (
)
from legenddataflow import execenv_pyexe
from legenddataflow.utils import filelist_path
from dbetto import TextDB
from dbetto.catalog import Catalog


# FIXME: the system argument should always be explicitly supplied
def get_chanlist(
setup, keypart, workflow, config, det_status, chan_maps, system="geds"
):
def get_chanlist(config, keypart, workflow, det_status, channelmap, system):
key = ChannelProcKey.parse_keypart(keypart)

flist_path = filelist_path(setup)
os.makedirs(flist_path, exist_ok=True)
output_file = os.path.join(
flist_path,
f"all-{key.experiment}-{key.period}-{key.run}-{key.datatype}-{key.timestamp}-channels.chankeylist.{random.randint(0,99999):05d}",
)
if isinstance(det_status, (str, Path)):
det_status = TextDB(det_status, lazy=True)

os.system(
execenv_pyexe(config, "create-chankeylist")
+ f"--det-status {det_status} --channelmap {chan_maps} --timestamp {key.timestamp} "
f"--datatype {key.datatype} --output-file {output_file} --system {system}"
)
if isinstance(channelmap, (str, Path)):
channelmap = TextDB(channelmap, lazy=True)

if isinstance(det_status, TextDB):
status_map = det_status.statuses.on(key.timestamp, system=key.datatype)
else:
status_map = det_status.valid_for(key.timestamp, system=key.datatype)
if isinstance(channelmap, TextDB):
chmap = channelmap.channelmaps.on(key.timestamp)
else:
chmap = channelmap.valid_for(key.timestamp)

with open(output_file) as r:
chan_list = r.read().splitlines()
os.remove(output_file)
return chan_list
# only restrict to a certain system (geds, spms, ...)
channels = []
for channel in chmap.map("system", unique=False)[system].map("name"):
if channel not in status_map:
msg = f"{channel} is not found in the status map (on {key.timestamp})"
raise RuntimeError(msg)
if status_map[channel].processable is False:
continue
channels.append(channel)

if len(channels) == 0:
print("WARNING: No channels found") # noqa: T201

return channels


def get_par_chanlist(
Expand All @@ -45,15 +57,13 @@ def get_par_chanlist(
basedir,
det_status,
chan_maps,
system,
datatype="cal",
system="geds",
name=None,
extension="yaml",
):

chan_list = get_chanlist(
setup, keypart, workflow, config, det_status, chan_maps, system
)
chan_list = get_chanlist(setup, keypart, workflow, det_status, chan_maps, system)

par_pattern = get_pattern_pars_tmp_channel(
setup, tier, name, datatype=datatype, extension=extension
Expand All @@ -64,9 +74,18 @@ def get_par_chanlist(
return filenames


def get_plt_chanlist(setup, keypart, tier, basedir, det_status, chan_maps, name=None):
def get_plt_chanlist(
setup,
keypart,
tier,
basedir,
det_status,
chan_maps,
system,
name=None,
):

chan_list = get_chanlist(setup, keypart, workflow, config, det_status, chan_maps)
chan_list = get_chanlist(setup, keypart, workflow, det_status, chan_maps, system)

par_pattern = get_pattern_plts_tmp_channel(setup, tier, name)

Expand Down
50 changes: 13 additions & 37 deletions workflow/rules/channel_merge.smk
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from legenddataflow.utils import set_last_rule_name
from legenddataflow.execenv import execenv_pyexe


def build_merge_rules(tier, lh5_merge=False, lh5_tier=None):
def build_merge_rules(tier, lh5_merge=False, lh5_tier=None, system="geds"):
if lh5_tier is None:
lh5_tier = tier
rule:
Expand All @@ -15,8 +15,9 @@ def build_merge_rules(tier, lh5_merge=False, lh5_tier=None):
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-cal-{wildcards.timestamp}-channels",
tier,
basedir,
det_status,
chan_maps,
det_status_textdb,
channelmap_textdb,
system=system,
),
output:
patterns.get_pattern_plts(config, tier),
Expand All @@ -36,8 +37,9 @@ def build_merge_rules(tier, lh5_merge=False, lh5_tier=None):
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-cal-{wildcards.timestamp}-channels",
tier,
basedir,
det_status,
chan_maps,
det_status_textdb,
channelmap_textdb,
system=system,
name="objects",
extension="pkl",
),
Expand Down Expand Up @@ -66,8 +68,9 @@ def build_merge_rules(tier, lh5_merge=False, lh5_tier=None):
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-cal-{wildcards.timestamp}-channels",
tier,
basedir,
det_status,
chan_maps,
det_status_textdb,
channelmap_textdb,
system=system,
),
output:
temp(
Expand All @@ -86,34 +89,6 @@ def build_merge_rules(tier, lh5_merge=False, lh5_tier=None):

set_last_rule_name(workflow, f"build_pars_{tier}_db")

rule:
"""Merge pars for SiPM channels in a single pars file."""
input:
lambda wildcards: get_par_chanlist(
config,
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-{wildcards.datatype}-{wildcards.timestamp}-channels",
tier,
basedir,
det_status,
chan_maps,
datatype=wildcards.datatype,
system="spms"
),
output:
patterns.get_pattern_pars(
config,
tier,
name="spms",
datatype="{datatype}",
),
group:
f"merge-{tier}"
shell:
execenv_pyexe(config, "merge-channels") + \
"--input {input} "
"--output {output} "

set_last_rule_name(workflow, f"build_pars_spms_{tier}_db")

rule:
input:
Expand All @@ -122,8 +97,9 @@ def build_merge_rules(tier, lh5_merge=False, lh5_tier=None):
f"all-{wildcards.experiment}-{wildcards.period}-{wildcards.run}-cal-{wildcards.timestamp}-channels",
lh5_tier,
basedir,
det_status,
chan_maps,
det_status_textdb,
channelmap_textdb,
system=system,
extension="lh5" if lh5_merge is True else inspect.signature(get_par_chanlist).parameters['extension'].default,
),
in_db=patterns.get_pattern_pars_tmp(
Expand Down
9 changes: 8 additions & 1 deletion workflow/rules/common.smk
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,12 @@ def get_search_pattern(tier):


def get_table_name(metadata, config, datatype, timestamp, detector, tier):
chmap = metadata.channelmap(timestamp, system=datatype)
if isinstance(metadata, (str, Path)):
chmap = metadata.channelmap(timestamp, system=datatype)
elif isinstance(metadata, Catalog):
chmap = metadata.valid_for(timestamp, system=datatype)
else:
raise ValueError(
f"metadata must be a string or a Catalog object, not {type(metadata)}"
)
return config.table_format[tier].format(ch=chmap[detector].daq.rawid)
57 changes: 41 additions & 16 deletions workflow/rules/dsp_pars_geds.smk
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ rule build_pars_dsp_tau_geds:
datatype="cal",
channel="{channel}",
raw_table_name=lambda wildcards: get_table_name(
metadata, config, "cal", wildcards.timestamp, wildcards.channel, "raw"
channelmap_textdb,
config,
"cal",
wildcards.timestamp,
wildcards.channel,
"raw",
),
output:
decay_const=temp(get_pattern_pars_tmp_channel(config, "dsp", "decay_constant")),
Expand Down Expand Up @@ -57,14 +62,19 @@ rule build_pars_evtsel_geds:
filelist_path(config), "all-{experiment}-{period}-{run}-cal-raw.filelist"
),
pulser_file=get_pattern_pars_tmp_channel(config, "tcm", "pulser_ids"),
database=get_pattern_pars_tmp_channel(config, "dsp", "decay_constant"),
database=rules.build_pars_dsp_tau_geds.output.decay_const,
raw_cal_curve=get_blinding_curve_file,
params:
timestamp="{timestamp}",
datatype="cal",
channel="{channel}",
raw_table_name=lambda wildcards: get_table_name(
metadata, config, "cal", wildcards.timestamp, wildcards.channel, "raw"
channelmap_textdb,
config,
"cal",
wildcards.timestamp,
wildcards.channel,
"raw",
),
output:
peak_file=temp(
Expand Down Expand Up @@ -97,14 +107,19 @@ rule build_pars_dsp_nopt_geds:
files=os.path.join(
filelist_path(config), "all-{experiment}-{period}-{run}-fft-raw.filelist"
),
database=get_pattern_pars_tmp_channel(config, "dsp", "decay_constant"),
inplots=get_pattern_plts_tmp_channel(config, "dsp", "decay_constant"),
database=rules.build_pars_dsp_tau_geds.output.decay_const,
inplots=rules.build_pars_dsp_tau_geds.output.plots,
params:
timestamp="{timestamp}",
datatype="cal",
channel="{channel}",
raw_table_name=lambda wildcards: get_table_name(
metadata, config, "cal", wildcards.timestamp, wildcards.channel, "raw"
channelmap_textdb,
config,
"cal",
wildcards.timestamp,
wildcards.channel,
"raw",
),
output:
dsp_pars_nopt=temp(
Expand Down Expand Up @@ -137,15 +152,20 @@ rule build_pars_dsp_dplms_geds:
fft_files=os.path.join(
filelist_path(config), "all-{experiment}-{period}-{run}-fft-raw.filelist"
),
peak_file=get_pattern_pars_tmp_channel(config, "dsp", "peaks", extension="lh5"),
database=get_pattern_pars_tmp_channel(config, "dsp", "noise_optimization"),
inplots=get_pattern_plts_tmp_channel(config, "dsp", "noise_optimization"),
peak_file=rules.build_pars_evtsel_geds.output.peak_file,
database=rules.build_pars_dsp_nopt_geds.output.dsp_pars_nopt,
inplots=rules.build_pars_dsp_nopt_geds.output.plots,
params:
timestamp="{timestamp}",
datatype="cal",
channel="{channel}",
raw_table_name=lambda wildcards: get_table_name(
metadata, config, "cal", wildcards.timestamp, wildcards.channel, "raw"
channelmap_textdb,
config,
"cal",
wildcards.timestamp,
wildcards.channel,
"raw",
),
output:
dsp_pars=temp(get_pattern_pars_tmp_channel(config, "dsp", "dplms")),
Expand Down Expand Up @@ -176,15 +196,20 @@ rule build_pars_dsp_dplms_geds:
# This rule builds the optimal energy filter parameters for the dsp using calibration dsp files
rule build_pars_dsp_eopt_geds:
input:
peak_file=get_pattern_pars_tmp_channel(config, "dsp", "peaks", extension="lh5"),
decay_const=get_pattern_pars_tmp_channel(config, "dsp", "dplms"),
inplots=get_pattern_plts_tmp_channel(config, "dsp", "dplms"),
peak_file=rules.build_pars_evtsel_geds.output.peak_file,
decay_const=rules.build_pars_dsp_dplms_geds.output.dsp_pars,
inplots=rules.build_pars_dsp_dplms_geds.output.plots,
params:
timestamp="{timestamp}",
datatype="cal",
channel="{channel}",
raw_table_name=lambda wildcards: get_table_name(
metadata, config, "cal", wildcards.timestamp, wildcards.channel, "raw"
channelmap_textdb,
config,
"cal",
wildcards.timestamp,
wildcards.channel,
"raw",
),
output:
dsp_pars=temp(get_pattern_pars_tmp_channel(config, "dsp_eopt")),
Expand Down Expand Up @@ -246,8 +271,8 @@ rule build_svm_dsp_geds:

rule build_pars_dsp_svm_geds:
input:
dsp_pars=get_pattern_pars_tmp_channel(config, "dsp_eopt"),
svm_file=get_pattern_pars(config, "dsp", "svm", extension="pkl"),
dsp_pars=rules.build_pars_dsp_eopt_geds.output.dsp_pars,
svm_file=rules.build_svm_dsp_geds.output.dsp_pars,
output:
dsp_pars=temp(get_pattern_pars_tmp_channel(config, "dsp")),
log:
Expand Down
Loading
Loading