From 313a52217cccb2c3454981cde7dc05fda245ebcb Mon Sep 17 00:00:00 2001 From: Hadi Ravanbakhsh Date: Tue, 25 Feb 2025 06:50:03 -0800 Subject: [PATCH] Initialize recursive protobuf fields more efficiently. Currently, a field (F0) is not initialized if there are subfields of the form F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs. This means that for a huge protobuf field that has a recursive subfield deep in its definition, the whole field is not initialized. And later, even when F0 is initialized, F1 won't get initialized, etc. This could be very inefficient. To avoid this, we define "recursion breaker fields". For example, "Fs" becomes a recursion breaker. Then, all fields up to Fs are initialized. And later when Fs gets initialized, all Fs -> ... -> Fn get initialized. This CL consists of the following changes: - IsProtoRecursive deals with infinite recursions only. - IsFinitelyRecursive is replaced with IsRecursionBreaker which is implemented separately. PiperOrigin-RevId: 730867271 --- .../internal/domains/protobuf_domain_impl.h | 122 +++++++++++------- 1 file changed, 76 insertions(+), 46 deletions(-) diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index 3b21e1f9..6d44da33 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -484,7 +484,7 @@ class ProtobufDomainUntypedImpl } if (oneof_to_field[oneof->index()] != field->index()) continue; } else if (IsCustomizedRecursivelyOnly()) { - if (!MustBeSet(field) && IsFieldFinitelyRecursive(field)) { + if (!MustBeSet(field) && IsRecursionBreaker(field)) { // We avoid initializing non-required recursive fields by default (if // they are not explicitly customized). Otherwise, the initialization // may never terminate. If a proto has only non-required recursive @@ -966,7 +966,7 @@ class ProtobufDomainUntypedImpl OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i)); if (policy == OptionalPolicy::kAlwaysNull) continue; if (IsCustomizedRecursivelyOnly()) { - if (IsFieldFinitelyRecursive(oneof->field(i))) continue; + if (IsRecursionBreaker(oneof->field(i))) continue; if (MustBeUnset(oneof->field(i))) continue; } fields.push_back(i); @@ -1704,21 +1704,12 @@ class ProtobufDomainUntypedImpl return GetDomainForField(field, /*use_policy=*/false); } - // Analysis type for protobuf recursions. - enum class RecursionType { - // The proto contains a proto of type P, that must contain another P. - kInfinitelyRecursive, - // The proto contains a proto of type P, that can contain another P. - kFinitelyRecursive, - }; - // Returns true if there are subprotos in the `descriptor` that form an // infinite recursion. bool IsInfinitelyRecursive(const Descriptor* descriptor) const { FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); absl::flat_hash_set parents; - return IsProtoRecursive(/*field=*/nullptr, parents, - RecursionType::kInfinitelyRecursive, descriptor); + return IsProtoRecursive(/*field=*/nullptr, parents, descriptor); } // Returns true if there are subfields in the `field` that form an @@ -1727,14 +1718,6 @@ class ProtobufDomainUntypedImpl // customized using `WithFieldsAlwaysSet`). bool IsInfinitelyRecursive(const FieldDescriptor* field) const { FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); - absl::flat_hash_set parents; - return IsProtoRecursive(field, parents, - RecursionType::kInfinitelyRecursive); - } - - bool IsFieldFinitelyRecursive(const FieldDescriptor* field) { - FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); - if (!field->message_type()) return false; ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); static absl::NoDestructor< absl::flat_hash_map, bool>> @@ -1745,8 +1728,7 @@ class ProtobufDomainUntypedImpl if (it != cache->end()) return it->second; } absl::flat_hash_set parents; - bool result = - IsProtoRecursive(field, parents, RecursionType::kFinitelyRecursive); + bool result = IsProtoRecursive(field, parents); absl::MutexLock l(&mutex); cache->insert({{policy_.id(), field}, result}); return result; @@ -1760,27 +1742,18 @@ class ProtobufDomainUntypedImpl return index == kFieldCountIndex; } - bool IsOneofRecursive(const OneofDescriptor* oneof, - absl::flat_hash_set& parents, - RecursionType recursion_type) const { + bool IsOneofRecursive( + const OneofDescriptor* oneof, + absl::flat_hash_set& parents) const { bool is_oneof_recursive = false; for (int i = 0; i < oneof->field_count(); ++i) { const auto* field = oneof->field(i); const auto field_policy = policy_.GetOptionalPolicy(field); if (field_policy == OptionalPolicy::kAlwaysNull) continue; - if (recursion_type == RecursionType::kInfinitelyRecursive) { - is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && - field->message_type() && - IsProtoRecursive(field, parents, recursion_type); - if (!is_oneof_recursive) { - return false; - } - } else { - if (field->message_type() && - IsProtoRecursive(field, parents, recursion_type)) { - return true; - } - } + is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && + field->message_type() && + IsProtoRecursive(field, parents); + if (!is_oneof_recursive) return false; } return is_oneof_recursive; } @@ -1829,7 +1802,6 @@ class ProtobufDomainUntypedImpl // If `field` is nullptr, all fields of `descriptor` are checked. bool IsProtoRecursive(const FieldDescriptor* field, absl::flat_hash_set& parents, - RecursionType recursion_type, const Descriptor* descriptor = nullptr) const { if (field != nullptr) { if (parents.contains(field)) return true; @@ -1841,7 +1813,7 @@ class ProtobufDomainUntypedImpl } for (int i = 0; i < descriptor->oneof_decl_count(); ++i) { const auto* oneof = descriptor->oneof_decl(i); - if (IsOneofRecursive(oneof, parents, recursion_type)) { + if (IsOneofRecursive(oneof, parents)) { if (field != nullptr) parents.erase(field); return true; } @@ -1857,12 +1829,8 @@ class ProtobufDomainUntypedImpl default_domain->Init(prng); continue; } - if (recursion_type == RecursionType::kInfinitelyRecursive) { - if (!MustBeSet(subfield)) continue; - } else { - if (MustBeUnset(subfield)) continue; - } - if (IsProtoRecursive(subfield, parents, recursion_type)) { + if (!MustBeSet(subfield)) continue; + if (IsProtoRecursive(subfield, parents)) { if (field != nullptr) parents.erase(field); return true; } @@ -1871,6 +1839,68 @@ class ProtobufDomainUntypedImpl return false; } + // A subset of proto types are considered as recursion breakers and won't + // get recursively initialized during domain initialization to avoid + // non-terminating initialization. + // + // Returns true if the `field` (F0) does not have to be set, and there are + // subfields in the form: F0 -> F1 -> ... -> Fn -> F0 or F20 -> F19 ... -> F0 + // and none of other Fi-s are marked as recursion breakers so far. In other + // words, this method computes recursion breakers and check membership of + // `field` in the set of recursion breakers. + bool IsRecursionBreaker(const FieldDescriptor* field) { + FUZZTEST_INTERNAL_CHECK(IsCustomizedRecursivelyOnly(), "Internal error."); + if (!field->message_type()) return false; + absl::flat_hash_set parents; + return IsRecursionBreaker(/*root=*/field, field, parents); + } + + bool IsRecursionBreaker( + const FieldDescriptor* root, const FieldDescriptor* field, + absl::flat_hash_set& parents) const { + ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); + static absl::NoDestructor< + absl::flat_hash_map, bool>> + cache ABSL_GUARDED_BY(mutex); + { + absl::MutexLock l(&mutex); + auto it = cache->find({policy_.id(), field}); + if (it != cache->end()) return it->second; + } + // Cannot break the recursion for required fields. + bool can_be_unset = !MustBeSet(field); + if (field->containing_oneof() && !can_be_unset) { // oneof must be set + // We check whether `field` is infinitely recursive without considering + // other oneof fields. If it is, there's another field in the oneof that + // can be set. + absl::flat_hash_set subfield_parents; + subfield_parents.insert(field); + can_be_unset = IsProtoRecursive(field, subfield_parents); + } + if (can_be_unset) { + // Break recursion for deeply nested or recursive protos. + if (parents.size() > 20 || parents.contains(field)) { + absl::MutexLock l(&mutex); + cache->insert({{policy_.id(), field}, true}); + return true; + } + parents.insert(field); + } + for (const FieldDescriptor* subfield : + GetProtobufFields(field->message_type())) { + if (!subfield->message_type()) continue; + if (MustBeUnset(subfield)) continue; + IsRecursionBreaker(root, subfield, parents); + } + if (can_be_unset) parents.erase(field); + absl::MutexLock l(&mutex); + // If the result is computed while visiting the children, we shouldn't + // overwrite. For example, if we visit A -> B -> C -> A, we can return the + // result of the nested A for upper-level A. + auto [it, inserted] = cache->insert({{policy_.id(), field}, false}); + return it->second; + } + bool IsRequired(const FieldDescriptor* field) const { return field->is_required() || IsMapValueMessage(field); }