Skip to content

Commit 8b241f5

Browse files
author
Simon Berger
committed
Update users/berger
1 parent 7a7e810 commit 8b241f5

File tree

3 files changed

+333
-124
lines changed

3 files changed

+333
-124
lines changed

users/berger/network/helpers/label_context.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def add_context_1_decoder(
263263
def add_context_1_decoder_recog(
264264
network: Dict,
265265
num_outputs: int,
266+
blank_idx: int = 0,
266267
encoder: str = "encoder",
267268
embedding_size: int = 128,
268269
dec_mlp_args: Dict = {},
@@ -351,9 +352,21 @@ def add_context_1_decoder_recog(
351352
"reuse_params": "output",
352353
}
353354

355+
assert blank_idx == 0, "Blank idx != 0 not implemented for ilm"
356+
# Set p(blank) = 1 and re-normalize the non-blank probs
357+
# so we want P'[b, 0] = 1, sum(P'[b, 1:]) = 1, given a normalized tensor P, i.e. sum(P[b, :]) = 1
358+
# in log space logP'[b, 0] = 0, sum(exp(logP'[b, 1:])) = 1
359+
# so set logP'[b, 1:] <- logP[b, 1:] - log(1 - exp(P[b, 0]))
360+
# then sum(exp(logP'[b, 1:])) = sum(P[1:] / (1 - exp(P[b, 0]))) = sum(P[b, 1:]) / sum(b, P[1:]) = 1
361+
output_unit["ilm_renorm"] = {
362+
"class": "eval",
363+
"from": ["ilm"],
364+
"eval": "tf.concat([tf.zeros(tf.shape(source(0)[:, :1])), source(0)[:, 1:] - tf.math.log(1.0 - tf.exp(source(0)[:, :1]))], axis=-1)",
365+
}
366+
354367
output_unit["output_sub_ilm"] = {
355368
"class": "eval",
356-
"from": ["output", "ilm"],
369+
"from": ["output", "ilm_renorm"],
357370
"eval": f"source(0) - {ilm_scale} * source(1)",
358371
}
359372

users/berger/recipe/rasr/label_tree_and_scorer.py

+54-3
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
use_prior: bool = False,
7979
prior_scale: float = 0.6,
8080
prior_file: Optional[tk.Path] = None,
81-
extra_args: Dict = {},
81+
extra_args: Optional[Dict] = None,
8282
):
8383
self.config = rasr.RasrConfig()
8484
self.post_config = rasr.RasrConfig()
@@ -102,13 +102,64 @@ def __init__(
102102
self.config.priori_scale = prior_scale
103103

104104
# sprint key values #
105-
for key, value in extra_args.items():
106-
self.config[key.replace("_", "-")] = value
105+
if extra_args is not None:
106+
for key, value in extra_args.items():
107+
self.config[key.replace("_", "-")] = value
107108

108109
@property
109110
def scorer_type(self):
110111
return self.config.label_scorer_type
111112

113+
@property
114+
def scale(self):
115+
return self.config.scale
116+
117+
@property
118+
def label_file(self):
119+
if self.config._get("label-file") is not None:
120+
return self.config.label_file
121+
return None
122+
123+
@property
124+
def num_classes(self):
125+
if self.config._get("number-of-classes") is not None:
126+
return self.config.number_of_classes
127+
return None
128+
129+
@property
130+
def use_prior(self):
131+
if self.config._get("use-prior") is not None:
132+
return self.config["use-prior"]
133+
return False
134+
135+
@property
136+
def prior_scale(self):
137+
if self.config._get("priori-scale") is not None:
138+
return self.config["priori-scale"]
139+
return 1.0
140+
141+
@property
142+
def prior_file(self):
143+
if self.config._get("prior-file") is not None:
144+
return self.config["prior-file"]
145+
return None
146+
147+
@property
148+
def extra_args(self):
149+
return {
150+
key: val
151+
for key, val in self.config._items()
152+
if key not in [
153+
"label-scorer-type",
154+
"scale",
155+
"label-file",
156+
"number-of-classes",
157+
"use-prior",
158+
"priori-scale",
159+
"prior-file",
160+
]
161+
}
162+
112163
def apply_config(
113164
self,
114165
path: str,

0 commit comments

Comments
 (0)