Skip to content

Commit

Permalink
Add ValidateCorpusValue to the InGrammar domains.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 732989178
  • Loading branch information
hadi88 authored and copybara-github committed Mar 7, 2025
1 parent 169a96b commit 4bfb207
Showing 1 changed file with 110 additions and 37 deletions.
147 changes: 110 additions & 37 deletions fuzztest/internal/domains/in_grammar_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ class StringLiteralDomain {
result.children.emplace<std::monostate>();
return result;
}

static absl::Status ValidateCorpusValue(const ASTNode& astnode) {
if (!CheckASTNodeTypeIdAndChildType<std::monostate>(astnode, id)) {
return absl::InvalidArgumentError("Invalid node type!");
}
return absl::OkStatus();
}
};

template <ASTTypeId id, const absl::string_view& value>
Expand All @@ -151,23 +158,25 @@ class RegexLiteralDomain {
static void Mutate(ASTNode& val, absl::BitGenRef prng,
const domain_implementor::MutationMetadata& metadata,
bool only_shrink) {
GetInnerRegexpDomain().Mutate(std::get<1>(val.children), prng, metadata,
only_shrink);
GetInnerRegexpDomain().Mutate(std::get<DFAPath>(val.children), prng,
metadata, only_shrink);
}

static ASTTypeId TypeId() { return id; }

static void ToString(std::string& output, const ASTNode& val) {
FUZZTEST_INTERNAL_CHECK(val.children.index() == 1, "Not a regex literal!");
absl::StrAppend(&output,
GetInnerRegexpDomain().GetValue(std::get<1>(val.children)));
FUZZTEST_INTERNAL_CHECK(CheckASTNodeTypeIdAndChildType<DFAPath>(val, id),
"Not a regex literal!");
absl::StrAppend(&output, GetInnerRegexpDomain().GetValue(
std::get<DFAPath>(val.children)));
}

static bool IsMutable(const ASTNode& /*val*/) { return true; }

static IRObject SerializeCorpus(const ASTNode& astnode) {
FUZZTEST_INTERNAL_CHECK(
CheckASTNodeTypeIdAndChildType<DFAPath>(astnode, id), "Invalid node!");
CheckASTNodeTypeIdAndChildType<DFAPath>(astnode, id),
"Not a regex literal!");
return WrapASTIntoIRObject(astnode,
GetInnerRegexpDomain().SerializeCorpus(
std::get<DFAPath>(astnode.children)));
Expand All @@ -190,6 +199,13 @@ class RegexLiteralDomain {
return result;
}

static absl::Status ValidateCorpusValue(const ASTNode& astnode) {
if (!CheckASTNodeTypeIdAndChildType<DFAPath>(astnode, id)) {
return absl::InvalidArgumentError("Not a regex literal!");
}
return absl::OkStatus();
}

private:
static internal::InRegexpImpl& GetInnerRegexpDomain() {
static internal::InRegexpImpl* inner_domain =
Expand Down Expand Up @@ -240,8 +256,11 @@ class VectorDomain {
static void Mutate(ASTNode& val, absl::BitGenRef prng,
const domain_implementor::MutationMetadata& metadata,
bool only_shrink) {
FUZZTEST_INTERNAL_CHECK(val.children.index() == 2, "Not a vector!");
std::vector<ASTNode>& elements = std::get<2>(val.children);
FUZZTEST_INTERNAL_CHECK(
CheckASTNodeTypeIdAndChildType<std::vector<ASTNode>>(val, id),
"Not a vector!");
std::vector<ASTNode>& elements =
std::get<std::vector<ASTNode>>(val.children);
if (only_shrink) {
ShrinkElements(elements, prng);
return;
Expand Down Expand Up @@ -272,7 +291,7 @@ class VectorDomain {
}

static void ToString(std::string& output, const ASTNode& val) {
for (const auto& child : std::get<2>(val.children)) {
for (const auto& child : std::get<std::vector<ASTNode>>(val.children)) {
ElementT::ToString(output, child);
}
}
Expand All @@ -282,7 +301,7 @@ class VectorDomain {
static IRObject SerializeCorpus(const ASTNode& astnode) {
FUZZTEST_INTERNAL_CHECK(
CheckASTNodeTypeIdAndChildType<std::vector<ASTNode>>(astnode, id),
"Invalid node!");
"Not a vector!");
IRObject expansion_obj;
auto& inner_subs = expansion_obj.MutableSubs();
for (auto& node : std::get<std::vector<ASTNode>>(astnode.children)) {
Expand Down Expand Up @@ -319,6 +338,18 @@ class VectorDomain {
return result;
}

static absl::Status ValidateCorpusValue(const ASTNode& astnode) {
if (!CheckASTNodeTypeIdAndChildType<std::vector<ASTNode>>(astnode, id)) {
return absl::InvalidArgumentError("Not a vector!");
}
absl::Status status = absl::OkStatus();
for (const auto& child : std::get<std::vector<ASTNode>>(astnode.children)) {
status.Update(ElementT::ValidateCorpusValue(child));
if (!status.ok()) return status;
}
return status;
}

private:
static void ShrinkElements(std::vector<ASTNode>& elements,
absl::BitGenRef prng) {
Expand Down Expand Up @@ -393,13 +424,14 @@ class TupleDomain {
const domain_implementor::MutationMetadata& metadata,
bool only_shrink) {
FUZZTEST_INTERNAL_CHECK(
val.children.index() == 2 &&
std::get<2>(val.children).size() == sizeof...(ElementT),
CheckASTNodeTypeIdAndChildType<std::vector<ASTNode>>(val, id) &&
std::get<std::vector<ASTNode>>(val.children).size() ==
sizeof...(ElementT),
"Tuple elements number doesn't match!");

std::vector<int> mutables;
ApplyIndex<sizeof...(ElementT)>([&](auto... I) {
((ElementT::IsMutable(std::get<2>(val.children)[I])
((ElementT::IsMutable(std::get<std::vector<ASTNode>>(val.children)[I])
? mutables.push_back(I)
: (void)0),
...);
Expand All @@ -411,23 +443,28 @@ class TupleDomain {

int choice = mutables[absl::Uniform<int>(prng, 0, mutables.size())];
ApplyIndex<sizeof...(ElementT)>([&](auto... I) {
((choice == I ? (ElementT::Mutate(std::get<2>(val.children)[I], prng,
metadata, only_shrink))
: (void)0),
((choice == I
? (ElementT::Mutate(std::get<std::vector<ASTNode>>(val.children)[I],
prng, metadata, only_shrink))
: (void)0),
...);
});
}

static void ToString(std::string& output, const ASTNode& val) {
ApplyIndex<sizeof...(ElementT)>([&](auto... I) {
(ElementT::ToString(output, std::get<2>(val.children)[I]), ...);
(ElementT::ToString(output,
std::get<std::vector<ASTNode>>(val.children)[I]),
...);
});
}

static bool IsMutable(const ASTNode& val) {
bool result = false;
ApplyIndex<sizeof...(ElementT)>([&](auto... I) {
result = (ElementT::IsMutable(std::get<2>(val.children)[I]) || ...);
result = (ElementT::IsMutable(
std::get<std::vector<ASTNode>>(val.children)[I]) ||
...);
});
return result;
}
Expand All @@ -439,8 +476,8 @@ class TupleDomain {
IRObject expansion_obj;
auto& inner_subs = expansion_obj.MutableSubs();
ApplyIndex<sizeof...(ElementT)>([&](auto... I) {
(inner_subs.push_back(
ElementT::SerializeCorpus(std::get<2>(astnode.children)[I])),
(inner_subs.push_back(ElementT::SerializeCorpus(
std::get<std::vector<ASTNode>>(astnode.children)[I])),
...);
});
return WrapASTIntoIRObject(astnode, expansion_obj);
Expand Down Expand Up @@ -476,6 +513,24 @@ class TupleDomain {
}
return result;
}

static absl::Status ValidateCorpusValue(const ASTNode& astnode) {
if (!CheckASTNodeTypeIdAndChildType<std::vector<ASTNode>>(astnode, id)) {
return absl::InvalidArgumentError("Not a vector!");
}
if (std::get<std::vector<ASTNode>>(astnode.children).size() ==
sizeof...(ElementT)) {
return absl::InvalidArgumentError("Tuple elements number doesn't match!");
}
absl::Status status = absl::OkStatus();
ApplyIndex<sizeof...(ElementT)>([&](auto... I) {
(void)((status.Update(ElementT::ValidateCorpusValue(
std::get<std::vector<ASTNode>>(astnode.children)[I])),
status.ok()) &&
...);
});
return status;
}
};

template <ASTTypeId id, int fallback_index, typename... ElementT>
Expand All @@ -493,7 +548,7 @@ class VariantDomain {
: absl::Uniform<int>(prng, 0, (sizeof...(ElementT)));
ASTNode result{id, std::vector<ASTNode>()};
Switch<sizeof...(ElementT)>(choice, [&](auto I) {
std::get<2>(result.children)
std::get<std::vector<ASTNode>>(result.children)
.push_back(std::tuple_element<I, std::tuple<ElementT...>>::type::
InitWithBudget(prng, generation_budget));
});
Expand All @@ -504,7 +559,8 @@ class VariantDomain {
const domain_implementor::MutationMetadata& metadata,
bool only_shrink) {
constexpr bool has_alternative = sizeof...(ElementT) > 1;
ASTNode current_value = std::get<2>(val.children).front();
ASTNode current_value =
std::get<std::vector<ASTNode>>(val.children).front();
ASTTypeId current_value_id = current_value.type_id;
bool is_current_value_mutable;
((ElementT::TypeId() == current_value_id
Expand Down Expand Up @@ -536,12 +592,12 @@ class VariantDomain {
}

static void ToString(std::string& output, const ASTNode& val) {
FUZZTEST_INTERNAL_CHECK(std::get<2>(val.children).size() == 1,
"This is not a variant ast node.");
ASTTypeId child_id = std::get<2>(val.children).front().type_id;
((ElementT::TypeId() == child_id
? (ElementT::ToString(output, std::get<2>(val.children).front()))
: (void)0),
FUZZTEST_INTERNAL_CHECK(
std::get<std::vector<ASTNode>>(val.children).size() == 1,
"This is not a variant ast node.");
auto child = std::get<std::vector<ASTNode>>(val.children).front();
((ElementT::TypeId() == child.type_id ? (ElementT::ToString(output, child))
: (void)0),
...);
}

Expand All @@ -551,10 +607,9 @@ class VariantDomain {

// Otherwise, we check whether the only choice is mutable.
bool result = false;
ASTTypeId child_id = std::get<2>(val.children).front().type_id;
((ElementT::TypeId() == child_id
? (result = ElementT::IsMutable(std::get<2>(val.children).front()),
(void)0)
auto child = std::get<std::vector<ASTNode>>(val.children).front();
((ElementT::TypeId() == child.type_id
? (result = ElementT::IsMutable(child), (void)0)
: (void)0),
...);
return result;
Expand Down Expand Up @@ -603,11 +658,29 @@ class VariantDomain {
return result;
}

static absl::Status ValidateCorpusValue(const ASTNode& astnode) {
if (!CheckASTNodeTypeIdAndChildType<std::vector<ASTNode>>(astnode, id)) {
return absl::InvalidArgumentError("Invalid node type!");
}
if (std::get<std::vector<ASTNode>>(astnode.children).size() != 1) {
return absl::InvalidArgumentError("This is not a variant ast node.");
}
absl::Status status = absl::OkStatus();
auto child = std::get<std::vector<ASTNode>>(astnode.children).front();
(void)((ElementT::TypeId() == child.type_id
? (status.Update(ElementT::ValidateCorpusValue(child)),
status.ok())
: true) &&
...);
return status;
}

private:
static void MutateCurrentValue(
ASTNode& val, absl::BitGenRef prng,
const domain_implementor::MutationMetadata& metadata, bool only_shrink) {
ASTNode& current_value = std::get<2>(val.children).front();
ASTNode& current_value =
std::get<std::vector<ASTNode>>(val.children).front();
((ElementT::TypeId() == current_value.type_id
? (ElementT::Mutate(current_value, prng, metadata, only_shrink))
: (void)0),
Expand All @@ -616,7 +689,8 @@ class VariantDomain {
static void SwitchToAlternative(ASTNode& val, absl::BitGenRef prng) {
constexpr int n_alternative = sizeof...(ElementT);
FUZZTEST_INTERNAL_CHECK(n_alternative > 1, "No alternative to switch!");
int child_type_id = std::get<2>(val.children).front().type_id;
int child_type_id =
std::get<std::vector<ASTNode>>(val.children).front().type_id;
int current_choice = 0;
ApplyIndex<n_alternative>([&](auto... I) {
((ElementT::TypeId() == child_type_id ? (current_choice = I, (void)0)
Expand All @@ -630,7 +704,7 @@ class VariantDomain {

// Switch to an alternative value.
ApplyIndex<n_alternative>([&](auto... I) {
((choice == I ? (std::get<2>(val.children).front() =
((choice == I ? (std::get<std::vector<ASTNode>>(val.children).front() =
ElementT::InitWithBudget(prng, kMaxGenerationNum),
(void)0)
: (void)0),
Expand Down Expand Up @@ -690,8 +764,7 @@ class InGrammarImpl
absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const {
// Validation is currently done during Parsing, and UserToCorpusValue() is
// not supported yet.
// TODO(lszekeres): Refactor so that validation happens here instead.
return absl::OkStatus();
return TopDomain::ValidateCorpusValue(corpus_value);
}

private:
Expand Down

0 comments on commit 4bfb207

Please sign in to comment.