Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PCSG handling #104

Merged
merged 9 commits into from
Feb 26, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HerbGrammar"
uuid = "4ef9e186-2fe5-4b24-8de7-9f7291f24af7"
authors = ["Sebastijan Dumancic <s.dumancic@tudelft.nl>", "Jaap de Jong <J.deJong-18@student.tudelft.nl>", "Nicolae Filat <N.Filat@student.tudelft.nl>", "Piotr Cichoń <gitlab@gitlab.ewi.tudelft.nl>"]
version = "0.5.0"
version = "0.5.1"

[deps]
HerbCore = "2b23ba43-8213-43cb-b5ea-38c12b45bd45"
Expand Down
4 changes: 4 additions & 0 deletions src/HerbGrammar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export
nonterminals,
iscomplete,
parse_rule!,
rulesoftype,

@csgrammar,
expr2csgrammar,
Expand All @@ -51,12 +52,15 @@ export

@pcsgrammar,
expr2pcsgrammar,
normalize!,
init_probabilities!,

SymbolTable,
grammar2symboltable,

rulenode2expr,
rulenode_log_probability,
max_rulenode_log_probability,

mindepth_map,
mindepth,
Expand Down
91 changes: 82 additions & 9 deletions src/csg/probabilistic_csg.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

"""
expr2pcsgrammar(ex::Expr)::ContextSensitiveGrammar

Function for converting an `Expr` to a [`ContextSensitiveGrammar`](@ref) with probabilities.
If the expression is hardcoded, you should use the `@pcsgrammar` macro.
Only expressions in the correct format (see [`@pcsgrammar`](@ref)) can be converted.
Expand Down Expand Up @@ -47,6 +49,8 @@
end

"""
parse_probabilistic_rule(e::Expr)

Parses a single (potentially shorthand) derivation rule of a probabilistic [`ContextSensitiveGrammar`](@ref).
Returns `nothing` if the rule is not probabilistic, otherwise a `Tuple` of its type and a
`Vector` of probability-rule pairs it expands into.
Expand Down Expand Up @@ -78,26 +82,52 @@


"""
normalize!(grammar::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing)

A function for normalizing the probabilities of a probabilistic [`ContextSensitiveGrammar`](@ref).
If the optional `type` argument is provided, only the rules of that type are normalized.
If the optional `type` argument is provided, only the rules of that type are normalized.
If the grammar is not probabilistic, i.e. `grammar.log_probabilities==nothing`, a uniform distribution is initialized.
"""
function normalize!(g::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing)
probabilities = map(exp, g.log_probabilities)
types = isnothing(type) ? keys(g.bytype) : [type]
function normalize!(grammar::ContextSensitiveGrammar, type::Union{Symbol, Nothing}=nothing)
if !isprobabilistic(grammar)
@warn "Requesting normalization in a non-probabilistic grammar. Uniform distribution is assumed."
init_probabilities!(grammar)
return grammar
end

probabilities = map(exp, grammar.log_probabilities)
types = isnothing(type) ? keys(grammar.bytype) : [type]

for t ∈ types
total_prob = sum(probabilities[i] for i ∈ g.bytype[t])
total_prob = sum(probabilities[i] for i ∈ grammar.bytype[t])
if !(total_prob ≈ 1)
for i ∈ g.bytype[t]
for i ∈ grammar.bytype[t]
probabilities[i] /= total_prob
end
end
end

g.log_probabilities = map(log, probabilities)
return g
grammar.log_probabilities = map(log, probabilities)
return grammar
end


"""
init_probabilities!(grammar::AbstractGrammar)

If the grammar is not probabilistic yet, initializes the grammar with uniform probabilities per type. If the grammar is already probabilistic, no changed are made.
"""
function init_probabilities!(grammar::AbstractGrammar)
if isprobabilistic(grammar)
@warn "Tried to init probabilities for grammar, but it is already probabilistic. No changes are made."
else
grammar.log_probabilities = zeros(length(grammar.rules))
normalize!(grammar)
end
return grammar
end


"""
@pcsgrammar

Expand Down Expand Up @@ -144,4 +174,47 @@

macro pcfgrammar(ex)
return :(expr2pcsgrammar($(QuoteNode(ex))))
end
end


"""
log_probability(grammar::AbstractGrammar, index::Int)::Real

Returns the log probability for the rule at `index` in the grammar.

!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function log_probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed."

Check warning on line 190 in src/csg/probabilistic_csg.jl

View check run for this annotation

Codecov / codecov/patch

src/csg/probabilistic_csg.jl#L190

Added line #L190 was not covered by tests
# Assume uniform probability
return log(1 / length(grammar.bytype[grammar.types[index]]))

Check warning on line 192 in src/csg/probabilistic_csg.jl

View check run for this annotation

Codecov / codecov/patch

src/csg/probabilistic_csg.jl#L192

Added line #L192 was not covered by tests
end
return grammar.log_probabilities[index]
end

"""
probability(grammar::AbstractGrammar, index::Int)::Real

Return the probability for a rule in the grammar.
Use [`log_probability`](@ref) whenever possible.

!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.Uniform distribution is assumed."

Check warning on line 208 in src/csg/probabilistic_csg.jl

View check run for this annotation

Codecov / codecov/patch

src/csg/probabilistic_csg.jl#L206-L208

Added lines #L206 - L208 were not covered by tests
# Assume uniform probability
return 1 / length(grammar.bytype[grammar.types[index]])

Check warning on line 210 in src/csg/probabilistic_csg.jl

View check run for this annotation

Codecov / codecov/patch

src/csg/probabilistic_csg.jl#L210

Added line #L210 was not covered by tests
end
return ℯ^grammar.log_probabilities[index]

Check warning on line 212 in src/csg/probabilistic_csg.jl

View check run for this annotation

Codecov / codecov/patch

src/csg/probabilistic_csg.jl#L212

Added line #L212 was not covered by tests
end

"""
isprobabilistic(grammar::AbstractGrammar)::Bool

Function returns whether a [`AbstractGrammar`](@ref) is probabilistic.
"""
isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities ≡ nothing)
42 changes: 0 additions & 42 deletions src/grammar_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,48 +123,6 @@ Returns true if the production rule at rule_index contains the special _() eval
iseval(grammar::AbstractGrammar, index::Int)::Bool = grammar.iseval[index]


"""
log_probability(grammar::AbstractGrammar, index::Int)::Real

Returns the log probability for the rule at `index` in the grammar.

!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function log_probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed."
# Assume uniform probability
return log(1 / length(grammar.bytype[grammar.types[index]]))
end
return grammar.log_probabilities[index]
end

"""
probability(grammar::AbstractGrammar, index::Int)::Real

Return the probability for a rule in the grammar.
Use [`log_probability`](@ref) whenever possible.

!!! warning
If the grammar is not probabilistic, a warning is displayed, and a uniform probability is assumed.
"""
function probability(grammar::AbstractGrammar, index::Int)::Real
if !isprobabilistic(grammar)
@warn "Requesting probability in a non-probabilistic grammar.\nUniform distribution is assumed."
# Assume uniform probability
return 1 / length(grammar.bytype[grammar.types[index]])
end
return ℯ^grammar.log_probabilities[index]
end

"""
isprobabilistic(grammar::AbstractGrammar)::Bool

Function returns whether a [`AbstractGrammar`](@ref) is probabilistic.
"""
isprobabilistic(grammar::AbstractGrammar)::Bool = !(grammar.log_probabilities ≡ nothing)


"""
nchildren(grammar::AbstractGrammar, rule_index::Int)::Int
Expand Down
87 changes: 30 additions & 57 deletions src/rulenode_operators.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,3 @@
rulesoftype(::Hole, ::Set{Int}) = Set{Int}()

"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol)

Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree.
"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol) = rulesoftype(node, Set{Int}(grammar[ruletype]))
rulesoftype(::Hole, ::AbstractGrammar, ::Symbol) = Set{Int}()


"""
rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode)

Returns every rule in the ruleset that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree.

!!! warning
The `ignoreNode` must be a subtree of `node` for it to have an effect.
"""
function rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode)
retval = Set()

if node == ignoreNode
return retval
end

if get_rule(node) ∈ ruleset
union!(retval, [get_rule(node)])
end

if isempty(node.children)
return retval
else
for child ∈ node.children
union!(retval, rulesoftype(child, ruleset))
end

return retval
end
end
rulesoftype(node::RuleNode, ruleset::Set{Int}, ::Hole) = rulesoftype(node, ruleset)
rulesoftype(::Hole, ::Set{Int}, ::RuleNode) = Set()
rulesoftype(::Hole, ::Set{Int}, ::Hole) = Set()

"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode)

Returns every rule of nonterminal symbol `ruletype` that is also used in the [`AbstractRuleNode`](@ref) tree, but not in the `ignoreNode` subtree.

!!! warning
The `ignoreNode` must be a subtree of `node` for it to have an effect.
"""
rulesoftype(node::RuleNode, grammar::AbstractGrammar, ruletype::Symbol, ignoreNode::RuleNode) = rulesoftype(node, Set(grammar[ruletype]), ignoreNode)
rulesoftype(::Hole, ::AbstractGrammar, ::Symbol, ::RuleNode) = Set()

"""
swap_node(expr::AbstractRuleNode, new_expr::AbstractRuleNode, path::Vector{Int})

Expand Down Expand Up @@ -377,13 +322,41 @@ function expr2rulenode(expr::Union{Symbol,Number}, grammar::AbstractGrammar)
end

"""
rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar)

Calculates the log probability associated with a rulenode in a probabilistic grammar.
"""
function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar)
return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=1)
return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=0)
end

function rulenode_log_probability(hole::AbstractHole, grammar::AbstractGrammar)
if sum(hole.domain) == 1 # only one element
return log_probability(grammar, only(findall(hole.domain)))
else
throw(ArgumentError("Log probability of a UniformHole requested, which has more than 1 element within its domain. This is ambiguous."))
end
end
rulenode_log_probability(::Hole, ::AbstractGrammar) = 0

"""
max_rulenode_log_probability(rulenode::AbstractRuleNode, grammar::AbstractGrammar)

Calculates the highest possible probability within an `AbstractRuleNode`.
That is, for each node and its domain, get the highest probability and multiply it with the probabilities of its children, if present.
As we operate with log probabilities, we sum the logarithms.
"""
function max_rulenode_log_probability(rulenode::RuleNode, grammar::AbstractGrammar)
return log_probability(grammar, get_rule(rulenode)) + sum((max_rulenode_log_probability(c, grammar) for c ∈ rulenode.children), init=0)
end

rulenode_log_probability(::Hole, ::AbstractGrammar) = 1
function max_rulenode_log_probability(hole::AbstractHole, grammar::AbstractGrammar)
return maximum(grammar.log_probabilities[findall(hole.domain)]) + sum((max_rulenode_log_probability(c, grammar) for c ∈ hole.children), init=0)
end

function max_rulenode_log_probability(hole::Hole, grammar::AbstractGrammar)
return maximum(grammar.log_probabilities[findall(hole.domain)])
end


"""
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ using Test
include("test_rulenode2expr.jl")
include("test_expr2rulenode.jl")
include("test_utils.jl")
include("test_pcsg.jl")
end
Loading