diff --git a/Project.toml b/Project.toml index 64d65da..59bc579 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HerbGrammar" uuid = "4ef9e186-2fe5-4b24-8de7-9f7291f24af7" authors = ["Sebastijan Dumancic ", "Jaap de Jong ", "Nicolae Filat ", "Piotr Cichoń "] -version = "0.5.0" +version = "0.5.1" [deps] HerbCore = "2b23ba43-8213-43cb-b5ea-38c12b45bd45" diff --git a/src/HerbGrammar.jl b/src/HerbGrammar.jl index 32b27d3..3b50c16 100644 --- a/src/HerbGrammar.jl +++ b/src/HerbGrammar.jl @@ -39,6 +39,7 @@ export nonterminals, iscomplete, parse_rule!, + rulesoftype, @csgrammar, expr2csgrammar, @@ -51,12 +52,15 @@ export @pcsgrammar, expr2pcsgrammar, + normalize!, + init_probabilities!, SymbolTable, grammar2symboltable, rulenode2expr, rulenode_log_probability, + max_rulenode_log_probability, mindepth_map, mindepth, diff --git a/src/csg/probabilistic_csg.jl b/src/csg/probabilistic_csg.jl index 92c3987..cc3a32d 100644 --- a/src/csg/probabilistic_csg.jl +++ b/src/csg/probabilistic_csg.jl @@ -1,5 +1,7 @@ """ + expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar + Function for converting an `Expr` to a [`ContextSensitiveGrammar`](@ref) with probabilities. If the expression is hardcoded, you should use the `@pcsgrammar` macro. Only expressions in the correct format (see [`@pcsgrammar`](@ref)) can be converted. @@ -47,6 +49,8 @@ function expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar end """ + parse_probabilistic_rule(e::Expr) + Parses a single (potentially shorthand) derivation rule of a probabilistic [`ContextSensitiveGrammar`](@ref). Returns `nothing` if the rule is not probabilistic, otherwise a `Tuple` of its type and a `Vector` of probability-rule pairs it expands into. @@ -78,26 +82,52 @@ end """ + normalize!(grammar::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing) + A function for normalizing the probabilities of a probabilistic [`ContextSensitiveGrammar`](@ref). -If the optional `type` argument is provided, only the rules of that type are normalized. +If the optional `type` argument is provided, only the rules of that type are normalized. +If the grammar is not probabilistic, i.e. `grammar.log_probabilities==nothing`, a uniform distribution is initialized. """ -function normalize!(g::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing) - probabilities = map(exp, g.log_probabilities) - types = isnothing(type) ? keys(g.bytype) : [type] +function normalize!(grammar::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing) + if !isprobabilistic(grammar) + @warn "Requesting normalization in a non-probabilistic grammar. Uniform distribution is assumed." + init_probabilities!(grammar) + return grammar + end + + probabilities = map(exp, grammar.log_probabilities) + types = isnothing(type) ? keys(grammar.bytype) : [type] for t ∈ types - total_prob = sum(probabilities[i] for i ∈ g.bytype[t]) + total_prob = sum(probabilities[i] for i ∈ grammar.bytype[t]) if !(total_prob ≈ 1) - for i ∈ g.bytype[t] + for i ∈ grammar.bytype[t] probabilities[i] /= total_prob end end end - g.log_probabilities = map(log, probabilities) - return g + grammar.log_probabilities = map(log, probabilities) + return grammar end + +""" + init_probabilities!(grammar::AbstractGrammar) + +If the grammar is not probabilistic yet, initializes the grammar with uniform probabilities per type. If the grammar is already probabilistic, no changed are made. +""" +function init_probabilities!(grammar::AbstractGrammar) + if isprobabilistic(grammar) + @warn "Tried to init probabilities for grammar, but it is already probabilistic. No changes are made." + else + grammar.log_probabilities = zeros(length(grammar.rules)) + normalize!(grammar) + end + return grammar +end + + """ @pcsgrammar @@ -144,4 +174,47 @@ end macro pcfgrammar(ex) return :(expr2pcsgrammar($(QuoteNode(ex)))) -end \ No newline at end of file +end + + +""" + log_probability(grammar::AbstractGrammar, index::Int)::Real + +Returns the log probability for the rule at `index` in the grammar. + +!!! warning + If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. +""" +function log_probability(grammar::AbstractGrammar, index::Int)::Real + if !isprobabilistic(grammar) + @warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed." + # Assume uniform probability + return log(1 / length(grammar.bytype[grammar.types[index]])) + end + return grammar.log_probabilities[index] +end + +""" + probability(grammar::AbstractGrammar, index::Int)::Real + +Return the probability for a rule in the grammar. +Use [`log_probability`](@ref) whenever possible. + +!!! warning + If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. +""" +function probability(grammar::AbstractGrammar, index::Int)::Real + if !isprobabilistic(grammar) + @warn "Requesting probability in a non-probabilistic grammar.Uniform distribution is assumed." + # Assume uniform probability + return 1 / length(grammar.bytype[grammar.types[index]]) + end + return ℯ^grammar.log_probabilities[index] +end + +""" + isprobabilistic(grammar::AbstractGrammar)::Bool + +Function returns whether a [`AbstractGrammar`](@ref) is probabilistic. +""" +isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities ≡ nothing) \ No newline at end of file diff --git a/src/grammar_base.jl b/src/grammar_base.jl index 95c8f32..c86cba6 100644 --- a/src/grammar_base.jl +++ b/src/grammar_base.jl @@ -123,48 +123,6 @@ Returns true if the production rule at rule_index contains the special _() eval iseval(grammar::AbstractGrammar, index::Int)::Bool = grammar.iseval[index] -""" - log_probability(grammar::AbstractGrammar, index::Int)::Real - -Returns the log probability for the rule at `index` in the grammar. - -!!! warning - If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. -""" -function log_probability(grammar::AbstractGrammar, index::Int)::Real - if !isprobabilistic(grammar) - @warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed." - # Assume uniform probability - return log(1 / length(grammar.bytype[grammar.types[index]])) - end - return grammar.log_probabilities[index] -end - -""" - probability(grammar::AbstractGrammar, index::Int)::Real - -Return the probability for a rule in the grammar. -Use [`log_probability`](@ref) whenever possible. - -!!! warning - If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed. -""" -function probability(grammar::AbstractGrammar, index::Int)::Real - if !isprobabilistic(grammar) - @warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed." - # Assume uniform probability - return 1 / length(grammar.bytype[grammar.types[index]]) - end - return ℯ^grammar.log_probabilities[index] -end - -""" - isprobabilistic(grammar::AbstractGrammar)::Bool - -Function returns whether a [`AbstractGrammar`](@ref) is probabilistic. -""" -isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities ≡ nothing) - """ nchildren(grammar::AbstractGrammar, rule_index::Int)::Int diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index e28863b..8bcad87 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -1,58 +1,3 @@ -rulesoftype(::Hole, ::Set{Int}) = Set{Int}() - -""" - rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol) - -Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree. -""" -rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol) = rulesoftype(node, Set{Int}(grammar[ruletype])) -rulesoftype(::Hole, ::AbstractGrammar, ::Symbol) = Set{Int}() - - -""" - rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode) - -Returns every rule in the ruleset that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree. - -!!! warning - The `ignoreNode` must be a subtree of `node` for it to have an effect. -""" -function rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode) - retval = Set() - - if node == ignoreNode - return retval - end - - if get_rule(node) ∈ ruleset - union!(retval, [get_rule(node)]) - end - - if isempty(node.children) - return retval - else - for child ∈ node.children - union!(retval, rulesoftype(child, ruleset)) - end - - return retval - end -end -rulesoftype(node::RuleNode, ruleset::Set{Int}, ::Hole) = rulesoftype(node, ruleset) -rulesoftype(::Hole, ::Set{Int}, ::RuleNode) = Set() -rulesoftype(::Hole, ::Set{Int}, ::Hole) = Set() - -""" - rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) - -Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree. - -!!! warning - The `ignoreNode` must be a subtree of `node` for it to have an effect. -""" -rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) = rulesoftype(node, Set(grammar[ruletype]), ignoreNode) -rulesoftype(::Hole, ::AbstractGrammar, ::Symbol, ::RuleNode) = Set() - """ swap_node(expr::AbstractRuleNode, new_expr::AbstractRuleNode, path::Vector{Int}) @@ -377,13 +322,41 @@ function expr2rulenode(expr::Union{Symbol,Number}, grammar::AbstractGrammar) end """ + rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar) + Calculates the log probability associated with a rulenode in a probabilistic grammar. """ function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar) - return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=1) + return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=0) +end + +function rulenode_log_probability(hole::AbstractHole, grammar::AbstractGrammar) + if sum(hole.domain) == 1 # only one element + return log_probability(grammar, only(findall(hole.domain))) + else + throw(ArgumentError("Log probability of a UniformHole requested, which has more than 1 element within its domain. This is ambiguous.")) + end +end +rulenode_log_probability(::Hole, ::AbstractGrammar) = 0 + +""" + max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGrammar) + +Calculates the highest possible probability within an `AbstractRuleNode`. +That is, for each node and its domain, get the highest probability and multiply it with the probabilities of its children, if present. +As we operate with log probabilities, we sum the logarithms. +""" +function max_rulenode_log_probability(rulenode::RuleNode, grammar::AbstractGrammar) + return log_probability(grammar, get_rule(rulenode)) + sum((max_rulenode_log_probability(c, grammar) for c ∈ rulenode.children), init=0) end -rulenode_log_probability(::Hole, ::AbstractGrammar) = 1 +function max_rulenode_log_probability(hole::AbstractHole, grammar::AbstractGrammar) + return maximum(grammar.log_probabilities[findall(hole.domain)]) + sum((max_rulenode_log_probability(c, grammar) for c ∈ hole.children), init=0) +end + +function max_rulenode_log_probability(hole::Hole, grammar::AbstractGrammar) + return maximum(grammar.log_probabilities[findall(hole.domain)]) +end """ diff --git a/test/runtests.jl b/test/runtests.jl index 52b5a4c..c37833f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,4 +10,5 @@ using Test include("test_rulenode2expr.jl") include("test_expr2rulenode.jl") include("test_utils.jl") + include("test_pcsg.jl") end diff --git a/test/test_csg.jl b/test/test_csg.jl index 87e064f..20b4609 100644 --- a/test/test_csg.jl +++ b/test/test_csg.jl @@ -108,83 +108,6 @@ # delete file afterwards rm("toy_cfg.grammar") end - - @testset "Writing and loading probabilistic CSG to/from disk" begin - g₁ = @pcsgrammar begin - 0.5 : Real = |(0:3) - 0.5 : Real = x - end - - store_csg(g₁, "toy_pcfg.grammar") - g₂ = read_pcsg("toy_pcfg.grammar") - @test :Real ∈ g₂.types - @test g₂.rules == [0, 1, 2, 3, :x] - @test g₂.log_probabilities == g₁.log_probabilities - - - # delete file afterwards - rm("toy_pcfg.grammar") - end - - @testset "creating probabilistic CSG" begin - g = @pcsgrammar begin - 0.5 : R = |(0:2) - 0.3 : R = x - 0.2 : B = true | false - end - - @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 - @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 - @test g.bytype[:R] == Int[1,2,3,4] - @test g.bytype[:B] == Int[5,6] - @test :R ∈ g.types && :B ∈ g.types - end - - @testset "creating a non-normalized PCSG" begin - g = @pcsgrammar begin - 0.5 : R = |(0:2) - 0.5 : R = x - 0.5 : B = true | false - end - - @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 - @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 - @test g.rules == [0, 1, 2, :x, :true, :false] - @test g.bytype[:R] == Int[1,2,3,4] - @test g.bytype[:B] == Int[5,6] - @test :R ∈ g.types && :B ∈ g.types - end - - @testset "Adding a rule to a probabilistic CSG" begin - g = @pcsgrammar begin - 0.5 : R = x - 0.5 : R = R + R - end - - add_rule!(g, 0.5, :(R = 1 | 2)) - - @test g.rules == [:x, :(R + R), 1, 2] - - add_rule!(g, 0.5, :(B = t | f)) - - @test g.bytype[:B] == Int[5, 6] - @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 - @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 - end - - @testset "Creating a non-probabilistic rule in a PCSG" begin - expected_log = ( - :error, - "Rule without probability encountered in probabilistic grammar. Rule ignored." - ) - - @test_logs expected_log match_mode=:any begin - @pcsgrammar begin - 0.5 : R = x - R = R + R - end - end - end @testset "Test that strict equality is used during rule creation" begin g₁ = @csgrammar begin diff --git a/test/test_pcsg.jl b/test/test_pcsg.jl new file mode 100644 index 0000000..3c5c803 --- /dev/null +++ b/test/test_pcsg.jl @@ -0,0 +1,152 @@ +@testset verbose=true "Probabilistic CSGs" begin + @testset "Writing and loading probabilistic CSG to/from disk" begin + g₁ = @pcsgrammar begin + 0.5 : Real = |(0:3) + 0.5 : Real = x + end + + store_csg(g₁, "toy_pcfg.grammar") + g₂ = read_pcsg("toy_pcfg.grammar") + @test :Real ∈ g₂.types + @test g₂.rules == [0, 1, 2, 3, :x] + @test g₂.log_probabilities == g₁.log_probabilities + + # delete file afterwards + rm("toy_pcfg.grammar") + end + + @testset "Creating probabilistic CSG" begin + g = @pcsgrammar begin + 0.5 : R = |(0:2) + 0.3 : R = x + 0.2 : B = true | false + end + + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + @test g.bytype[:R] == Int[1,2,3,4] + @test g.bytype[:B] == Int[5,6] + @test :R ∈ g.types && :B ∈ g.types + end + + @testset "Creating a non-normalized PCSG" begin + g = @pcsgrammar begin + 0.5 : R = |(0:2) + 0.5 : R = x + 0.5 : B = true | false + end + + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + @test g.rules == [0, 1, 2, :x, :true, :false] + @test g.bytype[:R] == Int[1,2,3,4] + @test g.bytype[:B] == Int[5,6] + @test :R ∈ g.types && :B ∈ g.types + end + + @testset "Adding a rule to a probabilistic CSG" begin + + g = @pcsgrammar begin + 0.5 : R = x + 0.5 : R = R + R + end + + add_rule!(g, 0.5, :(R = 1 | 2)) + + @test g.rules == [:x, :(R + R), 1, 2] + + add_rule!(g, 0.5, :(B = t | f)) + + @test g.bytype[:B] == Int[5, 6] + @test sum(map(exp, g.log_probabilities[g.bytype[:R]])) ≈ 1.0 + @test sum(map(exp, g.log_probabilities[g.bytype[:B]])) ≈ 1.0 + end + + @testset "Creating a non-probabilistic rule in a PCSG" begin + expected_log = ( + :error, + "Rule without probability encountered in probabilistic grammar. Rule ignored." + ) + + @test_logs expected_log match_mode=:any begin + @pcsgrammar begin + 0.5 : R = x + R = R + R + end + end + end + + @testset "Make csg probabilistic" begin + grammar = @csgrammar begin + R = |(1:3) + S = |(1:2) + end + # Test correct initialization + @test !isprobabilistic(grammar) + init_probabilities!(grammar) + @test isprobabilistic(grammar) + + probs = grammar.log_probabilities + + # Test equivalence of probabilities + @test probs[1] == probs[2] == probs[3] + @test probs[end-1] == probs[end] + + # Test values + @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) + @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) + + @test_logs (:warn, "Tried to init probabilities for grammar, but it is already probabilistic. No changes are made.") init_probabilities!(grammar) + end + + @testset "Test normalize! csg" begin + grammar = @csgrammar begin + R = |(1:3) + S = |(1:2) + end + + @test_logs (:warn, "Requesting normalization in a non-probabilistic grammar. Uniform distribution is assumed.") normalize!(grammar) + + grammar.log_probabilities = zeros(length(grammar.rules)) + + normalize!(grammar) + + probs = grammar.log_probabilities + + # Test values + @test all(x -> isapprox(x, 1/3), exp.(probs)[1:3]) + @test all(x -> isapprox(x, 1/2), exp.(probs)[4:5]) + end + + @testset "Test rulenode_log_probability" begin + grammar = @pcsgrammar begin + 0.5 : R = |(1:7) + 0.5 : R = x + end + rn = @rulenode 2{4,8} + + @test isapprox(exp(rulenode_log_probability(rn, grammar)), 1/14 * 1/14 * 0.5) + + hole1 = UniformHole(BitVector((0, 1, 0)), [RuleNode(3), RuleNode(4)]) + hole2 = UniformHole(BitVector((1, 1, 1)), [RuleNode(3), RuleNode(4)]) + generic_hole = Hole([1,1,1,1,0]) + + @test rulenode_log_probability(hole1, grammar) == log_probability(grammar, 2) + @test_throws ArgumentError rulenode_log_probability(hole2, grammar) + @test rulenode_log_probability(generic_hole, grammar) == 0 + end + + @testset "Test max_rulenode_log_probability" begin + grammar = @pcsgrammar begin + 0.5 : R = |(1:2) + 0.5 : R = x + end + + hole = UniformHole(BitVector((1, 1, 1)), [RuleNode(2), RuleNode(3)]) + generic_hole = Hole([1,1,0]) + + @test exp(max_rulenode_log_probability(hole, grammar)) == 1/2 * 1/2 * 1/4 + @test exp(max_rulenode_log_probability(generic_hole, grammar)) == 0.25 + end + +end \ No newline at end of file diff --git a/test/test_rulenode_operators.jl b/test/test_rulenode_operators.jl index dee0c14..1ce2faa 100644 --- a/test/test_rulenode_operators.jl +++ b/test/test_rulenode_operators.jl @@ -15,4 +15,4 @@ end @test !isvariable(g₁, RuleNode(7, g₁), SomeDefinitions) @test isvariable(g₁, RuleNode(7, g₁)) end -end +end \ No newline at end of file