Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid initializing infinitely recursive sub-fields. #1554

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions domain_tests/arbitrary_domains_protobuf_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,21 @@ TEST(ProtocolBufferWithRequiredFields, ShrinkingNeverRemovesRequiredFields) {
}
}

TEST(ProtocolBufferWithRecursiveFields, InfiniteleyRecursiveFieldsAreNotSet) {
auto domain = Arbitrary<internal::TestProtobufWithRepeatedRecursionSubproto>()
.WithRepeatedFieldsAlwaysSet();
absl::BitGen bitgen;
Value val(domain, bitgen);

ASSERT_TRUE(val.user_value.IsInitialized()) << val.user_value;

for (int i = 0; i < 1000; ++i) {
val.Mutate(domain, bitgen, {}, false);
ASSERT_TRUE(val.user_value.IsInitialized()) << val.user_value;
ASSERT_FALSE(val.user_value.has_list()) << val.user_value;
}
}

TEST(ProtocolBuffer, CanUsePerFieldDomains) {
Domain<TestProtobuf> domain =
Arbitrary<TestProtobuf>()
Expand Down
1 change: 1 addition & 0 deletions fuzztest/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
Expand Down
1 change: 1 addition & 0 deletions fuzztest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ fuzztest_cc_library(
absl::status
absl::statusor
absl::strings
absl::str_format
absl::synchronization
absl::span
)
Expand Down
135 changes: 86 additions & 49 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -466,10 +467,11 @@ class ProtobufDomainUntypedImpl

corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
FUZZTEST_INTERNAL_CHECK(
!IsCustomizedRecursivelyOnly() || !IsInfinitelyRecursive(),
"Cannot set recursive fields by default.");
const auto* descriptor = prototype_.Get()->GetDescriptor();
FUZZTEST_INTERNAL_CHECK(
!IsCustomizedRecursivelyOnly() || !IsInfinitelyRecursive(descriptor),
absl::StrCat("Cannot set recursive fields for ",
descriptor->full_name(), " by default."));
corpus_type val;
absl::flat_hash_map<int, int> oneof_to_field;

Expand All @@ -481,15 +483,18 @@ class ProtobufDomainUntypedImpl
SelectAFieldIndexInOneof(oneof, prng);
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!MustBeSet(field) && IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
// fields, the initialization will be deterministic, which violates the
// assumption on domain Init. However, such cases should be extremely
// rare and breaking the assumption would not have severe consequences.
continue;
} else if (IsCustomizedRecursivelyOnly()) {
if (!MustBeSet(field) && IsFieldFinitelyRecursive(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
// fields, the initialization will be deterministic, which violates
// the assumption on domain Init. However, such cases should be
// extremely rare and breaking the assumption would not have severe
// consequences.
continue;
}
if (MustBeUnset(field)) continue;
}
VisitProtobufField(field, InitializeVisitor{prng, *this, val});
}
Expand Down Expand Up @@ -609,6 +614,7 @@ class ProtobufDomainUntypedImpl
GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
continue;
}
if (IsCustomizedRecursivelyOnly() && MustBeUnset(field)) continue;
++total_weight;

if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
Expand Down Expand Up @@ -646,6 +652,7 @@ class ProtobufDomainUntypedImpl
GetOneofFieldPolicy(field) == OptionalPolicy::kAlwaysNull) {
continue;
}
if (IsCustomizedRecursivelyOnly() && MustBeUnset(field)) continue;
++field_counter;
if (field_counter == selected_field_index) {
VisitProtobufField(
Expand Down Expand Up @@ -958,9 +965,9 @@ class ProtobufDomainUntypedImpl
for (int i = 0; i < oneof->field_count(); ++i) {
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(oneof->field(i))) {
continue;
if (IsCustomizedRecursivelyOnly()) {
if (IsFieldFinitelyRecursive(oneof->field(i))) continue;
if (MustBeUnset(oneof->field(i))) continue;
}
fields.push_back(i);
}
Expand Down Expand Up @@ -1705,31 +1712,43 @@ class ProtobufDomainUntypedImpl
kFinitelyRecursive,
};

bool IsInfinitelyRecursive() {
absl::flat_hash_set<decltype(prototype_.Get()->GetDescriptor())> parents;
return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents,
// Returns true if there are subprotos in the `descriptor` that form an
// infinite recursion.
bool IsInfinitelyRecursive(const Descriptor* descriptor) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(/*field=*/nullptr, parents,
RecursionType::kInfinitelyRecursive, descriptor);
}

// Returns true if there are subfields in the `field` that form an
// infinite recursion of the form: F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs,
// because all Fi-s have to be set (e.g., Fi is a required field, or is
// customized using `WithFieldsAlwaysSet`).
bool IsInfinitelyRecursive(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
absl::flat_hash_set<const FieldDescriptor*> parents;
return IsProtoRecursive(field, parents,
RecursionType::kInfinitelyRecursive);
}

bool IsFieldFinitelyRecursive(const FieldDescriptor* field) {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (!field->message_type()) return false;
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<
absl::flat_hash_map<std::pair<int64_t, const FieldDescriptor*>, bool>>
cache ABSL_GUARDED_BY(mutex);
bool can_use_cache = IsCustomizedRecursivelyOnly();
if (can_use_cache) {
{
absl::MutexLock l(&mutex);
auto it = cache->find({policy_.id(), field});
if (it != cache->end()) return it->second;
}
absl::flat_hash_set<decltype(field->message_type())> parents;
bool result = IsProtoRecursive(field->message_type(), parents,
RecursionType::kFinitelyRecursive);
if (can_use_cache) {
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, result});
}
absl::flat_hash_set<const FieldDescriptor*> parents;
bool result =
IsProtoRecursive(field, parents, RecursionType::kFinitelyRecursive);
absl::MutexLock l(&mutex);
cache->insert({{policy_.id(), field}, result});
return result;
}

Expand All @@ -1742,23 +1761,23 @@ class ProtobufDomainUntypedImpl
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
absl::flat_hash_set<const FieldDescriptor*>& parents,
RecursionType recursion_type) const {
bool is_oneof_recursive = false;
for (int i = 0; i < oneof->field_count(); ++i) {
const auto* field = oneof->field(i);
const auto field_policy = policy_.GetOptionalPolicy(field);
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
const auto* child = field->message_type();
if (recursion_type == RecursionType::kInfinitelyRecursive) {
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
child &&
IsProtoRecursive(child, parents, recursion_type);
field->message_type() &&
IsProtoRecursive(field, parents, recursion_type);
if (!is_oneof_recursive) {
return false;
}
} else {
if (child && IsProtoRecursive(child, parents, recursion_type)) {
if (field->message_type() &&
IsProtoRecursive(field, parents, recursion_type)) {
return true;
}
}
Expand All @@ -1767,6 +1786,7 @@ class ProtobufDomainUntypedImpl
}

bool MustBeSet(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (IsRequired(field)) {
return true;
} else if (field->containing_oneof()) {
Expand All @@ -1783,6 +1803,14 @@ class ProtobufDomainUntypedImpl
}

bool MustBeUnset(const FieldDescriptor* field) const {
FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error.");
if (field->message_type() && IsInfinitelyRecursive(field)) {
absl::FPrintF(
GetStderr(),
"[!] Infinite recursion detected for %s and it remains unset.\n",
field->full_name());
return true;
}
if (IsRequired(field)) {
return false;
} else if (field->containing_oneof()) {
Expand All @@ -1798,39 +1826,48 @@ class ProtobufDomainUntypedImpl
return false;
}

template <typename Descriptor>
bool IsProtoRecursive(const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents,
RecursionType recursion_type) const {
if (parents.contains(descriptor)) return true;
parents.insert(descriptor);
// If `field` is nullptr, all fields of `descriptor` are checked.
bool IsProtoRecursive(const FieldDescriptor* field,
absl::flat_hash_set<const FieldDescriptor*>& parents,
RecursionType recursion_type,
const Descriptor* descriptor = nullptr) const {
if (field != nullptr) {
if (parents.contains(field)) return true;
parents.insert(field);
descriptor = field->message_type();
} else {
FUZZTEST_INTERNAL_CHECK(descriptor,
"one of field or descriptor must be non-null!");
}
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents, recursion_type)) {
parents.erase(descriptor);
if (field != nullptr) parents.erase(field);
return true;
}
}
for (const FieldDescriptor* field : GetProtobufFields(descriptor)) {
if (field->containing_oneof()) continue;
const auto* child = field->message_type();
if (!child) continue;
if (policy_.GetDefaultDomainForProtobufs(field) != std::nullopt) {
for (const FieldDescriptor* subfield : GetProtobufFields(descriptor)) {
if (subfield->containing_oneof()) continue;
if (!subfield->message_type()) continue;
if (auto default_domain = policy_.GetDefaultDomainForProtobufs(subfield);
default_domain != std::nullopt) { // For handling WithProtobufFields.
// If this field is recursive, it will be detected when initializing
// its default domain. Otherwise, this field can always be set safely.
absl::BitGen prng;
default_domain->Init(prng);
continue;
}
if (recursion_type == RecursionType::kInfinitelyRecursive) {
if (!MustBeSet(field)) continue;
if (!MustBeSet(subfield)) continue;
} else {
if (MustBeUnset(field)) continue;
if (MustBeUnset(subfield)) continue;
}
if (IsProtoRecursive(child, parents, recursion_type)) {
parents.erase(descriptor);
if (IsProtoRecursive(subfield, parents, recursion_type)) {
if (field != nullptr) parents.erase(field);
return true;
}
}
parents.erase(descriptor);
if (field != nullptr) parents.erase(field);
return false;
}

Expand Down
8 changes: 8 additions & 0 deletions fuzztest/internal/test_protobuf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ message TestProtobufWithRecursion {
optional TestProtobufWithExtension ext = 4;
}

message TestProtobufWithRepeatedRecursion {
repeated TestProtobufWithRepeatedRecursion items = 1;
}

message TestProtobufWithRepeatedRecursionSubproto {
optional TestProtobufWithRepeatedRecursion list = 1;
}

message MessageWithGroup {
optional group GroupField = 1 {
optional int64 field1 = 2;
Expand Down
Loading