Skip to content

Commit

Permalink
Initialize recursive protobuf fields more efficiently.
Browse files Browse the repository at this point in the history
Currently, a field (F0) is not initialized if there are subfields of the form F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs. This means that for a huge protobuf field that has a recursive subfield deep in its definition, the whole field is not initialized. And later, even when F0 is initialized, F1 won't get initialized, etc. This could be very inefficient. To avoid this, we define "recursion breaker fields". For example, "Fs" becomes a recursion breaker. Then, all fields up to Fs are initialized. And later when Fs gets initialized, all Fs -> ... -> Fn get initialized.

PiperOrigin-RevId: 722802824
  • Loading branch information
hadi88 authored and copybara-github committed Feb 4, 2025
1 parent 1752ee0 commit 192f0f6
Showing 1 changed file with 85 additions and 28 deletions.
113 changes: 85 additions & 28 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class ProtobufDomainUntypedImpl
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!MustBeSet(field) && IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(field)) {
IsRecursionBreaker(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
Expand Down Expand Up @@ -951,7 +951,7 @@ class ProtobufDomainUntypedImpl
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(oneof->field(i))) {
IsRecursionBreaker(oneof->field(i))) {
continue;
}
fields.push_back(i);
Expand Down Expand Up @@ -1697,30 +1697,65 @@ class ProtobufDomainUntypedImpl
kFinitelyRecursive,
};

// Returns true if there are subfields in the proto that form an infinite
// recursion of the form: F0 -> F1 -> Fs -> ... -> Fn -> Fs, because all Fi-s
// have to be set.
bool IsInfinitelyRecursive() {
absl::flat_hash_set<decltype(prototype_.Get()->GetDescriptor())> parents;
return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents,
auto descriptor = prototype_.Get()->GetDescriptor();
absl::flat_hash_set<decltype(descriptor)> parents;
return IsProtoRecursive(descriptor, descriptor, parents,
RecursionType::kInfinitelyRecursive);
}

bool IsFieldFinitelyRecursive(const FieldDescriptor* field) {
if (!field->message_type()) return false;
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
static absl::NoDestructor<absl::flat_hash_map<const FieldDescriptor*, bool>>
cache ABSL_GUARDED_BY(mutex);
bool can_use_cache = IsCustomizedRecursivelyOnly();
if (can_use_cache) {
// A `field` (F0) is recursion breaker if it not infinitely recursive and
// there are subfields in the form: F0 -> F1 -> ... -> Fn -> F0 where none of
// other Fi-s are recursion breakers. In other words, for each recursion loop,
// it is sufficient to have only one recursion breaker.
//
// During domain initialization, the recursion breakers are detected and won't
// get initialized to avoid non-terminating initialization.
class RecursionBreakers {
public:
// Returns nullopt if the answer is not known yet.
static std::optional<bool> IsRecursionBreaker(
const FieldDescriptor* field) {
absl::MutexLock l(&mutex);
auto it = cache->find(field);
if (it != cache->end()) return it->second;
auto it = GetKnownBreakers().find(field);
return it != GetKnownBreakers().end() ? std::optional(it->second)
: std::nullopt;
}
absl::flat_hash_set<decltype(field->message_type())> parents;
bool result = IsProtoRecursive(field->message_type(), parents,
RecursionType::kFinitelyRecursive);
if (can_use_cache) {

static void Add(const FieldDescriptor* field, bool is_recursion_breaker) {
absl::MutexLock l(&mutex);
cache->insert({field, result});
GetKnownBreakers().insert({field, is_recursion_breaker});
}

private:
static absl::flat_hash_map<const FieldDescriptor*, bool>&
GetKnownBreakers() {
static absl::NoDestructor<
absl::flat_hash_map<const FieldDescriptor*, bool>>
known_breakers;
return *known_breakers;
}
static absl::Mutex mutex;
};

// Returns true if the `field` (F0) is not infinitely recursive but there are
// subfields in the form: F0 -> F1 -> ... -> Fn -> F0 and none of other Fi-s
// are marked as recursion breakers so far.
bool IsRecursionBreaker(const FieldDescriptor* field) {
auto descriptor = field->message_type();
if (!descriptor) return false;
bool can_be_recursion_breaker = IsCustomizedRecursivelyOnly();
if (can_be_recursion_breaker) {
auto cache = RecursionBreakers::IsRecursionBreaker(field);
if (cache.has_value()) return *cache;
}
absl::flat_hash_set<decltype(descriptor)> parents;
bool result = IsProtoRecursive(descriptor, descriptor, parents,
RecursionType::kFinitelyRecursive);
if (can_be_recursion_breaker) RecursionBreakers::Add(field, result);
return result;
}

Expand All @@ -1732,7 +1767,7 @@ class ProtobufDomainUntypedImpl
return index == kFieldCountIndex;
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
bool IsOneofRecursive(const Descriptor* root, const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
RecursionType recursion_type) const {
bool is_oneof_recursive = false;
Expand All @@ -1742,14 +1777,14 @@ class ProtobufDomainUntypedImpl
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
const auto* child = field->message_type();
if (recursion_type == RecursionType::kInfinitelyRecursive) {
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
child &&
IsProtoRecursive(child, parents, recursion_type);
is_oneof_recursive =
field_policy != OptionalPolicy::kWithNull && child &&
IsProtoRecursive(root, child, parents, recursion_type);
if (!is_oneof_recursive) {
return false;
}
} else {
if (child && IsProtoRecursive(child, parents, recursion_type)) {
if (child && IsProtoRecursive(root, child, parents, recursion_type)) {
return true;
}
}
Expand Down Expand Up @@ -1789,15 +1824,26 @@ class ProtobufDomainUntypedImpl
return false;
}

template <typename Descriptor>
bool IsProtoRecursive(const Descriptor* descriptor,
bool IsProtoRecursive(const Descriptor* root, const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents,
RecursionType recursion_type) const {
if (parents.contains(descriptor)) return true;
if (parents.contains(descriptor)) {
if (recursion_type == RecursionType::kInfinitelyRecursive) {
// Any path in the lasso form: F0 -> F1 -> Fs -> ... -> Fn -> Fs yields
// an infinitely recursion.
return true;
} else {
// A Path in the lasso form: F0 -> F1 -> Fs -> ... -> Fn -> Fs is OK as
// root is not in the recursion path and we can postpone the recursion
// detection. But, we signal the recursion when we find a loop of the
// form : F0 -> F1 -> ... -> Fn -> F0.
return descriptor == root;
};
}
parents.insert(descriptor);
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents, recursion_type)) {
if (IsOneofRecursive(root, oneof, parents, recursion_type)) {
parents.erase(descriptor);
return true;
}
Expand All @@ -1815,8 +1861,14 @@ class ProtobufDomainUntypedImpl
if (!MustBeSet(field)) continue;
} else {
if (MustBeUnset(field)) continue;
// Skip if the recursion is already broken through `field`.
if (bool can_be_recursion_breaker = IsCustomizedRecursivelyOnly()) {
std::optional<bool> is_recursion_breaker =
RecursionBreakers::IsRecursionBreaker(field);
if (is_recursion_breaker && *is_recursion_breaker) continue;
}
}
if (IsProtoRecursive(child, parents, recursion_type)) {
if (IsProtoRecursive(root, child, parents, recursion_type)) {
parents.erase(descriptor);
return true;
}
Expand Down Expand Up @@ -1857,6 +1909,11 @@ class ProtobufDomainUntypedImpl
absl::flat_hash_set<int> unset_oneof_fields_;
};

template <typename Message>
ABSL_CONST_INIT absl::Mutex
ProtobufDomainUntypedImpl<Message>::RecursionBreakers::mutex(
absl::kConstInit);

// Domain for `T` where `T` is a Protobuf message type.
// It is a small wrapper around `ProtobufDomainUntypedImpl` to make its API more
// convenient.
Expand Down

0 comments on commit 192f0f6

Please sign in to comment.