Skip to content

Commit

Permalink
Moved probabilistic functios to extra file, moved tests to extra file…
Browse files Browse the repository at this point in the history
…s, updated tests for warnings
  • Loading branch information
THinnerichs committed Feb 23, 2025
1 parent 9f30238 commit ca4cbab
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 218 deletions.
1 change: 1 addition & 0 deletions src/HerbGrammar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export
nonterminals,
iscomplete,
parse_rule!,
rulesoftype,

@csgrammar,
expr2csgrammar,
Expand Down
45 changes: 44 additions & 1 deletion src/csg/probabilistic_csg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,47 @@ end

macro pcfgrammar(ex)
return :(expr2pcsgrammar($(QuoteNode(ex))))
end
end


"""
log_probability(grammar::AbstractGrammar, index::Int)::Real
Returns the log probability for the rule at `index` in the grammar.
!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function log_probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed."
# Assume uniform probability
return log(1 / length(grammar.bytype[grammar.types[index]]))
end
return grammar.log_probabilities[index]
end

"""
probability(grammar::AbstractGrammar, index::Int)::Real
Return the probability for a rule in the grammar.
Use [`log_probability`](@ref) whenever possible.
!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.Uniform distribution is assumed."
# Assume uniform probability
return 1 / length(grammar.bytype[grammar.types[index]])
end
return^grammar.log_probabilities[index]
end

"""
isprobabilistic(grammar::AbstractGrammar)::Bool
Function returns whether a [`AbstractGrammar`](@ref) is probabilistic.
"""
isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities nothing)
43 changes: 0 additions & 43 deletions src/grammar_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,49 +123,6 @@ Returns true if the production rule at rule_index contains the special _() eval
iseval(grammar::AbstractGrammar, index::Int)::Bool = grammar.iseval[index]


"""
log_probability(grammar::AbstractGrammar, index::Int)::Real
Returns the log probability for the rule at `index` in the grammar.
!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function log_probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed."
# Assume uniform probability
return log(1 / length(grammar.bytype[grammar.types[index]]))
end
return grammar.log_probabilities[index]
end


"""
probability(grammar::AbstractGrammar, index::Int)::Real
Return the probability for a rule in the grammar.
Use [`log_probability`](@ref) whenever possible.
!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed."
# Assume uniform probability
return 1 / length(grammar.bytype[grammar.types[index]])
end
return^grammar.log_probabilities[index]
end

"""
isprobabilistic(grammar::AbstractGrammar)::Bool
Function returns whether a [`AbstractGrammar`](@ref) is probabilistic.
"""
isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities nothing)


"""
nchildren(grammar::AbstractGrammar, rule_index::Int)::Int
Expand Down
83 changes: 26 additions & 57 deletions src/rulenode_operators.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,3 @@
rulesoftype(::Hole, ::Set{Int}) = Set{Int}()

"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol)
Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree.
"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol) = rulesoftype(node, Set{Int}(grammar[ruletype]))
rulesoftype(::Hole, ::AbstractGrammar, ::Symbol) = Set{Int}()


"""
rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode)
Returns every rule in the ruleset that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree.
!!! warning
The `ignoreNode` must be a subtree of `node` for it to have an effect.
"""
function rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode)
retval = Set()

if node == ignoreNode
return retval
end

if get_rule(node) ruleset
union!(retval, [get_rule(node)])
end

if isempty(node.children)
return retval
else
for child node.children
union!(retval, rulesoftype(child, ruleset))
end

return retval
end
end
rulesoftype(node::RuleNode, ruleset::Set{Int}, ::Hole) = rulesoftype(node, ruleset)
rulesoftype(::Hole, ::Set{Int}, ::RuleNode) = Set()
rulesoftype(::Hole, ::Set{Int}, ::Hole) = Set()

rulesoftype(node::AbstractRuleNode, index::Int) = rulesoftype(node, Set{Int}(index))

"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode)
Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree.
!!! warning
The `ignoreNode` must be a subtree of `node` for it to have an effect.
"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) = rulesoftype(node, Set(grammar[ruletype]), ignoreNode)
rulesoftype(::Hole, ::AbstractGrammar, ::Symbol, ::RuleNode) = Set()

"""
swap_node(expr::AbstractRuleNode, new_expr::AbstractRuleNode, path::Vector{Int})
Expand Down Expand Up @@ -385,8 +328,34 @@ function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar)
return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c node.children), init=1)
end

function rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar)
if sum(hole.domain) == 1 # only one element
return log_probability(grammar, only(findall(hole.domain)))
else
throw(ArgumentError("Log probability of a UniformHole requested, which has more than 1 element within its domain. This is ambiguous."))
end
end
rulenode_log_probability(::Hole, ::AbstractGrammar) = 1

"""
max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGrammar)
Calculates the highest possible probability within an `AbstractRuleNode`.
That is, for each node and its domain, get the highest probability and multiply it with the probabilities of its children, if present.
As we operate with log probabilities, we sum the logarithms.
"""
max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGrammar) = rulenode_log_probability(rulenode, grammar)

function max_rulenode_log_probability(hole::UniformHole, grammar::AbstractGrammar)
max_index = argmax(i -> grammar.log_probabilities[i], findall(hole.domain))
return log_probability(grammar, max_index) + sum((max_rulenode_log_probability(c, grammar) for c node.children), init=1)
end

function max_rulenode_log_probability(hole::Hole, grammar::AbstractGrammar)
max_index = argmax(i -> grammar.log_probabilities[i], findall(hole.domain))
return log_probability(grammar, max_index)
end


"""
iscomplete(grammar::AbstractGrammar, node::RuleNode)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ using Test
include("test_rulenode2expr.jl")
include("test_expr2rulenode.jl")
include("test_utils.jl")
include("test_pcsg.jl")
end
116 changes: 0 additions & 116 deletions test/test_csg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,122 +108,6 @@
# delete file afterwards
rm("toy_cfg.grammar")
end

@testset "Writing and loading probabilistic CSG to/from disk" begin
g₁ = @pcsgrammar begin
0.5 : Real = |(0:3)
0.5 : Real = x
end

store_csg(g₁, "toy_pcfg.grammar")
g₂ = read_pcsg("toy_pcfg.grammar")
@test :Real g₂.types
@test g₂.rules == [0, 1, 2, 3, :x]
@test g₂.log_probabilities == g₁.log_probabilities


# delete file afterwards
rm("toy_pcfg.grammar")
end

@testset "creating probabilistic CSG" begin
g = @pcsgrammar begin
0.5 : R = |(0:2)
0.3 : R = x
0.2 : B = true | false
end

@test sum(map(exp, g.log_probabilities[g.bytype[:R]])) 1.0
@test sum(map(exp, g.log_probabilities[g.bytype[:B]])) 1.0
@test g.bytype[:R] == Int[1,2,3,4]
@test g.bytype[:B] == Int[5,6]
@test :R g.types && :B g.types
end

@testset "creating a non-normalized PCSG" begin
g = @pcsgrammar begin
0.5 : R = |(0:2)
0.5 : R = x
0.5 : B = true | false
end

@test sum(map(exp, g.log_probabilities[g.bytype[:R]])) 1.0
@test sum(map(exp, g.log_probabilities[g.bytype[:B]])) 1.0
@test g.rules == [0, 1, 2, :x, :true, :false]
@test g.bytype[:R] == Int[1,2,3,4]
@test g.bytype[:B] == Int[5,6]
@test :R g.types && :B g.types
end

@testset "Adding a rule to a probabilistic CSG" begin
g = @pcsgrammar begin
0.5 : R = x
0.5 : R = R + R
end

add_rule!(g, 0.5, :(R = 1 | 2))

@test g.rules == [:x, :(R + R), 1, 2]

add_rule!(g, 0.5, :(B = t | f))

@test g.bytype[:B] == Int[5, 6]
@test sum(map(exp, g.log_probabilities[g.bytype[:R]])) 1.0
@test sum(map(exp, g.log_probabilities[g.bytype[:B]])) 1.0
end

@testset "Creating a non-probabilistic rule in a PCSG" begin
expected_log = (
:error,
"Rule without probability encountered in probabilistic grammar. Rule ignored."
)

@test_logs expected_log match_mode=:any begin
@pcsgrammar begin
0.5 : R = x
R = R + R
end
end
end

@testset "make csg probabilistic" begin
grammar = @csgrammar begin
R = |(1:3)
S = |(1:2)
end
# Test correct initialization
@test !isprobabilistic(grammar)
init_probabilities!(grammar)
@test isprobabilistic(grammar)

probs = grammar.log_probabilities

# Test equivalence of probabilities
@test probs[1] == probs[2] == probs[3]
@test probs[end-1] == probs[end]

# Test values
@test all(x -> isapprox(x, 1/3), exp.(probs)[1:3])
@test all(x -> isapprox(x, 1/2), exp.(probs)[4:5])
end

@testset "normalize! csg" begin
grammar = @csgrammar begin
R = |(1:3)
S = |(1:2)
end

grammar.log_probabilities = zeros(length(grammar.rules))

normalize!(grammar)

probs = grammar.log_probabilities

# Test values
@test all(x -> isapprox(x, 1/3), exp.(probs)[1:3])
@test all(x -> isapprox(x, 1/2), exp.(probs)[4:5])
end


@testset "Test that strict equality is used during rule creation" begin
g₁ = @csgrammar begin
Expand Down
2 changes: 1 addition & 1 deletion test/test_rulenode_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ end
@test !isvariable(g₁, RuleNode(7, g₁), SomeDefinitions)
@test isvariable(g₁, RuleNode(7, g₁))
end
end
end

0 comments on commit ca4cbab

Please sign in to comment.