Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PCSG handling #104

Merged
merged 9 commits into from
Feb 26, 2025
2 changes: 2 additions & 0 deletions src/HerbGrammar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ export

@pcsgrammar,
expr2pcsgrammar,
normalize!,
init_probabilities!,

SymbolTable,
grammar2symboltable,
Expand Down
46 changes: 38 additions & 8 deletions src/csg/probabilistic_csg.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

"""
expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar

Function for converting an `Expr` to a [`ContextSensitiveGrammar`](@ref) with probabilities.
If the expression is hardcoded, you should use the `@pcsgrammar` macro.
Only expressions in the correct format (see [`@pcsgrammar`](@ref)) can be converted.
Expand Down Expand Up @@ -47,6 +49,8 @@
end

"""
parse_probabilistic_rule(e::Expr)

Parses a single (potentially shorthand) derivation rule of a probabilistic [`ContextSensitiveGrammar`](@ref).
Returns `nothing` if the rule is not probabilistic, otherwise a `Tuple` of its type and a
`Vector` of probability-rule pairs it expands into.
Expand Down Expand Up @@ -78,26 +82,52 @@


"""
normalize!(grammar::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing)

A function for normalizing the probabilities of a probabilistic [`ContextSensitiveGrammar`](@ref).
If the optional `type` argument is provided, only the rules of that type are normalized.
If the optional `type` argument is provided, only the rules of that type are normalized.
If the grammar is not probabilistic, i.e. `grammar.log_probabilities==nothing`, a uniform distribution is initialized.
"""
function normalize!(g::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing)
probabilities = map(exp, g.log_probabilities)
types = isnothing(type) ? keys(g.bytype) : [type]
function normalize!(grammar::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing)
if !isprobabilistic(grammar)
@warn "Requesting normalization in a non-probabilistic grammar. Uniform distribution is assumed."
init_probabilities!(grammar)
return grammar

Check warning on line 95 in src/csg/probabilistic_csg.jl

View check run for this annotation

Codecov / codecov/patch

src/csg/probabilistic_csg.jl#L93-L95

Added lines #L93 - L95 were not covered by tests
end

probabilities = map(exp, grammar.log_probabilities)
types = isnothing(type) ? keys(grammar.bytype) : [type]

for t ∈ types
total_prob = sum(probabilities[i] for i ∈ g.bytype[t])
total_prob = sum(probabilities[i] for i ∈ grammar.bytype[t])
if !(total_prob ≈ 1)
for i ∈ g.bytype[t]
for i ∈ grammar.bytype[t]
probabilities[i] /= total_prob
end
end
end

g.log_probabilities = map(log, probabilities)
return g
grammar.log_probabilities = map(log, probabilities)
return grammar
end


"""
init_probabilities!(grammar::AbstractGrammar)

If the grammar is not probabilistic yet, initializes the grammar with uniform probabilities per type. If the grammar is already probabilistic, no changed are made.
"""
function init_probabilities!(grammar::AbstractGrammar)
if isprobabilistic(grammar)
@warn "Tried to init probabilities for grammar, but it is already probabilistic. No changes are made."

Check warning on line 122 in src/csg/probabilistic_csg.jl

View check run for this annotation

Codecov / codecov/patch

src/csg/probabilistic_csg.jl#L122

Added line #L122 was not covered by tests
else
grammar.log_probabilities = zeros(length(grammar.rules))
normalize!(grammar)
end
return grammar
end


"""
@pcsgrammar

Expand Down
1 change: 1 addition & 0 deletions src/grammar_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ function log_probability(grammar::AbstractGrammar, index::Int)::Real
return grammar.log_probabilities[index]
end


"""
probability(grammar::AbstractGrammar, index::Int)::Real

Expand Down
2 changes: 2 additions & 0 deletions src/rulenode_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
rulesoftype(::Hole, ::Set{Int}, ::RuleNode) = Set()
rulesoftype(::Hole, ::Set{Int}, ::Hole) = Set()

rulesoftype(node::AbstractRuleNode, index::Int) = rulesoftype(node, Set{Int}(index))

Check warning on line 45 in src/rulenode_operators.jl

View check run for this annotation

Codecov / codecov/patch

src/rulenode_operators.jl#L45

Added line #L45 was not covered by tests

"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode)

Expand Down
39 changes: 39 additions & 0 deletions test/test_csg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,45 @@
end
end
end

@testset "make csg probabilistic" begin
grammar = @csgrammar begin
R = |(1:3)
S = |(1:2)
end
# Test correct initialization
@test !isprobabilistic(grammar)
init_probabilities!(grammar)
@test isprobabilistic(grammar)

probs = grammar.log_probabilities

# Test equivalence of probabilities
@test probs[1] == probs[2] == probs[3]
@test probs[end-1] == probs[end]

# Test values
@test all(x -> isapprox(x, 1/3), exp.(probs)[1:3])
@test all(x -> isapprox(x, 1/2), exp.(probs)[4:5])
end

@testset "normalize! csg" begin
grammar = @csgrammar begin
R = |(1:3)
S = |(1:2)
end

grammar.log_probabilities = zeros(length(grammar.rules))

normalize!(grammar)

probs = grammar.log_probabilities

# Test values
@test all(x -> isapprox(x, 1/3), exp.(probs)[1:3])
@test all(x -> isapprox(x, 1/2), exp.(probs)[4:5])
end


@testset "Test that strict equality is used during rule creation" begin
g₁ = @csgrammar begin
Expand Down