From 192f0f6222482cc59e0dea4a133f99ba8ccdaf2a Mon Sep 17 00:00:00 2001 From: Hadi Ravanbakhsh Date: Mon, 3 Feb 2025 14:36:06 -0800 Subject: [PATCH] Initialize recursive protobuf fields more efficiently. Currently, a field (F0) is not initialized if there are subfields of the form F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs. This means that for a huge protobuf field that has a recursive subfield deep in its definition, the whole field is not initialized. And later, even when F0 is initialized, F1 won't get initialized, etc. This could be very inefficient. To avoid this, we define "recursion breaker fields". For example, "Fs" becomes a recursion breaker. Then, all fields up to Fs are initialized. And later when Fs gets initialized, all Fs -> ... -> Fn get initialized. PiperOrigin-RevId: 722802824 --- .../internal/domains/protobuf_domain_impl.h | 113 +++++++++++++----- 1 file changed, 85 insertions(+), 28 deletions(-) diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index 1485f62a..0e170e19 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -474,7 +474,7 @@ class ProtobufDomainUntypedImpl } if (oneof_to_field[oneof->index()] != field->index()) continue; } else if (!MustBeSet(field) && IsCustomizedRecursivelyOnly() && - IsFieldFinitelyRecursive(field)) { + IsRecursionBreaker(field)) { // We avoid initializing non-required recursive fields by default (if // they are not explicitly customized). Otherwise, the initialization // may never terminate. If a proto has only non-required recursive @@ -951,7 +951,7 @@ class ProtobufDomainUntypedImpl OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i)); if (policy == OptionalPolicy::kAlwaysNull) continue; if (IsCustomizedRecursivelyOnly() && - IsFieldFinitelyRecursive(oneof->field(i))) { + IsRecursionBreaker(oneof->field(i))) { continue; } fields.push_back(i); @@ -1697,30 +1697,65 @@ class ProtobufDomainUntypedImpl kFinitelyRecursive, }; + // Returns true if there are subfields in the proto that form an infinite + // recursion of the form: F0 -> F1 -> Fs -> ... -> Fn -> Fs, because all Fi-s + // have to be set. bool IsInfinitelyRecursive() { - absl::flat_hash_setGetDescriptor())> parents; - return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents, + auto descriptor = prototype_.Get()->GetDescriptor(); + absl::flat_hash_set parents; + return IsProtoRecursive(descriptor, descriptor, parents, RecursionType::kInfinitelyRecursive); } - bool IsFieldFinitelyRecursive(const FieldDescriptor* field) { - if (!field->message_type()) return false; - ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); - static absl::NoDestructor> - cache ABSL_GUARDED_BY(mutex); - bool can_use_cache = IsCustomizedRecursivelyOnly(); - if (can_use_cache) { + // A `field` (F0) is recursion breaker if it not infinitely recursive and + // there are subfields in the form: F0 -> F1 -> ... -> Fn -> F0 where none of + // other Fi-s are recursion breakers. In other words, for each recursion loop, + // it is sufficient to have only one recursion breaker. + // + // During domain initialization, the recursion breakers are detected and won't + // get initialized to avoid non-terminating initialization. + class RecursionBreakers { + public: + // Returns nullopt if the answer is not known yet. + static std::optional IsRecursionBreaker( + const FieldDescriptor* field) { absl::MutexLock l(&mutex); - auto it = cache->find(field); - if (it != cache->end()) return it->second; + auto it = GetKnownBreakers().find(field); + return it != GetKnownBreakers().end() ? std::optional(it->second) + : std::nullopt; } - absl::flat_hash_setmessage_type())> parents; - bool result = IsProtoRecursive(field->message_type(), parents, - RecursionType::kFinitelyRecursive); - if (can_use_cache) { + + static void Add(const FieldDescriptor* field, bool is_recursion_breaker) { absl::MutexLock l(&mutex); - cache->insert({field, result}); + GetKnownBreakers().insert({field, is_recursion_breaker}); + } + + private: + static absl::flat_hash_map& + GetKnownBreakers() { + static absl::NoDestructor< + absl::flat_hash_map> + known_breakers; + return *known_breakers; } + static absl::Mutex mutex; + }; + + // Returns true if the `field` (F0) is not infinitely recursive but there are + // subfields in the form: F0 -> F1 -> ... -> Fn -> F0 and none of other Fi-s + // are marked as recursion breakers so far. + bool IsRecursionBreaker(const FieldDescriptor* field) { + auto descriptor = field->message_type(); + if (!descriptor) return false; + bool can_be_recursion_breaker = IsCustomizedRecursivelyOnly(); + if (can_be_recursion_breaker) { + auto cache = RecursionBreakers::IsRecursionBreaker(field); + if (cache.has_value()) return *cache; + } + absl::flat_hash_set parents; + bool result = IsProtoRecursive(descriptor, descriptor, parents, + RecursionType::kFinitelyRecursive); + if (can_be_recursion_breaker) RecursionBreakers::Add(field, result); return result; } @@ -1732,7 +1767,7 @@ class ProtobufDomainUntypedImpl return index == kFieldCountIndex; } - bool IsOneofRecursive(const OneofDescriptor* oneof, + bool IsOneofRecursive(const Descriptor* root, const OneofDescriptor* oneof, absl::flat_hash_set& parents, RecursionType recursion_type) const { bool is_oneof_recursive = false; @@ -1742,14 +1777,14 @@ class ProtobufDomainUntypedImpl if (field_policy == OptionalPolicy::kAlwaysNull) continue; const auto* child = field->message_type(); if (recursion_type == RecursionType::kInfinitelyRecursive) { - is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && - child && - IsProtoRecursive(child, parents, recursion_type); + is_oneof_recursive = + field_policy != OptionalPolicy::kWithNull && child && + IsProtoRecursive(root, child, parents, recursion_type); if (!is_oneof_recursive) { return false; } } else { - if (child && IsProtoRecursive(child, parents, recursion_type)) { + if (child && IsProtoRecursive(root, child, parents, recursion_type)) { return true; } } @@ -1789,15 +1824,26 @@ class ProtobufDomainUntypedImpl return false; } - template - bool IsProtoRecursive(const Descriptor* descriptor, + bool IsProtoRecursive(const Descriptor* root, const Descriptor* descriptor, absl::flat_hash_set& parents, RecursionType recursion_type) const { - if (parents.contains(descriptor)) return true; + if (parents.contains(descriptor)) { + if (recursion_type == RecursionType::kInfinitelyRecursive) { + // Any path in the lasso form: F0 -> F1 -> Fs -> ... -> Fn -> Fs yields + // an infinitely recursion. + return true; + } else { + // A Path in the lasso form: F0 -> F1 -> Fs -> ... -> Fn -> Fs is OK as + // root is not in the recursion path and we can postpone the recursion + // detection. But, we signal the recursion when we find a loop of the + // form : F0 -> F1 -> ... -> Fn -> F0. + return descriptor == root; + }; + } parents.insert(descriptor); for (int i = 0; i < descriptor->oneof_decl_count(); ++i) { const auto* oneof = descriptor->oneof_decl(i); - if (IsOneofRecursive(oneof, parents, recursion_type)) { + if (IsOneofRecursive(root, oneof, parents, recursion_type)) { parents.erase(descriptor); return true; } @@ -1815,8 +1861,14 @@ class ProtobufDomainUntypedImpl if (!MustBeSet(field)) continue; } else { if (MustBeUnset(field)) continue; + // Skip if the recursion is already broken through `field`. + if (bool can_be_recursion_breaker = IsCustomizedRecursivelyOnly()) { + std::optional is_recursion_breaker = + RecursionBreakers::IsRecursionBreaker(field); + if (is_recursion_breaker && *is_recursion_breaker) continue; + } } - if (IsProtoRecursive(child, parents, recursion_type)) { + if (IsProtoRecursive(root, child, parents, recursion_type)) { parents.erase(descriptor); return true; } @@ -1857,6 +1909,11 @@ class ProtobufDomainUntypedImpl absl::flat_hash_set unset_oneof_fields_; }; +template +ABSL_CONST_INIT absl::Mutex + ProtobufDomainUntypedImpl::RecursionBreakers::mutex( + absl::kConstInit); + // Domain for `T` where `T` is a Protobuf message type. // It is a small wrapper around `ProtobufDomainUntypedImpl` to make its API more // convenient.