diff --git a/fuzztest/internal/domains/protobuf_domain_impl.h b/fuzztest/internal/domains/protobuf_domain_impl.h index 1485f62a..0e170e19 100644 --- a/fuzztest/internal/domains/protobuf_domain_impl.h +++ b/fuzztest/internal/domains/protobuf_domain_impl.h @@ -474,7 +474,7 @@ class ProtobufDomainUntypedImpl } if (oneof_to_field[oneof->index()] != field->index()) continue; } else if (!MustBeSet(field) && IsCustomizedRecursivelyOnly() && - IsFieldFinitelyRecursive(field)) { + IsRecursionBreaker(field)) { // We avoid initializing non-required recursive fields by default (if // they are not explicitly customized). Otherwise, the initialization // may never terminate. If a proto has only non-required recursive @@ -951,7 +951,7 @@ class ProtobufDomainUntypedImpl OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i)); if (policy == OptionalPolicy::kAlwaysNull) continue; if (IsCustomizedRecursivelyOnly() && - IsFieldFinitelyRecursive(oneof->field(i))) { + IsRecursionBreaker(oneof->field(i))) { continue; } fields.push_back(i); @@ -1697,30 +1697,65 @@ class ProtobufDomainUntypedImpl kFinitelyRecursive, }; + // Returns true if there are subfields in the proto that form an infinite + // recursion of the form: F0 -> F1 -> Fs -> ... -> Fn -> Fs, because all Fi-s + // have to be set. bool IsInfinitelyRecursive() { - absl::flat_hash_setGetDescriptor())> parents; - return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents, + auto descriptor = prototype_.Get()->GetDescriptor(); + absl::flat_hash_set parents; + return IsProtoRecursive(descriptor, descriptor, parents, RecursionType::kInfinitelyRecursive); } - bool IsFieldFinitelyRecursive(const FieldDescriptor* field) { - if (!field->message_type()) return false; - ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit); - static absl::NoDestructor> - cache ABSL_GUARDED_BY(mutex); - bool can_use_cache = IsCustomizedRecursivelyOnly(); - if (can_use_cache) { + // A `field` (F0) is recursion breaker if it not infinitely recursive and + // there are subfields in the form: F0 -> F1 -> ... -> Fn -> F0 where none of + // other Fi-s are recursion breakers. In other words, for each recursion loop, + // it is sufficient to have only one recursion breaker. + // + // During domain initialization, the recursion breakers are detected and won't + // get initialized to avoid non-terminating initialization. + class RecursionBreakers { + public: + // Returns nullopt if the answer is not known yet. + static std::optional IsRecursionBreaker( + const FieldDescriptor* field) { absl::MutexLock l(&mutex); - auto it = cache->find(field); - if (it != cache->end()) return it->second; + auto it = GetKnownBreakers().find(field); + return it != GetKnownBreakers().end() ? std::optional(it->second) + : std::nullopt; } - absl::flat_hash_setmessage_type())> parents; - bool result = IsProtoRecursive(field->message_type(), parents, - RecursionType::kFinitelyRecursive); - if (can_use_cache) { + + static void Add(const FieldDescriptor* field, bool is_recursion_breaker) { absl::MutexLock l(&mutex); - cache->insert({field, result}); + GetKnownBreakers().insert({field, is_recursion_breaker}); + } + + private: + static absl::flat_hash_map& + GetKnownBreakers() { + static absl::NoDestructor< + absl::flat_hash_map> + known_breakers; + return *known_breakers; } + static absl::Mutex mutex; + }; + + // Returns true if the `field` (F0) is not infinitely recursive but there are + // subfields in the form: F0 -> F1 -> ... -> Fn -> F0 and none of other Fi-s + // are marked as recursion breakers so far. + bool IsRecursionBreaker(const FieldDescriptor* field) { + auto descriptor = field->message_type(); + if (!descriptor) return false; + bool can_be_recursion_breaker = IsCustomizedRecursivelyOnly(); + if (can_be_recursion_breaker) { + auto cache = RecursionBreakers::IsRecursionBreaker(field); + if (cache.has_value()) return *cache; + } + absl::flat_hash_set parents; + bool result = IsProtoRecursive(descriptor, descriptor, parents, + RecursionType::kFinitelyRecursive); + if (can_be_recursion_breaker) RecursionBreakers::Add(field, result); return result; } @@ -1732,7 +1767,7 @@ class ProtobufDomainUntypedImpl return index == kFieldCountIndex; } - bool IsOneofRecursive(const OneofDescriptor* oneof, + bool IsOneofRecursive(const Descriptor* root, const OneofDescriptor* oneof, absl::flat_hash_set& parents, RecursionType recursion_type) const { bool is_oneof_recursive = false; @@ -1742,14 +1777,14 @@ class ProtobufDomainUntypedImpl if (field_policy == OptionalPolicy::kAlwaysNull) continue; const auto* child = field->message_type(); if (recursion_type == RecursionType::kInfinitelyRecursive) { - is_oneof_recursive = field_policy != OptionalPolicy::kWithNull && - child && - IsProtoRecursive(child, parents, recursion_type); + is_oneof_recursive = + field_policy != OptionalPolicy::kWithNull && child && + IsProtoRecursive(root, child, parents, recursion_type); if (!is_oneof_recursive) { return false; } } else { - if (child && IsProtoRecursive(child, parents, recursion_type)) { + if (child && IsProtoRecursive(root, child, parents, recursion_type)) { return true; } } @@ -1789,15 +1824,26 @@ class ProtobufDomainUntypedImpl return false; } - template - bool IsProtoRecursive(const Descriptor* descriptor, + bool IsProtoRecursive(const Descriptor* root, const Descriptor* descriptor, absl::flat_hash_set& parents, RecursionType recursion_type) const { - if (parents.contains(descriptor)) return true; + if (parents.contains(descriptor)) { + if (recursion_type == RecursionType::kInfinitelyRecursive) { + // Any path in the lasso form: F0 -> F1 -> Fs -> ... -> Fn -> Fs yields + // an infinitely recursion. + return true; + } else { + // A Path in the lasso form: F0 -> F1 -> Fs -> ... -> Fn -> Fs is OK as + // root is not in the recursion path and we can postpone the recursion + // detection. But, we signal the recursion when we find a loop of the + // form : F0 -> F1 -> ... -> Fn -> F0. + return descriptor == root; + }; + } parents.insert(descriptor); for (int i = 0; i < descriptor->oneof_decl_count(); ++i) { const auto* oneof = descriptor->oneof_decl(i); - if (IsOneofRecursive(oneof, parents, recursion_type)) { + if (IsOneofRecursive(root, oneof, parents, recursion_type)) { parents.erase(descriptor); return true; } @@ -1815,8 +1861,14 @@ class ProtobufDomainUntypedImpl if (!MustBeSet(field)) continue; } else { if (MustBeUnset(field)) continue; + // Skip if the recursion is already broken through `field`. + if (bool can_be_recursion_breaker = IsCustomizedRecursivelyOnly()) { + std::optional is_recursion_breaker = + RecursionBreakers::IsRecursionBreaker(field); + if (is_recursion_breaker && *is_recursion_breaker) continue; + } } - if (IsProtoRecursive(child, parents, recursion_type)) { + if (IsProtoRecursive(root, child, parents, recursion_type)) { parents.erase(descriptor); return true; } @@ -1857,6 +1909,11 @@ class ProtobufDomainUntypedImpl absl::flat_hash_set unset_oneof_fields_; }; +template +ABSL_CONST_INIT absl::Mutex + ProtobufDomainUntypedImpl::RecursionBreakers::mutex( + absl::kConstInit); + // Domain for `T` where `T` is a Protobuf message type. // It is a small wrapper around `ProtobufDomainUntypedImpl` to make its API more // convenient.