From e0bfbecb3653c85aac77c1c71cde923ce83d75bb Mon Sep 17 00:00:00 2001 From: Tilman Hinnerichs Date: Mon, 17 Feb 2025 14:14:45 +0100 Subject: [PATCH 1/9] Add make_probabilistic --- src/grammar_base.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/grammar_base.jl b/src/grammar_base.jl index 95c8f32..45fe4b5 100644 --- a/src/grammar_base.jl +++ b/src/grammar_base.jl @@ -140,6 +140,20 @@ function log_probability(grammar::AbstractGrammar, index::Int)::Real return grammar.log_probabilities[index] end + +""" + make_probabilistic!(grammar::AbstractGrammar) + +If the grammar is not probabilistic yet, initializes the grammar with uniform probabilities per type. If the grammar is alread probabilistic, no changed are made. +""" +function make_probabilistic!(grammar::AbstractGrammar) + if isprobabilistic(grammar) + @warn "Grammar is already probabilistic. The grammar will not be changed." + else + grammar.log_probabilities = [log(1 / length(grammar.bytype[grammar.types[index]])) for index in 1:length(grammar.rules)] + end +end + """ probability(grammar::AbstractGrammar, index::Int)::Real From 9f302384be177f313f31492c39df70f37a0696e0 Mon Sep 17 00:00:00 2001 From: Tilman Hinnerichs Date: Wed, 19 Feb 2025 16:35:30 +0100 Subject: [PATCH 2/9] Refactor normalize,, add init_probabilities, add tests --- src/HerbGrammar.jl | 2 ++ src/csg/probabilistic_csg.jl | 46 +++++++++++++++++++++++++++++------- src/grammar_base.jl | 13 ---------- src/rulenode_operators.jl | 2 ++ test/test_csg.jl | 39 ++++++++++++++++++++++++++++++ 5 files changed, 81 insertions(+), 21 deletions(-) diff --git a/src/HerbGrammar.jl b/src/HerbGrammar.jl index 32b27d3..9da3ee3 100644 --- a/src/HerbGrammar.jl +++ b/src/HerbGrammar.jl @@ -51,6 +51,8 @@ export @pcsgrammar, expr2pcsgrammar, + normalize!, + init_probabilities!, SymbolTable, grammar2symboltable, diff --git a/src/csg/probabilistic_csg.jl b/src/csg/probabilistic_csg.jl index 92c3987..423cfa2 100644 --- a/src/csg/probabilistic_csg.jl +++ b/src/csg/probabilistic_csg.jl @@ -1,5 +1,7 @@ """ + expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar + Function for converting an `Expr` to a [`ContextSensitiveGrammar`](@ref) with probabilities. If the expression is hardcoded, you should use the `@pcsgrammar` macro. Only expressions in the correct format (see [`@pcsgrammar`](@ref)) can be converted. @@ -47,6 +49,8 @@ function expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar end """ + parse_probabilistic_rule(e::Expr) + Parses a single (potentially shorthand) derivation rule of a probabilistic [`ContextSensitiveGrammar`](@ref). Returns `nothing` if the rule is not probabilistic, otherwise a `Tuple` of its type and a `Vector` of probability-rule pairs it expands into. @@ -78,26 +82,52 @@ end """ + normalize!(grammar::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing) + A function for normalizing the probabilities of a probabilistic [`ContextSensitiveGrammar`](@ref). -If the optional `type` argument is provided, only the rules of that type are normalized. +If the optional `type` argument is provided, only the rules of that type are normalized. +If the grammar is not probabilistic, i.e. `grammar.log_probabilities==nothing`, a uniform distribution is initialized. """ -function normalize!(g::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing) - probabilities = map(exp, g.log_probabilities) - types = isnothing(type) ? keys(g.bytype) : [type] +function normalize!(grammar::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing) + if !isprobabilistic(grammar) + @warn "Requesting normalization in a non-probabilistic grammar. Uniform distribution is assumed." + init_probabilities!(grammar) + return grammar + end + + probabilities = map(exp, grammar.log_probabilities) + types = isnothing(type) ? keys(grammar.bytype) : [type] for t ∈ types - total_prob = sum(probabilities[i] for i ∈ g.bytype[t]) + total_prob = sum(probabilities[i] for i ∈ grammar.bytype[t]) if !(total_prob ≈ 1) - for i ∈ g.bytype[t] + for i ∈ grammar.bytype[t] probabilities[i] /= total_prob end end end - g.log_probabilities = map(log, probabilities) - return g + grammar.log_probabilities = map(log, probabilities) + return grammar end + +""" + init_probabilities!(grammar::AbstractGrammar) + +If the grammar is not probabilistic yet, initializes the grammar with uniform probabilities per type. If the grammar is already probabilistic, no changed are made. +""" +function init_probabilities!(grammar::AbstractGrammar) + if isprobabilistic(grammar) + @warn "Tried to init probabilities for grammar, but it is already probabilistic. No changes are made." + else + grammar.log_probabilities = zeros(length(grammar.rules)) + normalize!(grammar) + end + return grammar +end + + """ @pcsgrammar diff --git a/src/grammar_base.jl b/src/grammar_base.jl index 45fe4b5..c7d189a 100644 --- a/src/grammar_base.jl +++ b/src/grammar_base.jl @@ -141,19 +141,6 @@ function log_probability(grammar::AbstractGrammar, index::Int)::Real end -""" - make_probabilistic!(grammar::AbstractGrammar) - -If the grammar is not probabilistic yet, initializes the grammar with uniform probabilities per type. If the grammar is alread probabilistic, no changed are made. -""" -function make_probabilistic!(grammar::AbstractGrammar) - if isprobabilistic(grammar) - @warn "Grammar is already probabilistic. The grammar will not be changed." - else - grammar.log_probabilities = [log(1 / length(grammar.bytype[grammar.types[index]])) for index in 1:length(grammar.rules)] - end -end - """ probability(grammar::AbstractGrammar, index::Int)::Real diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index e28863b..cfac3c4 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -42,6 +42,8 @@ rulesoftype(node::RuleNode, ruleset::Set{Int}, ::Hole) = rulesoftype(node, rules rulesoftype(::Hole, ::Set{Int}, ::RuleNode) = Set() rulesoftype(::Hole, ::Set{Int}, ::Hole) = Set() +rulesoftype(node::AbstractRuleNode, index::Int) = rulesoftype(node, Set{Int}(index)) + """ rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) diff --git a/test/test_csg.jl b/test/test_csg.jl index 87e064f..13dac2b 100644 --- a/test/test_csg.jl +++ b/test/test_csg.jl @@ -185,6 +185,45 @@ end end end + + @testset "make csg probabilistic" begin + grammar = @csgrammar begin + R = |(1:3) + S = |(1:2) + end + # Test correct initialization + @test !isprobabilistic(grammar) + init_probabilities!(grammar) + @test isprobabilistic(grammar) + + probs = grammar.log_probabilities + + # Test equivalence of probabilities + @test probs[1] == probs[2] == probs[3] + @test probs[end-1] == probs[end] + + # Test values + @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) + @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) + end + + @testset "normalize! csg" begin + grammar = @csgrammar begin + R = |(1:3) + S = |(1:2) + end + + grammar.log_probabilities = zeros(length(grammar.rules)) + + normalize!(grammar) + + probs = grammar.log_probabilities + + # Test values + @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) + @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) + end + @testset "Test that strict equality is used during rule creation" begin g₁ = @csgrammar begin From ca4cbabb6c3d8625f99137fda91b403444cda6c8 Mon Sep 17 00:00:00 2001 From: Tilman Hinnerichs Date: Sun, 23 Feb 2025 18:45:12 +0100 Subject: [PATCH 3/9] Moved probabilistic functios to extra file, moved tests to extra files, updated tests for warnings --- src/HerbGrammar.jl | 1 + src/csg/probabilistic_csg.jl | 45 ++++++++++++- src/grammar_base.jl | 43 ------------ src/rulenode_operators.jl | 83 +++++++---------------- test/runtests.jl | 1 + test/test_csg.jl | 116 -------------------------------- test/test_rulenode_operators.jl | 2 +- 7 files changed, 73 insertions(+), 218 deletions(-) diff --git a/src/HerbGrammar.jl b/src/HerbGrammar.jl index 9da3ee3..8698c94 100644 --- a/src/HerbGrammar.jl +++ b/src/HerbGrammar.jl @@ -39,6 +39,7 @@ export nonterminals, iscomplete, parse_rule!, + rulesoftype, @csgrammar, expr2csgrammar, diff --git a/src/csg/probabilistic_csg.jl b/src/csg/probabilistic_csg.jl index 423cfa2..cc3a32d 100644 --- a/src/csg/probabilistic_csg.jl +++ b/src/csg/probabilistic_csg.jl @@ -174,4 +174,47 @@ end macro pcfgrammar(ex) return :(expr2pcsgrammar($(QuoteNode(ex)))) -end \ No newline at end of file +end + + +""" + log_probability(grammar::AbstractGrammar, index::Int)::Real + +Returns the log probability for the rule at `index` in the grammar. + +!!! warning + If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. +""" +function log_probability(grammar::AbstractGrammar, index::Int)::Real + if !isprobabilistic(grammar) + @warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed." + # Assume uniform probability + return log(1 / length(grammar.bytype[grammar.types[index]])) + end + return grammar.log_probabilities[index] +end + +""" + probability(grammar::AbstractGrammar, index::Int)::Real + +Return the probability for a rule in the grammar. +Use [`log_probability`](@ref) whenever possible. + +!!! warning + If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. +""" +function probability(grammar::AbstractGrammar, index::Int)::Real + if !isprobabilistic(grammar) + @warn "Requesting probability in a non-probabilistic grammar.Uniform distribution is assumed." + # Assume uniform probability + return 1 / length(grammar.bytype[grammar.types[index]]) + end + return ℯ^grammar.log_probabilities[index] +end + +""" + isprobabilistic(grammar::AbstractGrammar)::Bool + +Function returns whether a [`AbstractGrammar`](@ref) is probabilistic. +""" +isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities ≡ nothing) \ No newline at end of file diff --git a/src/grammar_base.jl b/src/grammar_base.jl index c7d189a..c86cba6 100644 --- a/src/grammar_base.jl +++ b/src/grammar_base.jl @@ -123,49 +123,6 @@ Returns true if the production rule at rule_index contains the special _() eval iseval(grammar::AbstractGrammar, index::Int)::Bool = grammar.iseval[index] -""" - log_probability(grammar::AbstractGrammar, index::Int)::Real - -Returns the log probability for the rule at `index` in the grammar. - -!!! warning - If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. -""" -function log_probability(grammar::AbstractGrammar, index::Int)::Real - if !isprobabilistic(grammar) - @warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed." - # Assume uniform probability - return log(1 / length(grammar.bytype[grammar.types[index]])) - end - return grammar.log_probabilities[index] -end - - -""" - probability(grammar::AbstractGrammar, index::Int)::Real - -Return the probability for a rule in the grammar. -Use [`log_probability`](@ref) whenever possible. - -!!! warning - If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. -""" -function probability(grammar::AbstractGrammar, index::Int)::Real - if !isprobabilistic(grammar) - @warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed." - # Assume uniform probability - return 1 / length(grammar.bytype[grammar.types[index]]) - end - return ℯ^grammar.log_probabilities[index] -end - -""" - isprobabilistic(grammar::AbstractGrammar)::Bool - -Function returns whether a [`AbstractGrammar`](@ref) is probabilistic. -""" -isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities ≡ nothing) - """ nchildren(grammar::AbstractGrammar, rule_index::Int)::Int diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index cfac3c4..7a9e61b 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -1,60 +1,3 @@ -rulesoftype(::Hole, ::Set{Int}) = Set{Int}() - -""" - rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol) - -Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree. -""" -rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol) = rulesoftype(node, Set{Int}(grammar[ruletype])) -rulesoftype(::Hole, ::AbstractGrammar, ::Symbol) = Set{Int}() - - -""" - rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode) - -Returns every rule in the ruleset that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree. - -!!! warning - The `ignoreNode` must be a subtree of `node` for it to have an effect. -""" -function rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode) - retval = Set() - - if node == ignoreNode - return retval - end - - if get_rule(node) ∈ ruleset - union!(retval, [get_rule(node)]) - end - - if isempty(node.children) - return retval - else - for child ∈ node.children - union!(retval, rulesoftype(child, ruleset)) - end - - return retval - end -end -rulesoftype(node::RuleNode, ruleset::Set{Int}, ::Hole) = rulesoftype(node, ruleset) -rulesoftype(::Hole, ::Set{Int}, ::RuleNode) = Set() -rulesoftype(::Hole, ::Set{Int}, ::Hole) = Set() - -rulesoftype(node::AbstractRuleNode, index::Int) = rulesoftype(node, Set{Int}(index)) - -""" - rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) - -Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree. - -!!! warning - The `ignoreNode` must be a subtree of `node` for it to have an effect. -""" -rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) = rulesoftype(node, Set(grammar[ruletype]), ignoreNode) -rulesoftype(::Hole, ::AbstractGrammar, ::Symbol, ::RuleNode) = Set() - """ swap_node(expr::AbstractRuleNode, new_expr::AbstractRuleNode, path::Vector{Int}) @@ -385,8 +328,34 @@ function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar) return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=1) end +function rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar) + if sum(hole.domain) == 1 # only one element + return log_probability(grammar, only(findall(hole.domain))) + else + throw(ArgumentError("Log probability of a UniformHole requested, which has more than 1 element within its domain. This is ambiguous.")) + end +end rulenode_log_probability(::Hole, ::AbstractGrammar) = 1 +""" + max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGrammar) + +Calculates the highest possible probability within an `AbstractRuleNode`. +That is, for each node and its domain, get the highest probability and multiply it with the probabilities of its children, if present. +As we operate with log probabilities, we sum the logarithms. +""" +max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGrammar) = rulenode_log_probability(rulenode, grammar) + +function max_rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar) + max_index = argmax(i -> grammar.log_probabilities[i], findall(hole.domain)) + return log_probability(grammar, max_index) + sum((max_rulenode_log_probability(c, grammar) for c ∈ node.children), init=1) +end + +function max_rulenode_log_probability(hole::Hole, grammar::AbstractGrammar) + max_index = argmax(i -> grammar.log_probabilities[i], findall(hole.domain)) + return log_probability(grammar, max_index) +end + """ iscomplete(grammar::AbstractGrammar, node::RuleNode) diff --git a/test/runtests.jl b/test/runtests.jl index 52b5a4c..c37833f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,4 +10,5 @@ using Test include("test_rulenode2expr.jl") include("test_expr2rulenode.jl") include("test_utils.jl") + include("test_pcsg.jl") end diff --git a/test/test_csg.jl b/test/test_csg.jl index 13dac2b..20b4609 100644 --- a/test/test_csg.jl +++ b/test/test_csg.jl @@ -108,122 +108,6 @@ # delete file afterwards rm("toy_cfg.grammar") end - - @testset "Writing and loading probabilistic CSG to/from disk" begin - g₁ = @pcsgrammar begin - 0.5 : Real = |(0:3) - 0.5 : Real = x - end - - store_csg(g₁, "toy_pcfg.grammar") - g₂ = read_pcsg("toy_pcfg.grammar") - @test :Real ∈ g₂.types - @test g₂.rules == [0, 1, 2, 3, :x] - @test g₂.log_probabilities == g₁.log_probabilities - - - # delete file afterwards - rm("toy_pcfg.grammar") - end - - @testset "creating probabilistic CSG" begin - g = @pcsgrammar begin - 0.5 : R = |(0:2) - 0.3 : R = x - 0.2 : B = true | false - end - - @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 - @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 - @test g.bytype[:R] == Int[1,2,3,4] - @test g.bytype[:B] == Int[5,6] - @test :R ∈ g.types && :B ∈ g.types - end - - @testset "creating a non-normalized PCSG" begin - g = @pcsgrammar begin - 0.5 : R = |(0:2) - 0.5 : R = x - 0.5 : B = true | false - end - - @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 - @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 - @test g.rules == [0, 1, 2, :x, :true, :false] - @test g.bytype[:R] == Int[1,2,3,4] - @test g.bytype[:B] == Int[5,6] - @test :R ∈ g.types && :B ∈ g.types - end - - @testset "Adding a rule to a probabilistic CSG" begin - g = @pcsgrammar begin - 0.5 : R = x - 0.5 : R = R + R - end - - add_rule!(g, 0.5, :(R = 1 | 2)) - - @test g.rules == [:x, :(R + R), 1, 2] - - add_rule!(g, 0.5, :(B = t | f)) - - @test g.bytype[:B] == Int[5, 6] - @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 - @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 - end - - @testset "Creating a non-probabilistic rule in a PCSG" begin - expected_log = ( - :error, - "Rule without probability encountered in probabilistic grammar. Rule ignored." - ) - - @test_logs expected_log match_mode=:any begin - @pcsgrammar begin - 0.5 : R = x - R = R + R - end - end - end - - @testset "make csg probabilistic" begin - grammar = @csgrammar begin - R = |(1:3) - S = |(1:2) - end - # Test correct initialization - @test !isprobabilistic(grammar) - init_probabilities!(grammar) - @test isprobabilistic(grammar) - - probs = grammar.log_probabilities - - # Test equivalence of probabilities - @test probs[1] == probs[2] == probs[3] - @test probs[end-1] == probs[end] - - # Test values - @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) - @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) - end - - @testset "normalize! csg" begin - grammar = @csgrammar begin - R = |(1:3) - S = |(1:2) - end - - grammar.log_probabilities = zeros(length(grammar.rules)) - - normalize!(grammar) - - probs = grammar.log_probabilities - - # Test values - @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) - @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) - end - @testset "Test that strict equality is used during rule creation" begin g₁ = @csgrammar begin diff --git a/test/test_rulenode_operators.jl b/test/test_rulenode_operators.jl index dee0c14..1ce2faa 100644 --- a/test/test_rulenode_operators.jl +++ b/test/test_rulenode_operators.jl @@ -15,4 +15,4 @@ end @test !isvariable(g₁, RuleNode(7, g₁), SomeDefinitions) @test isvariable(g₁, RuleNode(7, g₁)) end -end +end \ No newline at end of file From 067e7222f4d784d10ef8e249d068689d8c1545b1 Mon Sep 17 00:00:00 2001 From: Tilman Hinnerichs Date: Sun, 23 Feb 2025 18:45:53 +0100 Subject: [PATCH 4/9] Added file for probabilistic tests only --- test/test_pcsg.jl | 120 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 test/test_pcsg.jl diff --git a/test/test_pcsg.jl b/test/test_pcsg.jl new file mode 100644 index 0000000..4fcb384 --- /dev/null +++ b/test/test_pcsg.jl @@ -0,0 +1,120 @@ +@testset verbose=true "Probabilistic CSGs" begin + @testset "Writing and loading probabilistic CSG to/from disk" begin + g₁ = @pcsgrammar begin + 0.5 : Real = |(0:3) + 0.5 : Real = x + end + + store_csg(g₁, "toy_pcfg.grammar") + g₂ = read_pcsg("toy_pcfg.grammar") + @test :Real ∈ g₂.types + @test g₂.rules == [0, 1, 2, 3, :x] + @test g₂.log_probabilities == g₁.log_probabilities + + # delete file afterwards + rm("toy_pcfg.grammar") + end + + @testset "Creating probabilistic CSG" begin + g = @pcsgrammar begin + 0.5 : R = |(0:2) + 0.3 : R = x + 0.2 : B = true | false + end + + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + @test g.bytype[:R] == Int[1,2,3,4] + @test g.bytype[:B] == Int[5,6] + @test :R ∈ g.types && :B ∈ g.types + end + + @testset "Creating a non-normalized PCSG" begin + g = @pcsgrammar begin + 0.5 : R = |(0:2) + 0.5 : R = x + 0.5 : B = true | false + end + + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + @test g.rules == [0, 1, 2, :x, :true, :false] + @test g.bytype[:R] == Int[1,2,3,4] + @test g.bytype[:B] == Int[5,6] + @test :R ∈ g.types && :B ∈ g.types + end + + @testset "Adding a rule to a probabilistic CSG" begin + + g = @pcsgrammar begin + 0.5 : R = x + 0.5 : R = R + R + end + + add_rule!(g, 0.5, :(R = 1 | 2)) + + @test g.rules == [:x, :(R + R), 1, 2] + + add_rule!(g, 0.5, :(B = t | f)) + + @test g.bytype[:B] == Int[5, 6] + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + end + + @testset "Creating a non-probabilistic rule in a PCSG" begin + expected_log = ( + :error, + "Rule without probability encountered in probabilistic grammar. Rule ignored." + ) + + @test_logs expected_log match_mode=:any begin + @pcsgrammar begin + 0.5 : R = x + R = R + R + end + end + end + + @testset "Make csg probabilistic" begin + grammar = @csgrammar begin + R = |(1:3) + S = |(1:2) + end + # Test correct initialization + @test !isprobabilistic(grammar) + init_probabilities!(grammar) + @test isprobabilistic(grammar) + + probs = grammar.log_probabilities + + # Test equivalence of probabilities + @test probs[1] == probs[2] == probs[3] + @test probs[end-1] == probs[end] + + # Test values + @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) + @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) + + @test_logs (:warn, "Tried to init probabilities for grammar, but it is already probabilistic. No changes are made.") init_probabilities!(grammar) + end + + @testset "Test normalize! csg" begin + grammar = @csgrammar begin + R = |(1:3) + S = |(1:2) + end + + @test_logs (:warn, "Requesting normalization in a non-probabilistic grammar. Uniform distribution is assumed.") normalize!(grammar) + + grammar.log_probabilities = zeros(length(grammar.rules)) + + normalize!(grammar) + + probs = grammar.log_probabilities + + # Test values + @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) + @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) + end +end \ No newline at end of file From b42a4d26b1a62da6624678d9d2338778c3fd8abf Mon Sep 17 00:00:00 2001 From: Tilman Hinnerichs Date: Sun, 23 Feb 2025 19:30:00 +0100 Subject: [PATCH 5/9] Added max_rulenode_log_probabilities, fixed gnarly bug, added proper tests --- src/HerbGrammar.jl | 1 + src/rulenode_operators.jl | 6 +++--- test/test_pcsg.jl | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/HerbGrammar.jl b/src/HerbGrammar.jl index 8698c94..3b50c16 100644 --- a/src/HerbGrammar.jl +++ b/src/HerbGrammar.jl @@ -60,6 +60,7 @@ export rulenode2expr, rulenode_log_probability, + max_rulenode_log_probability, mindepth_map, mindepth, diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index 7a9e61b..15ee241 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -325,7 +325,7 @@ end Calculates the log probability associated with a rulenode in a probabilistic grammar. """ function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar) - return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=1) + return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=0) end function rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar) @@ -335,7 +335,7 @@ function rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar) throw(ArgumentError("Log probability of a UniformHole requested, which has more than 1 element within its domain. This is ambiguous.")) end end -rulenode_log_probability(::Hole, ::AbstractGrammar) = 1 +rulenode_log_probability(::Hole, ::AbstractGrammar) = 0 """ max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGrammar) @@ -348,7 +348,7 @@ max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGramma function max_rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar) max_index = argmax(i -> grammar.log_probabilities[i], findall(hole.domain)) - return log_probability(grammar, max_index) + sum((max_rulenode_log_probability(c, grammar) for c ∈ node.children), init=1) + return log_probability(grammar, max_index) + sum((max_rulenode_log_probability(c, grammar) for c ∈ hole.children), init=0) end function max_rulenode_log_probability(hole::Hole, grammar::AbstractGrammar) diff --git a/test/test_pcsg.jl b/test/test_pcsg.jl index 4fcb384..3c5c803 100644 --- a/test/test_pcsg.jl +++ b/test/test_pcsg.jl @@ -117,4 +117,36 @@ @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) end + + @testset "Test rulenode_log_probability" begin + grammar = @pcsgrammar begin + 0.5 : R = |(1:7) + 0.5 : R = x + end + rn = @rulenode 2{4,8} + + @test isapprox(exp(rulenode_log_probability(rn, grammar)), 1/14 * 1/14 * 0.5) + + hole1 = UniformHole(BitVector((0, 1, 0)), [RuleNode(3), RuleNode(4)]) + hole2 = UniformHole(BitVector((1, 1, 1)), [RuleNode(3), RuleNode(4)]) + generic_hole = Hole([1,1,1,1,0]) + + @test rulenode_log_probability(hole1, grammar) == log_probability(grammar, 2) + @test_throws ArgumentError rulenode_log_probability(hole2, grammar) + @test rulenode_log_probability(generic_hole, grammar) == 0 + end + + @testset "Test max_rulenode_log_probability" begin + grammar = @pcsgrammar begin + 0.5 : R = |(1:2) + 0.5 : R = x + end + + hole = UniformHole(BitVector((1, 1, 1)), [RuleNode(2), RuleNode(3)]) + generic_hole = Hole([1,1,0]) + + @test exp(max_rulenode_log_probability(hole, grammar)) == 1/2 * 1/2 * 1/4 + @test exp(max_rulenode_log_probability(generic_hole, grammar)) == 0.25 + end + end \ No newline at end of file From 0fb467784110d7ac0ddc4e05b825e8861c442148 Mon Sep 17 00:00:00 2001 From: Tilman Hinnerichs Date: Sun, 23 Feb 2025 19:31:28 +0100 Subject: [PATCH 6/9] Fix documentation --- src/rulenode_operators.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index 15ee241..e39dbc6 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -322,6 +322,8 @@ function expr2rulenode(expr::Union{Symbol,Number}, grammar::AbstractGrammar) end """ + rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar) + Calculates the log probability associated with a rulenode in a probabilistic grammar. """ function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar) From 2074867abbc686f902000be3405aac86f5a6408b Mon Sep 17 00:00:00 2001 From: Tilman Hinnerichs Date: Mon, 24 Feb 2025 12:37:07 +0100 Subject: [PATCH 7/9] Fixed rulenode_log_probability typing to include StateHoles --- src/rulenode_operators.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index e39dbc6..4b23dd5 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -330,7 +330,7 @@ function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar) return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=0) end -function rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar) +function rulenode_log_probability(hole::AbstractHole, grammar::AbstractGrammar) if sum(hole.domain) == 1 # only one element return log_probability(grammar, only(findall(hole.domain))) else @@ -346,9 +346,11 @@ Calculates the highest possible probability within an `AbstractRuleNode`. That is, for each node and its domain, get the highest probability and multiply it with the probabilities of its children, if present. As we operate with log probabilities, we sum the logarithms. """ -max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGrammar) = rulenode_log_probability(rulenode, grammar) +function max_rulenode_log_probability(rulenode::RuleNode, grammar::AbstractGrammar) + return log_probability(grammar, get_rule(rulenode)) + sum((max_rulenode_log_probability(c, grammar) for c ∈ rulenode.children), init=0) +end -function max_rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar) +function max_rulenode_log_probability(hole::AbstractHole, grammar::AbstractGrammar) max_index = argmax(i -> grammar.log_probabilities[i], findall(hole.domain)) return log_probability(grammar, max_index) + sum((max_rulenode_log_probability(c, grammar) for c ∈ hole.children), init=0) end From df87308dcbd9d942ba5fb9ba12c2ddfd806425f5 Mon Sep 17 00:00:00 2001 From: Tilman Hinnerichs Date: Mon, 24 Feb 2025 14:21:10 +0100 Subject: [PATCH 8/9] Simplify and optimize --- src/rulenode_operators.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index 4b23dd5..8bcad87 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -351,13 +351,11 @@ function max_rulenode_log_probability(rulenode::RuleNode, grammar::AbstractGramm end function max_rulenode_log_probability(hole::AbstractHole, grammar::AbstractGrammar) - max_index = argmax(i -> grammar.log_probabilities[i], findall(hole.domain)) - return log_probability(grammar, max_index) + sum((max_rulenode_log_probability(c, grammar) for c ∈ hole.children), init=0) + return maximum(grammar.log_probabilities[findall(hole.domain)]) + sum((max_rulenode_log_probability(c, grammar) for c ∈ hole.children), init=0) end function max_rulenode_log_probability(hole::Hole, grammar::AbstractGrammar) - max_index = argmax(i -> grammar.log_probabilities[i], findall(hole.domain)) - return log_probability(grammar, max_index) + return maximum(grammar.log_probabilities[findall(hole.domain)]) end From 3f546a5f7332ad5979076d049e1ee3f96813918a Mon Sep 17 00:00:00 2001 From: Reuben Gardos Reid <5456207+ReubenJ@users.noreply.github.com> Date: Tue, 25 Feb 2025 16:57:21 +0100 Subject: [PATCH 9/9] Bump patch version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 64d65da..59bc579 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HerbGrammar" uuid = "4ef9e186-2fe5-4b24-8de7-9f7291f24af7" authors = ["Sebastijan Dumancic ", "Jaap de Jong ", "Nicolae Filat ", "Piotr Cichoń "] -version = "0.5.0" +version = "0.5.1" [deps] HerbCore = "2b23ba43-8213-43cb-b5ea-38c12b45bd45"