Skip to content

Commit

Permalink
Initialize recursive protobuf fields more efficiently.
Browse files Browse the repository at this point in the history
Currently, a field (F0) is not initialized if there are subfields of the form F0 -> F1 -> ... -> Fs -> ... -> Fn -> Fs. This means that for a huge protobuf field that has a recursive subfield deep in its definition, the whole field is not initialized. And later, even when F0 is initialized, F1 won't get initialized, etc. This could be very inefficient. To avoid this, we define "recursion breaker fields". For example, "Fs" becomes a recursion breaker. Then, all fields up to Fs are initialized. And later when Fs gets initialized, all Fs -> ... -> Fn get initialized.

PiperOrigin-RevId: 722802824
  • Loading branch information
hadi88 authored and copybara-github committed Feb 3, 2025
1 parent 1752ee0 commit d72cecb
Showing 1 changed file with 68 additions and 23 deletions.
91 changes: 68 additions & 23 deletions fuzztest/internal/domains/protobuf_domain_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class ProtobufDomainUntypedImpl
}
if (oneof_to_field[oneof->index()] != field->index()) continue;
} else if (!MustBeSet(field) && IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(field)) {
IsRecursionBreaker(field)) {
// We avoid initializing non-required recursive fields by default (if
// they are not explicitly customized). Otherwise, the initialization
// may never terminate. If a proto has only non-required recursive
Expand Down Expand Up @@ -951,7 +951,7 @@ class ProtobufDomainUntypedImpl
OptionalPolicy policy = GetOneofFieldPolicy(oneof->field(i));
if (policy == OptionalPolicy::kAlwaysNull) continue;
if (IsCustomizedRecursivelyOnly() &&
IsFieldFinitelyRecursive(oneof->field(i))) {
IsRecursionBreaker(oneof->field(i))) {
continue;
}
fields.push_back(i);
Expand Down Expand Up @@ -1697,29 +1697,58 @@ class ProtobufDomainUntypedImpl
kFinitelyRecursive,
};

// Returns true if there are subfields in the proto that form an infinite
// recursion of the form: F0 -> F1 -> Fs -> ... -> Fn -> Fs, because all Fi-s
// have to be set.
bool IsInfinitelyRecursive() {
absl::flat_hash_set<decltype(prototype_.Get()->GetDescriptor())> parents;
return IsProtoRecursive(prototype_.Get()->GetDescriptor(), parents,
auto descriptor = prototype_.Get()->GetDescriptor();
absl::flat_hash_set<decltype(descriptor)> parents;
return IsProtoRecursive(descriptor, descriptor, parents,
RecursionType::kInfinitelyRecursive);
}

bool IsFieldFinitelyRecursive(const FieldDescriptor* field) {
if (!field->message_type()) return false;
ABSL_CONST_INIT static absl::Mutex mutex(absl::kConstInit);
class RecursionBreakerCache {
static absl::Mutex mutex_;
static absl::NoDestructor<absl::flat_hash_map<const FieldDescriptor*, bool>>
cache ABSL_GUARDED_BY(mutex);
cache ABSL_GUARDED_BY(mutex_);

static std::optional<bool> IsRecursionBreaker(
const FieldDescriptor* field) {
absl::MutexLock l(&mutex_);
auto it = cache->find(field);
return it != cache->end() ? it->second : std::nullopt;
}

static void Insert(const FieldDescriptor* field,
bool is_recursion_breaker) {
absl::MutexLock l(&mutex_);
cache->insert({field, is_recursion_breaker});
}
};

// A `field` (F0) is recursion breaker if it not infinitely recursive and
// there are subfields in the form: F0 -> F1 -> ... -> Fn -> F0 where none of
// other Fi-s are recursion breakers. In other words, for each recursion loop,
// it is sufficient to have only one recursion breaker.
//
// During domain initialization, the recursion breakers are detected and won't
// get initialized to avoid non-terminating initialization.
//
// Returns true if the `field` (F0) is not infinitely recursive but there are
// subfields in the form: F0 -> F1 -> ... -> Fn -> F0 and none of other Fi-s
// are marked as recursion breakers so far.
bool IsRecursionBreaker(const FieldDescriptor* field) {
if (!field->message_type()) return false;
bool can_use_cache = IsCustomizedRecursivelyOnly();
if (can_use_cache) {
absl::MutexLock l(&mutex);
auto it = cache->find(field);
if (it != cache->end()) return it->second;
auto cache = RecursionBreakerCache::IsRecursionBreaker(field);
if (cache.has_value()) return *cache;
}
absl::flat_hash_set<decltype(field->message_type())> parents;
bool result = IsProtoRecursive(field->message_type(), parents,
RecursionType::kFinitelyRecursive);
if (can_use_cache) {
absl::MutexLock l(&mutex);
cache->insert({field, result});
RecursionBreakerCache::Insert(field, result);
}
return result;
}
Expand All @@ -1732,7 +1761,7 @@ class ProtobufDomainUntypedImpl
return index == kFieldCountIndex;
}

bool IsOneofRecursive(const OneofDescriptor* oneof,
bool IsOneofRecursive(const Descriptor* root, const OneofDescriptor* oneof,
absl::flat_hash_set<const Descriptor*>& parents,
RecursionType recursion_type) const {
bool is_oneof_recursive = false;
Expand All @@ -1742,14 +1771,14 @@ class ProtobufDomainUntypedImpl
if (field_policy == OptionalPolicy::kAlwaysNull) continue;
const auto* child = field->message_type();
if (recursion_type == RecursionType::kInfinitelyRecursive) {
is_oneof_recursive = field_policy != OptionalPolicy::kWithNull &&
child &&
IsProtoRecursive(child, parents, recursion_type);
is_oneof_recursive =
field_policy != OptionalPolicy::kWithNull && child &&
IsProtoRecursive(root, child, parents, recursion_type);
if (!is_oneof_recursive) {
return false;
}
} else {
if (child && IsProtoRecursive(child, parents, recursion_type)) {
if (child && IsProtoRecursive(root, child, parents, recursion_type)) {
return true;
}
}
Expand Down Expand Up @@ -1789,15 +1818,26 @@ class ProtobufDomainUntypedImpl
return false;
}

template <typename Descriptor>
bool IsProtoRecursive(const Descriptor* descriptor,
bool IsProtoRecursive(const Descriptor* root, const Descriptor* descriptor,
absl::flat_hash_set<const Descriptor*>& parents,
RecursionType recursion_type) const {
if (parents.contains(descriptor)) return true;
if (parents.contains(descriptor)) {
if (recursion_type == RecursionType::kInfinitelyRecursive) {
// Any path in the lasso form: F0 -> F1 -> Fs -> ... -> Fn -> Fs yields
// an infinitely recursion.
return true;
} else {
// A Path in the lasso form: F0 -> F1 -> Fs -> ... -> Fn -> Fs is OK as
// root is not in the recursion path and we can postpone the recursion
// detection. But, we signal the recursion when we find a loop of the
// form : F0 -> F1 -> ... -> Fn -> F0.
return descriptor == root;
};
}
parents.insert(descriptor);
for (int i = 0; i < descriptor->oneof_decl_count(); ++i) {
const auto* oneof = descriptor->oneof_decl(i);
if (IsOneofRecursive(oneof, parents, recursion_type)) {
if (IsOneofRecursive(root, oneof, parents, recursion_type)) {
parents.erase(descriptor);
return true;
}
Expand All @@ -1815,8 +1855,13 @@ class ProtobufDomainUntypedImpl
if (!MustBeSet(field)) continue;
} else {
if (MustBeUnset(field)) continue;
// Skip if the recursion is already broken through `field`.
if (bool can_use_cache = IsCustomizedRecursivelyOnly()) {
auto cache = RecursionBreakerCache::find(field);
if (cache.has_value() && *cache) continue;
}
}
if (IsProtoRecursive(child, parents, recursion_type)) {
if (IsProtoRecursive(root, child, parents, recursion_type)) {
parents.erase(descriptor);
return true;
}
Expand Down

0 comments on commit d72cecb

Please sign in to comment.