diff --git a/BUILD b/BUILD index b3f5d3e2e4295..310d4d53f57c1 100644 --- a/BUILD +++ b/BUILD @@ -605,7 +605,6 @@ GRPC_XDS_TARGETS = [ "//src/core:grpc_lb_policy_xds_cluster_manager", "//src/core:grpc_lb_policy_xds_override_host", "//src/core:grpc_lb_policy_xds_wrr_locality", - "//src/core:grpc_lb_policy_ring_hash", "//src/core:grpc_resolver_xds", "//src/core:grpc_resolver_c2p", "//src/core:grpc_xds_server_config_fetcher", @@ -838,6 +837,7 @@ grpc_cc_library( "//src/core:grpc_lb_policy_outlier_detection", "//src/core:grpc_lb_policy_pick_first", "//src/core:grpc_lb_policy_priority", + "//src/core:grpc_lb_policy_ring_hash", "//src/core:grpc_lb_policy_round_robin", "//src/core:grpc_lb_policy_weighted_round_robin", "//src/core:grpc_lb_policy_weighted_target", diff --git a/CMakeLists.txt b/CMakeLists.txt index 5aa53b28b9257..00003281ac66a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3671,6 +3671,7 @@ add_library(grpc_unsecure src/core/load_balancing/outlier_detection/outlier_detection.cc src/core/load_balancing/pick_first/pick_first.cc src/core/load_balancing/priority/priority.cc + src/core/load_balancing/ring_hash/ring_hash.cc src/core/load_balancing/rls/rls.cc src/core/load_balancing/round_robin/round_robin.cc src/core/load_balancing/weighted_round_robin/static_stride_scheduler.cc diff --git a/build_autogenerated.yaml b/build_autogenerated.yaml index 45df9d5128eab..9e0e3020e6810 100644 --- a/build_autogenerated.yaml +++ b/build_autogenerated.yaml @@ -2673,6 +2673,7 @@ libs: - src/core/load_balancing/oob_backend_metric_internal.h - src/core/load_balancing/outlier_detection/outlier_detection.h - src/core/load_balancing/pick_first/pick_first.h + - src/core/load_balancing/ring_hash/ring_hash.h - src/core/load_balancing/rls/rls.h - src/core/load_balancing/subchannel_interface.h - src/core/load_balancing/weighted_round_robin/static_stride_scheduler.h @@ -2762,7 +2763,9 @@ libs: - src/core/util/uuid_v4.h - src/core/util/validation_errors.h - src/core/util/work_serializer.h + - src/core/util/xxhash_inline.h - third_party/upb/upb/generated_code_support.h + - third_party/xxhash/xxhash.h src: - src/core/call/request_buffer.cc - src/core/channelz/channel_trace.cc @@ -3097,6 +3100,7 @@ libs: - src/core/load_balancing/outlier_detection/outlier_detection.cc - src/core/load_balancing/pick_first/pick_first.cc - src/core/load_balancing/priority/priority.cc + - src/core/load_balancing/ring_hash/ring_hash.cc - src/core/load_balancing/rls/rls.cc - src/core/load_balancing/round_robin/round_robin.cc - src/core/load_balancing/weighted_round_robin/static_stride_scheduler.cc @@ -20435,6 +20439,7 @@ targets: - test/core/event_engine/event_engine_test_utils.h - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h - test/core/load_balancing/lb_policy_test_lib.h + - test/core/test_util/scoped_env_var.h src: - test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.proto - test/core/event_engine/event_engine_test_utils.cc @@ -26825,6 +26830,7 @@ targets: run: false language: c++ headers: + - test/core/test_util/scoped_env_var.h - test/cpp/end2end/connection_attempt_injector.h - test/cpp/end2end/counted_service.h - test/cpp/end2end/test_service_impl.h diff --git a/src/core/BUILD b/src/core/BUILD index 9335eb1f22889..27bc4d6d9e0f6 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -6011,6 +6011,7 @@ grpc_cc_library( "grpc_fake_credentials", "grpc_fault_injection_filter", "grpc_lb_policy_pick_first", + "grpc_lb_policy_ring_hash", "grpc_lb_xds_channel_args", "grpc_matchers", "grpc_outlier_detection_header", @@ -6667,6 +6668,7 @@ grpc_cc_library( "absl/container:inlined_vector", "absl/log:check", "absl/log:log", + "absl/random", "absl/status", "absl/status:statusor", "absl/strings", @@ -6679,6 +6681,7 @@ grpc_cc_library( "closure", "connectivity_state", "delegating_helper", + "env", "error", "grpc_lb_policy_pick_first", "grpc_service_config", @@ -6690,6 +6693,7 @@ grpc_cc_library( "lb_policy_registry", "pollset_set", "ref_counted", + "ref_counted_string", "resolved_address", "unique_type_name", "validation_errors", diff --git a/src/core/load_balancing/lb_policy.h b/src/core/load_balancing/lb_policy.h index 1bf499bce447c..e7ff9d56939ca 100644 --- a/src/core/load_balancing/lb_policy.h +++ b/src/core/load_balancing/lb_policy.h @@ -152,14 +152,6 @@ class LoadBalancingPolicy : public InternallyRefCounted { /// The LB policy may use the existing metadata to influence its routing /// decision, and it may add new metadata elements to be sent with the /// call to the chosen backend. - // TODO(roth): Before making the LB policy API public, consider - // whether this is the right way to expose metadata to the picker. - // This approach means that if a pick modifies metadata but then we - // discard the pick because the subchannel is not connected, the - // metadata change will still have been made. Maybe we actually - // want to somehow provide metadata changes in PickResult::Complete - // instead? Or maybe we use a CallTracer that can add metadata when - // the call actually starts on the subchannel? MetadataInterface* initial_metadata; /// An interface for accessing call state. Can be used to allocate /// memory associated with the call in an efficient way. diff --git a/src/core/load_balancing/ring_hash/ring_hash.cc b/src/core/load_balancing/ring_hash/ring_hash.cc index 2409a40e425b1..f8f2c052ca0cb 100644 --- a/src/core/load_balancing/ring_hash/ring_hash.cc +++ b/src/core/load_balancing/ring_hash/ring_hash.cc @@ -35,6 +35,7 @@ #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/random/random.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -59,10 +60,12 @@ #include "src/core/resolver/endpoint_addresses.h" #include "src/core/util/crash.h" #include "src/core/util/debug_location.h" +#include "src/core/util/env.h" #include "src/core/util/json/json.h" #include "src/core/util/orphanable.h" #include "src/core/util/ref_counted.h" #include "src/core/util/ref_counted_ptr.h" +#include "src/core/util/ref_counted_string.h" #include "src/core/util/unique_type_name.h" #include "src/core/util/work_serializer.h" #include "src/core/util/xxhash_inline.h" @@ -74,53 +77,79 @@ UniqueTypeName RequestHashAttribute::TypeName() { return kFactory.Create(); } -// Helper Parser method +namespace { + +constexpr absl::string_view kRingHash = "ring_hash_experimental"; -const JsonLoaderInterface* RingHashConfig::JsonLoader(const JsonArgs&) { - static const auto* loader = - JsonObjectLoader() - .OptionalField("minRingSize", &RingHashConfig::min_ring_size) - .OptionalField("maxRingSize", &RingHashConfig::max_ring_size) - .Finish(); - return loader; +bool XdsRingHashSetRequestHashKeyEnabled() { + auto value = GetEnv("GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY"); + if (!value.has_value()) return false; + bool parsed_value; + bool parse_succeeded = gpr_parse_bool_value(value->c_str(), &parsed_value); + return parse_succeeded && parsed_value; } -void RingHashConfig::JsonPostLoad(const Json&, const JsonArgs&, - ValidationErrors* errors) { - { - ValidationErrors::ScopedField field(errors, ".minRingSize"); - if (!errors->FieldHasErrors() && - (min_ring_size == 0 || min_ring_size > 8388608)) { - errors->AddError("must be in the range [1, 8388608]"); - } - } - { - ValidationErrors::ScopedField field(errors, ".maxRingSize"); - if (!errors->FieldHasErrors() && - (max_ring_size == 0 || max_ring_size > 8388608)) { - errors->AddError("must be in the range [1, 8388608]"); +class RingHashJsonArgs final : public JsonArgs { + public: + bool IsEnabled(absl::string_view key) const override { + if (key == "request_hash_header") { + return XdsRingHashSetRequestHashKeyEnabled(); } + return true; } - if (min_ring_size > max_ring_size) { - errors->AddError("max_ring_size cannot be smaller than min_ring_size"); - } -} - -namespace { - -constexpr absl::string_view kRingHash = "ring_hash_experimental"; +}; class RingHashLbConfig final : public LoadBalancingPolicy::Config { public: - RingHashLbConfig(size_t min_ring_size, size_t max_ring_size) - : min_ring_size_(min_ring_size), max_ring_size_(max_ring_size) {} + RingHashLbConfig() = default; + + RingHashLbConfig(const RingHashLbConfig&) = delete; + RingHashLbConfig& operator=(const RingHashLbConfig&) = delete; + + RingHashLbConfig(RingHashLbConfig&& other) = delete; + RingHashLbConfig& operator=(RingHashLbConfig&& other) = delete; + absl::string_view name() const override { return kRingHash; } size_t min_ring_size() const { return min_ring_size_; } size_t max_ring_size() const { return max_ring_size_; } + absl::string_view request_hash_header() const { return request_hash_header_; } + + static const JsonLoaderInterface* JsonLoader(const JsonArgs&) { + static const auto* loader = + JsonObjectLoader() + .OptionalField("minRingSize", &RingHashLbConfig::min_ring_size_) + .OptionalField("maxRingSize", &RingHashLbConfig::max_ring_size_) + .OptionalField("requestHashHeader", + &RingHashLbConfig::request_hash_header_, + "request_hash_header") + .Finish(); + return loader; + } + + void JsonPostLoad(const Json&, const JsonArgs&, ValidationErrors* errors) { + { + ValidationErrors::ScopedField field(errors, ".minRingSize"); + if (!errors->FieldHasErrors() && + (min_ring_size_ == 0 || min_ring_size_ > 8388608)) { + errors->AddError("must be in the range [1, 8388608]"); + } + } + { + ValidationErrors::ScopedField field(errors, ".maxRingSize"); + if (!errors->FieldHasErrors() && + (max_ring_size_ == 0 || max_ring_size_ > 8388608)) { + errors->AddError("must be in the range [1, 8388608]"); + } + } + if (min_ring_size_ > max_ring_size_) { + errors->AddError("maxRingSize cannot be smaller than minRingSize"); + } + } private: - size_t min_ring_size_; - size_t max_ring_size_; + uint64_t min_ring_size_ = 1024; + uint64_t max_ring_size_ = 4096; + std::string request_hash_header_; }; // @@ -217,9 +246,13 @@ class RingHash final : public LoadBalancingPolicy { explicit Picker(RefCountedPtr ring_hash) : ring_hash_(std::move(ring_hash)), ring_(ring_hash_->ring_), - endpoints_(ring_hash_->endpoints_.size()) { + endpoints_(ring_hash_->endpoints_.size()), + request_hash_header_(ring_hash_->request_hash_header_) { for (const auto& p : ring_hash_->endpoint_map_) { endpoints_[p.second->index()] = p.second->GetInfoForPicker(); + if (endpoints_[p.second->index()].state == GRPC_CHANNEL_CONNECTING) { + has_endpoint_in_connecting_state_ = true; + } } } @@ -260,6 +293,8 @@ class RingHash final : public LoadBalancingPolicy { RefCountedPtr ring_hash_; RefCountedPtr ring_; std::vector endpoints_; + bool has_endpoint_in_connecting_state_ = false; + RefCountedStringValue request_hash_header_; }; ~RingHash() override; @@ -278,6 +313,7 @@ class RingHash final : public LoadBalancingPolicy { // Current endpoint list, channel args, and ring. EndpointAddressesList endpoints_; ChannelArgs args_; + RefCountedStringValue request_hash_header_; RefCountedPtr ring_; std::map> endpoint_map_; @@ -297,17 +333,34 @@ class RingHash final : public LoadBalancingPolicy { // RingHash::PickResult RingHash::Picker::Pick(PickArgs args) { - auto* call_state = static_cast(args.call_state); - auto* hash_attribute = call_state->GetCallAttribute(); - if (hash_attribute == nullptr) { - return PickResult::Fail(absl::InternalError("hash attribute not present")); + // Determine request hash. + bool using_random_hash = false; + uint64_t request_hash; + if (request_hash_header_.as_string_view().empty()) { + // Being used in xDS. Request hash is passed in via an attribute. + auto* call_state = static_cast(args.call_state); + auto* hash_attribute = call_state->GetCallAttribute(); + if (hash_attribute == nullptr) { + return PickResult::Fail( + absl::InternalError("hash attribute not present")); + } + request_hash = hash_attribute->request_hash(); + } else { + std::string buffer; + auto header_value = args.initial_metadata->Lookup( + request_hash_header_.as_string_view(), &buffer); + if (header_value.has_value()) { + request_hash = XXH64(header_value->data(), header_value->size(), 0); + } else { + request_hash = absl::Uniform(absl::BitGen()); + using_random_hash = true; + } } - uint64_t request_hash = hash_attribute->request_hash(); - const auto& ring = ring_->ring(); // Find the index in the ring to use for this RPC. // Ported from https://github.com/RJ/ketama/blob/master/libketama/ketama.c // (ketama_get_server) NOTE: The algorithm depends on using signed integers // for lowp, highp, and index. Do not change them! + const auto& ring = ring_->ring(); int64_t lowp = 0; int64_t highp = ring.size(); int64_t index = 0; @@ -333,22 +386,42 @@ RingHash::PickResult RingHash::Picker::Pick(PickArgs args) { } } // Find the first endpoint we can use from the selected index. - for (size_t i = 0; i < ring.size(); ++i) { - const auto& entry = ring[(index + i) % ring.size()]; - const auto& endpoint_info = endpoints_[entry.endpoint_index]; - switch (endpoint_info.state) { - case GRPC_CHANNEL_READY: + if (!using_random_hash) { + for (size_t i = 0; i < ring.size(); ++i) { + const auto& entry = ring[(index + i) % ring.size()]; + const auto& endpoint_info = endpoints_[entry.endpoint_index]; + switch (endpoint_info.state) { + case GRPC_CHANNEL_READY: + return endpoint_info.picker->Pick(args); + case GRPC_CHANNEL_IDLE: + new EndpointConnectionAttempter( + ring_hash_.Ref(DEBUG_LOCATION, "EndpointConnectionAttempter"), + endpoint_info.endpoint); + ABSL_FALLTHROUGH_INTENDED; + case GRPC_CHANNEL_CONNECTING: + return PickResult::Queue(); + default: + break; + } + } + } else { + // Using a random hash. We will use the first READY endpoint we + // find, triggering at most one endpoint to attempt connecting. + bool requested_connection = has_endpoint_in_connecting_state_; + for (size_t i = 0; i < ring.size(); ++i) { + const auto& entry = ring[(index + i) % ring.size()]; + const auto& endpoint_info = endpoints_[entry.endpoint_index]; + if (endpoint_info.state == GRPC_CHANNEL_READY) { return endpoint_info.picker->Pick(args); - case GRPC_CHANNEL_IDLE: + } + if (!requested_connection && endpoint_info.state == GRPC_CHANNEL_IDLE) { new EndpointConnectionAttempter( ring_hash_.Ref(DEBUG_LOCATION, "EndpointConnectionAttempter"), endpoint_info.endpoint); - ABSL_FALLTHROUGH_INTENDED; - case GRPC_CHANNEL_CONNECTING: - return PickResult::Queue(); - default: - break; + requested_connection = true; + } } + if (requested_connection) return PickResult::Queue(); } return PickResult::Fail(absl::UnavailableError(absl::StrCat( "ring hash cannot find a connected endpoint; first failure: ", @@ -362,7 +435,7 @@ RingHash::PickResult RingHash::Picker::Pick(PickArgs args) { RingHash::Ring::Ring(RingHash* ring_hash, RingHashLbConfig* config) { // Store the weights while finding the sum. struct EndpointWeight { - std::string address; // Key by endpoint's first address. + std::string hash_key; // By default, endpoint's first address. // Default weight is 1 for the cases where a weight is not provided, // each occurrence of the address will be counted a weight value of 1. uint32_t weight = 1; @@ -374,8 +447,14 @@ RingHash::Ring::Ring(RingHash* ring_hash, RingHashLbConfig* config) { endpoint_weights.reserve(endpoints.size()); for (const auto& endpoint : endpoints) { EndpointWeight endpoint_weight; - endpoint_weight.address = - grpc_sockaddr_to_string(&endpoint.addresses().front(), false).value(); + auto hash_key = + endpoint.args().GetString(GRPC_ARG_RING_HASH_ENDPOINT_HASH_KEY); + if (hash_key.has_value()) { + endpoint_weight.hash_key = std::string(*hash_key); + } else { + endpoint_weight.hash_key = + grpc_sockaddr_to_string(&endpoint.addresses().front(), false).value(); + } // Weight should never be zero, but ignore it just in case, since // that value would screw up the ring-building algorithm. auto weight_arg = endpoint.args().GetInt(GRPC_ARG_ADDRESS_WEIGHT); @@ -425,8 +504,8 @@ RingHash::Ring::Ring(RingHash* ring_hash, RingHashLbConfig* config) { uint64_t min_hashes_per_host = ring_size; uint64_t max_hashes_per_host = 0; for (size_t i = 0; i < endpoints.size(); ++i) { - const std::string& address_string = endpoint_weights[i].address; - hash_key_buffer.assign(address_string.begin(), address_string.end()); + const std::string& hash_key = endpoint_weights[i].hash_key; + hash_key_buffer.assign(hash_key.begin(), hash_key.end()); hash_key_buffer.emplace_back('_'); auto offset_start = hash_key_buffer.end(); target_hashes += scale * endpoint_weights[i].normalized_weight; @@ -652,9 +731,11 @@ absl::Status RingHash::UpdateLocked(UpdateArgs args) { } // Save channel args. args_ = std::move(args.args); + // Save config. + auto* config = DownCast(args.config.get()); + request_hash_header_ = RefCountedStringValue(config->request_hash_header()); // Build new ring. - ring_ = MakeRefCounted( - this, static_cast(args.config.get())); + ring_ = MakeRefCounted(this, config); // Update endpoint map. std::map> endpoint_map; std::vector errors; @@ -853,11 +934,9 @@ class RingHashFactory final : public LoadBalancingPolicyFactory { absl::StatusOr> ParseLoadBalancingConfig(const Json& json) const override { - auto config = LoadFromJson( - json, JsonArgs(), "errors validating ring_hash LB policy config"); - if (!config.ok()) return config.status(); - return MakeRefCounted(config->min_ring_size, - config->max_ring_size); + return LoadFromJson>( + json, RingHashJsonArgs(), + "errors validating ring_hash LB policy config"); } }; diff --git a/src/core/load_balancing/ring_hash/ring_hash.h b/src/core/load_balancing/ring_hash/ring_hash.h index 95f17cbe81dfd..b3e9886202ce7 100644 --- a/src/core/load_balancing/ring_hash/ring_hash.h +++ b/src/core/load_balancing/ring_hash/ring_hash.h @@ -27,6 +27,10 @@ #include "src/core/util/unique_type_name.h" #include "src/core/util/validation_errors.h" +// Optional endpoint attribute specifying the hash key. +#define GRPC_ARG_RING_HASH_ENDPOINT_HASH_KEY \ + GRPC_ARG_NO_SUBCHANNEL_PREFIX "hash_key" + namespace grpc_core { class RequestHashAttribute final @@ -45,17 +49,6 @@ class RequestHashAttribute final uint64_t request_hash_; }; -// Helper Parsing method to parse ring hash policy configs; for example, ring -// hash size validity. -struct RingHashConfig { - uint64_t min_ring_size = 1024; - uint64_t max_ring_size = 4096; - - static const JsonLoaderInterface* JsonLoader(const JsonArgs&); - void JsonPostLoad(const Json& json, const JsonArgs&, - ValidationErrors* errors); -}; - } // namespace grpc_core #endif // GRPC_SRC_CORE_LOAD_BALANCING_RING_HASH_RING_HASH_H diff --git a/src/core/plugin_registry/grpc_plugin_registry.cc b/src/core/plugin_registry/grpc_plugin_registry.cc index 4970e2a012747..a276ea9222355 100644 --- a/src/core/plugin_registry/grpc_plugin_registry.cc +++ b/src/core/plugin_registry/grpc_plugin_registry.cc @@ -62,6 +62,7 @@ extern void RegisterOutlierDetectionLbPolicy( CoreConfiguration::Builder* builder); extern void RegisterWeightedTargetLbPolicy(CoreConfiguration::Builder* builder); extern void RegisterPickFirstLbPolicy(CoreConfiguration::Builder* builder); +extern void RegisterRingHashLbPolicy(CoreConfiguration::Builder* builder); extern void RegisterRoundRobinLbPolicy(CoreConfiguration::Builder* builder); extern void RegisterWeightedRoundRobinLbPolicy( CoreConfiguration::Builder* builder); @@ -102,6 +103,7 @@ void BuildCoreConfiguration(CoreConfiguration::Builder* builder) { RegisterWeightedTargetLbPolicy(builder); RegisterPickFirstLbPolicy(builder); RegisterRoundRobinLbPolicy(builder); + RegisterRingHashLbPolicy(builder); RegisterWeightedRoundRobinLbPolicy(builder); BuildClientChannelConfiguration(builder); SecurityRegisterHandshakerFactories(builder); diff --git a/src/core/plugin_registry/grpc_plugin_registry_extra.cc b/src/core/plugin_registry/grpc_plugin_registry_extra.cc index aa91cba241788..7854b5154ab59 100644 --- a/src/core/plugin_registry/grpc_plugin_registry_extra.cc +++ b/src/core/plugin_registry/grpc_plugin_registry_extra.cc @@ -37,7 +37,6 @@ extern void RegisterCdsLbPolicy(CoreConfiguration::Builder* builder); extern void RegisterXdsOverrideHostLbPolicy( CoreConfiguration::Builder* builder); extern void RegisterXdsWrrLocalityLbPolicy(CoreConfiguration::Builder* builder); -extern void RegisterRingHashLbPolicy(CoreConfiguration::Builder* builder); extern void RegisterFileWatcherCertificateProvider( CoreConfiguration::Builder* builder); extern void RegisterXdsHttpProxyMapper(CoreConfiguration::Builder* builder); @@ -60,7 +59,6 @@ void RegisterExtraFilters(CoreConfiguration::Builder* builder) { RegisterCdsLbPolicy(builder); RegisterXdsOverrideHostLbPolicy(builder); RegisterXdsWrrLocalityLbPolicy(builder); - RegisterRingHashLbPolicy(builder); RegisterFileWatcherCertificateProvider(builder); RegisterXdsHttpProxyMapper(builder); #endif diff --git a/src/core/xds/grpc/xds_endpoint_parser.cc b/src/core/xds/grpc/xds_endpoint_parser.cc index 2bf1f44386365..92a17c6743ad6 100644 --- a/src/core/xds/grpc/xds_endpoint_parser.cc +++ b/src/core/xds/grpc/xds_endpoint_parser.cc @@ -45,8 +45,11 @@ #include "src/core/lib/channel/channel_args.h" #include "src/core/lib/debug/trace.h" #include "src/core/lib/iomgr/resolved_address.h" +#include "src/core/load_balancing/ring_hash/ring_hash.h" #include "src/core/util/down_cast.h" #include "src/core/util/env.h" +#include "src/core/util/json/json_args.h" +#include "src/core/util/json/json_object_loader.h" #include "src/core/util/string.h" #include "src/core/util/upb_utils.h" #include "src/core/util/validation_errors.h" @@ -72,6 +75,16 @@ bool XdsDualstackEndpointsEnabled() { return parse_succeeded && parsed_value; } +// TODO(roth): Flip the default to false once this proves stable, then +// remove it entirely at some point in the future. +bool XdsEndpointHashKeyBackwardCompatEnabled() { + auto value = GetEnv("GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT"); + if (!value.has_value()) return true; + bool parsed_value; + bool parse_succeeded = gpr_parse_bool_value(value->c_str(), &parsed_value); + return parse_succeeded && parsed_value; +} + void MaybeLogClusterLoadAssignment( const XdsResourceType::DecodeContext& context, const envoy_config_endpoint_v3_ClusterLoadAssignment* cla) { @@ -87,21 +100,22 @@ void MaybeLogClusterLoadAssignment( } } -std::string GetProxyAddressFromMetadata( - const XdsResourceType::DecodeContext& context, - const envoy_config_core_v3_Metadata* metadata, ValidationErrors* errors) { - if (XdsHttpConnectEnabled() && metadata != nullptr) { - XdsMetadataMap metadata_map = - ParseXdsMetadataMap(context, metadata, errors); - auto* proxy_address_entry = - metadata_map.Find("envoy.http11_proxy_transport_socket.proxy_address"); - if (proxy_address_entry != nullptr && - proxy_address_entry->type() == XdsAddressMetadataValue::Type()) { - return DownCast(proxy_address_entry) - ->address(); - } - } - return ""; +std::string GetProxyAddressFromMetadata(const XdsMetadataMap& metadata_map) { + auto* proxy_address_entry = metadata_map.FindType( + "envoy.http11_proxy_transport_socket.proxy_address"); + if (proxy_address_entry == nullptr) return ""; + return proxy_address_entry->address(); +} + +std::string GetHashKeyFromMetadata(const XdsMetadataMap& metadata_map) { + auto* hash_key_entry = + metadata_map.FindType("envoy.lb"); + if (hash_key_entry == nullptr) return ""; + ValidationErrors unused_errors; + return LoadJsonObjectField(hash_key_entry->json().object(), + JsonArgs(), "hash_key", + &unused_errors) + .value_or(""); } absl::optional EndpointAddressesParse( @@ -126,9 +140,19 @@ absl::optional EndpointAddressesParse( } } // metadata - std::string proxy_address = GetProxyAddressFromMetadata( - context, envoy_config_endpoint_v3_LbEndpoint_metadata(lb_endpoint), - errors); + std::string proxy_address; + std::string hash_key; + if (XdsHttpConnectEnabled() || !XdsEndpointHashKeyBackwardCompatEnabled()) { + XdsMetadataMap metadata_map = ParseXdsMetadataMap( + context, envoy_config_endpoint_v3_LbEndpoint_metadata(lb_endpoint), + errors); + if (XdsHttpConnectEnabled()) { + proxy_address = GetProxyAddressFromMetadata(metadata_map); + } + if (!XdsEndpointHashKeyBackwardCompatEnabled()) { + hash_key = GetHashKeyFromMetadata(metadata_map); + } + } // endpoint std::vector addresses; absl::string_view hostname; @@ -177,6 +201,9 @@ absl::optional EndpointAddressesParse( } else if (!locality_proxy_address.empty()) { args = args.Set(GRPC_ARG_XDS_HTTP_PROXY, locality_proxy_address); } + if (!hash_key.empty()) { + args = args.Set(GRPC_ARG_RING_HASH_ENDPOINT_HASH_KEY, hash_key); + } return EndpointAddresses(addresses, args); } @@ -231,11 +258,15 @@ absl::optional LocalityParse( parsed_locality.locality.name = MakeRefCounted( std::move(region), std::move(zone), std::move(sub_zone)); // metadata - std::string proxy_address = GetProxyAddressFromMetadata( - context, - envoy_config_endpoint_v3_LocalityLbEndpoints_metadata( - locality_lb_endpoints), - errors); + std::string proxy_address; + if (XdsHttpConnectEnabled()) { + XdsMetadataMap metadata_map = ParseXdsMetadataMap( + context, + envoy_config_endpoint_v3_LocalityLbEndpoints_metadata( + locality_lb_endpoints), + errors); + proxy_address = GetProxyAddressFromMetadata(metadata_map); + } // lb_endpoints size_t size; const envoy_config_endpoint_v3_LbEndpoint* const* lb_endpoints = diff --git a/src/core/xds/grpc/xds_metadata.h b/src/core/xds/grpc/xds_metadata.h index 3e8fa4f214b79..419540cb3631b 100644 --- a/src/core/xds/grpc/xds_metadata.h +++ b/src/core/xds/grpc/xds_metadata.h @@ -60,6 +60,14 @@ class XdsMetadataMap { const XdsMetadataValue* Find(absl::string_view key) const; + template + const T* FindType(absl::string_view key) const { + auto it = map_.find(key); + if (it == map_.end()) return nullptr; + if (it->second->type() != T::Type()) return nullptr; + return DownCast(it->second.get()); + } + bool empty() const { return map_.empty(); } size_t size() const { return map_.size(); } diff --git a/test/core/load_balancing/BUILD b/test/core/load_balancing/BUILD index 022d2ca07691d..2c89018f7d56e 100644 --- a/test/core/load_balancing/BUILD +++ b/test/core/load_balancing/BUILD @@ -265,6 +265,7 @@ grpc_cc_test( "//src/core:channel_args", "//src/core:grpc_lb_policy_ring_hash", "//test/core/test_util:grpc_test_util", + "//test/core/test_util:scoped_env_var", ], ) diff --git a/test/core/load_balancing/lb_policy_test_lib.h b/test/core/load_balancing/lb_policy_test_lib.h index f784f925a6cbe..1194b02e2d548 100644 --- a/test/core/load_balancing/lb_policy_test_lib.h +++ b/test/core/load_balancing/lb_policy_test_lib.h @@ -948,7 +948,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { EXPECT_TRUE(update.status.ok()) << update.status << " at " << location.file() << ":" << location.line(); - ExpectPickQueued(update.picker.get(), {}, location); + ExpectPickQueued(update.picker.get(), {}, {}, location); return true; // Keep going. } EXPECT_EQ(update.state, GRPC_CHANNEL_READY) @@ -992,7 +992,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { EXPECT_TRUE(update.status.ok()) << update.status << " at " << location.file() << ":" << location.line(); - ExpectPickQueued(update.picker.get(), {}, location); + ExpectPickQueued(update.picker.get(), {}, {}, location); return true; // Keep going. } EXPECT_EQ(update.state, GRPC_CHANNEL_TRANSIENT_FAILURE) @@ -1059,7 +1059,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { absl::Status expected_status = absl::OkStatus(), SourceLocation location = SourceLocation()) { auto picker = ExpectState(expected_state, expected_status, location); - return ExpectPickQueued(picker.get(), {}, location); + return ExpectPickQueued(picker.get(), {}, {}, location); } // Convenient frontend to ExpectStateAndQueuingPicker() for CONNECTING. @@ -1076,20 +1076,22 @@ class LoadBalancingPolicyTest : public ::testing::Test { // Does a pick and returns the result. LoadBalancingPolicy::PickResult DoPick( LoadBalancingPolicy::SubchannelPicker* picker, - const CallAttributes& call_attributes = {}) { + const CallAttributes& call_attributes = {}, + const std::map& metadata = {}) { ExecCtx exec_ctx; - FakeMetadata metadata({}); + FakeMetadata md(metadata); FakeCallState call_state(call_attributes); - return picker->Pick({"/service/method", &metadata, &call_state}); + return picker->Pick({"/service/method", &md, &call_state}); } // Requests a pick on picker and expects a Queue result. bool ExpectPickQueued(LoadBalancingPolicy::SubchannelPicker* picker, const CallAttributes call_attributes = {}, + const std::map& metadata = {}, SourceLocation location = SourceLocation()) { EXPECT_NE(picker, nullptr) << location.file() << ":" << location.line(); if (picker == nullptr) return false; - auto pick_result = DoPick(picker, call_attributes); + auto pick_result = DoPick(picker, call_attributes, metadata); EXPECT_TRUE(absl::holds_alternative( pick_result.result)) << PickResultString(pick_result) << "\nat " << location.file() << ":" @@ -1108,6 +1110,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { absl::optional ExpectPickComplete( LoadBalancingPolicy::SubchannelPicker* picker, const CallAttributes& call_attributes = {}, + const std::map& metadata = {}, std::unique_ptr* subchannel_call_tracker = nullptr, SubchannelState::FakeSubchannel** picked_subchannel = nullptr, @@ -1116,7 +1119,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { if (picker == nullptr) { return absl::nullopt; } - auto pick_result = DoPick(picker, call_attributes); + auto pick_result = DoPick(picker, call_attributes, metadata); auto* complete = absl::get_if( &pick_result.result); EXPECT_NE(complete, nullptr) << PickResultString(pick_result) << " at " @@ -1167,6 +1170,7 @@ class LoadBalancingPolicyTest : public ::testing::Test { std::unique_ptr subchannel_call_tracker; auto address = ExpectPickComplete(picker, call_attributes, + /*metadata=*/{}, subchannel_call_trackers == nullptr ? nullptr : &subchannel_call_tracker, diff --git a/test/core/load_balancing/outlier_detection_test.cc b/test/core/load_balancing/outlier_detection_test.cc index d73a4ed54bf0d..40a8cfc280ff3 100644 --- a/test/core/load_balancing/outlier_detection_test.cc +++ b/test/core/load_balancing/outlier_detection_test.cc @@ -154,7 +154,7 @@ class OutlierDetectionTest : public LoadBalancingPolicyTest { LoadBalancingPolicy::SubchannelPicker* picker) { std::unique_ptr subchannel_call_tracker; - auto address = ExpectPickComplete(picker, {}, &subchannel_call_tracker); + auto address = ExpectPickComplete(picker, {}, {}, &subchannel_call_tracker); if (address.has_value()) { subchannel_call_tracker->Start(); FakeMetadata metadata({}); diff --git a/test/core/load_balancing/ring_hash_test.cc b/test/core/load_balancing/ring_hash_test.cc index 522f5794921dc..c0b8dcaa3747a 100644 --- a/test/core/load_balancing/ring_hash_test.cc +++ b/test/core/load_balancing/ring_hash_test.cc @@ -38,6 +38,7 @@ #include "src/core/util/ref_counted_ptr.h" #include "src/core/util/xxhash_inline.h" #include "test/core/load_balancing/lb_policy_test_lib.h" +#include "test/core/test_util/scoped_env_var.h" #include "test/core/test_util/test_config.h" namespace grpc_core { @@ -53,7 +54,8 @@ class RingHashTest : public LoadBalancingPolicyTest { RingHashTest() : LoadBalancingPolicyTest("ring_hash_experimental") {} static RefCountedPtr MakeRingHashConfig( - int min_ring_size = 0, int max_ring_size = 0) { + int min_ring_size = 0, int max_ring_size = 0, + const std::string& request_hash_header = "") { Json::Object fields; if (min_ring_size > 0) { fields["minRingSize"] = Json::FromString(absl::StrCat(min_ring_size)); @@ -61,19 +63,25 @@ class RingHashTest : public LoadBalancingPolicyTest { if (max_ring_size > 0) { fields["maxRingSize"] = Json::FromString(absl::StrCat(max_ring_size)); } + if (!request_hash_header.empty()) { + fields["requestHashHeader"] = Json::FromString(request_hash_header); + } return MakeConfig(Json::FromArray({Json::FromObject( {{"ring_hash_experimental", Json::FromObject(fields)}})})); } - RequestHashAttribute* MakeHashAttribute(absl::string_view address) { - std::string hash_input = - absl::StrCat(absl::StripPrefix(address, "ipv4:"), "_0"); - uint64_t hash = XXH64(hash_input.data(), hash_input.size(), 0); + RequestHashAttribute* MakeHashAttributeForString(absl::string_view key) { + std::string key_str = absl::StrCat(key, "_0"); + uint64_t hash = XXH64(key_str.data(), key_str.size(), 0); attribute_storage_.emplace_back( std::make_unique(hash)); return attribute_storage_.back().get(); } + RequestHashAttribute* MakeHashAttribute(absl::string_view address) { + return MakeHashAttributeForString(absl::StripPrefix(address, "ipv4:")); + } + std::vector> attribute_storage_; }; @@ -94,6 +102,8 @@ TEST_F(RingHashTest, Basic) { subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); picker = ExpectState(GRPC_CHANNEL_CONNECTING); ExpectPickQueued(picker.get(), {address0_attribute}); + EXPECT_EQ(nullptr, FindSubchannel(kAddresses[1])); + EXPECT_EQ(nullptr, FindSubchannel(kAddresses[2])); subchannel->SetConnectivityState(GRPC_CHANNEL_READY); picker = ExpectState(GRPC_CHANNEL_READY); auto address = ExpectPickComplete(picker.get(), {address0_attribute}); @@ -182,6 +192,131 @@ TEST_F(RingHashTest, MultipleAddressesPerEndpoint) { EXPECT_EQ(address, kEndpoint1Addresses[1]); } +TEST_F(RingHashTest, EndpointHashKeys) { + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + const std::array kHashKeys = {"foo", "bar", "baz"}; + std::vector endpoints; + for (size_t i = 0; i < 3; ++i) { + endpoints.push_back(MakeEndpointAddresses( + {kAddresses[i]}, + ChannelArgs().Set(GRPC_ARG_RING_HASH_ENDPOINT_HASH_KEY, kHashKeys[i]))); + }; + EXPECT_EQ( + ApplyUpdate(BuildUpdate(endpoints, MakeRingHashConfig()), lb_policy()), + absl::OkStatus()); + auto picker = ExpectState(GRPC_CHANNEL_IDLE); + auto* hash_attribute = MakeHashAttributeForString(kHashKeys[1]); + ExpectPickQueued(picker.get(), {hash_attribute}); + WaitForWorkSerializerToFlush(); + WaitForWorkSerializerToFlush(); + auto* subchannel = FindSubchannel(kAddresses[1]); + ASSERT_NE(subchannel, nullptr); + EXPECT_TRUE(subchannel->ConnectionRequested()); + subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); + picker = ExpectState(GRPC_CHANNEL_CONNECTING); + ExpectPickQueued(picker.get(), {hash_attribute}); + EXPECT_EQ(nullptr, FindSubchannel(kAddresses[0])); + EXPECT_EQ(nullptr, FindSubchannel(kAddresses[2])); + subchannel->SetConnectivityState(GRPC_CHANNEL_READY); + picker = ExpectState(GRPC_CHANNEL_READY); + auto address = ExpectPickComplete(picker.get(), {hash_attribute}); + EXPECT_EQ(address, kAddresses[1]); +} + +TEST_F(RingHashTest, PickFailsWithoutRequestHashAttribute) { + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + EXPECT_EQ( + ApplyUpdate(BuildUpdate(kAddresses, MakeRingHashConfig()), lb_policy()), + absl::OkStatus()); + auto picker = ExpectState(GRPC_CHANNEL_IDLE); + ExpectPickFail(picker.get(), [&](const absl::Status& status) { + EXPECT_EQ(status, absl::InternalError("hash attribute not present")); + }); +} + +TEST_F(RingHashTest, RequestHashHeaderNotEnabled) { + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + EXPECT_EQ( + ApplyUpdate(BuildUpdate(kAddresses, MakeRingHashConfig(0, 0, "foo")), + lb_policy()), + absl::OkStatus()); + auto picker = ExpectState(GRPC_CHANNEL_IDLE); + ExpectPickFail(picker.get(), [&](const absl::Status& status) { + EXPECT_EQ(status, absl::InternalError("hash attribute not present")); + }); +} + +TEST_F(RingHashTest, RequestHashHeader) { + ScopedExperimentalEnvVar env( + "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY"); + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + EXPECT_EQ( + ApplyUpdate(BuildUpdate(kAddresses, MakeRingHashConfig(0, 0, "foo")), + lb_policy()), + absl::OkStatus()); + auto picker = ExpectState(GRPC_CHANNEL_IDLE); + std::string hash_key = + absl::StrCat(absl::StripPrefix(kAddresses[0], "ipv4:"), "_0"); + std::map metadata = {{"foo", hash_key}}; + ExpectPickQueued(picker.get(), /*call_attributes=*/{}, metadata); + WaitForWorkSerializerToFlush(); + WaitForWorkSerializerToFlush(); + auto* subchannel = FindSubchannel(kAddresses[0]); + ASSERT_NE(subchannel, nullptr); + EXPECT_TRUE(subchannel->ConnectionRequested()); + subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); + picker = ExpectState(GRPC_CHANNEL_CONNECTING); + ExpectPickQueued(picker.get(), {}, metadata); + EXPECT_EQ(nullptr, FindSubchannel(kAddresses[1])); + EXPECT_EQ(nullptr, FindSubchannel(kAddresses[2])); + subchannel->SetConnectivityState(GRPC_CHANNEL_READY); + picker = ExpectState(GRPC_CHANNEL_READY); + auto address = ExpectPickComplete(picker.get(), {}, metadata); + EXPECT_EQ(address, kAddresses[0]); +} + +TEST_F(RingHashTest, RequestHashHeaderNotPresent) { + ScopedExperimentalEnvVar env( + "GRPC_EXPERIMENTAL_RING_HASH_SET_REQUEST_HASH_KEY"); + const std::array kAddresses = { + "ipv4:127.0.0.1:441", "ipv4:127.0.0.1:442", "ipv4:127.0.0.1:443"}; + EXPECT_EQ( + ApplyUpdate(BuildUpdate(kAddresses, MakeRingHashConfig(0, 0, "foo")), + lb_policy()), + absl::OkStatus()); + auto picker = ExpectState(GRPC_CHANNEL_IDLE); + ExpectPickQueued(picker.get()); + WaitForWorkSerializerToFlush(); + WaitForWorkSerializerToFlush(); + // It will randomly pick one. + size_t index = 0; + SubchannelState* subchannel = nullptr; + for (; index < kAddresses.size(); ++index) { + subchannel = FindSubchannel(kAddresses[index]); + if (subchannel != nullptr) { + LOG(INFO) << "Randomly picked subchannel index " << index; + break; + } + } + ASSERT_NE(subchannel, nullptr); + EXPECT_TRUE(subchannel->ConnectionRequested()); + subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); + picker = ExpectState(GRPC_CHANNEL_CONNECTING); + ExpectPickQueued(picker.get()); + // No other subchannels should have been created yet. + for (size_t i = 0; i < kAddresses.size(); ++i) { + if (i != index) EXPECT_EQ(nullptr, FindSubchannel(kAddresses[i])); + } + subchannel->SetConnectivityState(GRPC_CHANNEL_READY); + picker = ExpectState(GRPC_CHANNEL_READY); + auto address = ExpectPickComplete(picker.get()); + EXPECT_EQ(address, kAddresses[index]); +} + } // namespace } // namespace testing } // namespace grpc_core diff --git a/test/core/load_balancing/xds_override_host_test.cc b/test/core/load_balancing/xds_override_host_test.cc index e487b6ab32366..7f65f5a222196 100644 --- a/test/core/load_balancing/xds_override_host_test.cc +++ b/test/core/load_balancing/xds_override_host_test.cc @@ -183,7 +183,7 @@ class XdsOverrideHostTest : public LoadBalancingPolicyTest { } std::string expected_addresses_str = absl::StrJoin(expected_addresses, ","); for (size_t i = 0; i < 3; ++i) { - EXPECT_EQ(ExpectPickComplete(picker, {attribute}, + EXPECT_EQ(ExpectPickComplete(picker, {attribute}, /*metadata=*/{}, /*subchannel_call_tracker=*/nullptr, /*picked_subchannel=*/nullptr, location), expected) @@ -202,9 +202,10 @@ class XdsOverrideHostTest : public LoadBalancingPolicyTest { SourceLocation location = SourceLocation()) { std::vector actual_picks; for (size_t i = 0; i < expected.size(); ++i) { - auto address = ExpectPickComplete( - picker, {attribute}, /*subchannel_call_tracker=*/nullptr, - /*picked_subchannel=*/nullptr, location); + auto address = + ExpectPickComplete(picker, {attribute}, /*metadata=*/{}, + /*subchannel_call_tracker=*/nullptr, + /*picked_subchannel=*/nullptr, location); ASSERT_TRUE(address.has_value()) << location.file() << ":" << location.line(); EXPECT_THAT(*address, ::testing::AnyOfArray(expected)) diff --git a/test/cpp/end2end/xds/BUILD b/test/cpp/end2end/xds/BUILD index 22e9441ac6a12..77e8e4f99fa90 100644 --- a/test/cpp/end2end/xds/BUILD +++ b/test/cpp/end2end/xds/BUILD @@ -414,6 +414,7 @@ grpc_cc_test( "//:grpc++", "//:grpc_resolver_fake", "//test/core/test_util:grpc_test_util", + "//test/core/test_util:scoped_env_var", "//test/cpp/end2end:connection_attempt_injector", "@envoy_api//envoy/config/cluster/v3:pkg_cc_proto", "@envoy_api//envoy/config/endpoint/v3:pkg_cc_proto", diff --git a/test/cpp/end2end/xds/xds_end2end_test_lib.h b/test/cpp/end2end/xds/xds_end2end_test_lib.h index bfed1231aa664..5a91d4554e03c 100644 --- a/test/cpp/end2end/xds/xds_end2end_test_lib.h +++ b/test/cpp/end2end/xds/xds_end2end_test_lib.h @@ -487,7 +487,8 @@ class XdsEnd2endTest : public ::testing::TestWithParam, ::envoy::config::core::v3::HealthStatus health_status = ::envoy::config::core::v3::HealthStatus::UNKNOWN, int lb_weight = 1, std::vector additional_backend_indexes = {}, - absl::string_view hostname = "") { + absl::string_view hostname = "", + const std::map& metadata = {}) { std::vector additional_ports; additional_ports.reserve(additional_backend_indexes.size()); for (size_t idx : additional_backend_indexes) { @@ -495,7 +496,7 @@ class XdsEnd2endTest : public ::testing::TestWithParam, } return EdsResourceArgs::Endpoint(backends_[backend_idx]->port(), health_status, lb_weight, additional_ports, - hostname); + hostname, metadata); } // Creates a vector of endpoints for a specified range of backends, diff --git a/test/cpp/end2end/xds/xds_ring_hash_end2end_test.cc b/test/cpp/end2end/xds/xds_ring_hash_end2end_test.cc index 75afa6c12bfb3..4755ac0349aea 100644 --- a/test/cpp/end2end/xds/xds_ring_hash_end2end_test.cc +++ b/test/cpp/end2end/xds/xds_ring_hash_end2end_test.cc @@ -33,6 +33,7 @@ #include "src/core/resolver/fake/fake_resolver.h" #include "src/core/util/env.h" #include "test/core/test_util/resolve_localhost_ip46.h" +#include "test/core/test_util/scoped_env_var.h" #include "test/cpp/end2end/connection_attempt_injector.h" #include "test/cpp/end2end/xds/xds_end2end_test_lib.h" @@ -455,6 +456,160 @@ TEST_P(RingHashTest, HeaderHashingWithRegexRewrite) { EXPECT_TRUE(found); } +TEST_P(RingHashTest, HashKeysInEds) { + grpc_core::testing::ScopedEnvVar env( + "GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT", "false"); + CreateAndStartBackends(4); + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancer_->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("address_hash"); + SetListenerAndRouteConfiguration(balancer_.get(), default_listener_, + new_route_config); + EdsResourceArgs args( + {{"locality0", + { + CreateEndpoint(0, + /*health_status=*/ + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + /*lb_weight=*/1, /*additional_backend_indexes=*/{}, + /*hostname=*/"", + {{"envoy.lb", "{\"hash_key\":\"foo\"}"}}), + CreateEndpoint(1, + /*health_status=*/ + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + /*lb_weight=*/1, /*additional_backend_indexes=*/{}, + /*hostname=*/"", + {{"envoy.lb", "{\"hash_key\":\"bar\"}"}}), + CreateEndpoint(2, + /*health_status=*/ + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + /*lb_weight=*/1, /*additional_backend_indexes=*/{}, + /*hostname=*/"", + {{"envoy.lb", "{\"hash_key\":\"baz\"}"}}), + CreateEndpoint(3, + /*health_status=*/ + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + /*lb_weight=*/1, /*additional_backend_indexes=*/{}, + /*hostname=*/"", + {{"envoy.lb", "{\"hash_key\":\"quux\"}"}}), + }}}); + balancer_->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Note each type of RPC will contains a header value that will always be + // hashed to a specific backend as the header value matches the value used + // to create the entry in the ring. + std::vector> metadata = { + {"address_hash", "foo_0"}}; + std::vector> metadata1 = { + {"address_hash", "bar_0"}}; + std::vector> metadata2 = { + {"address_hash", "baz_0"}}; + std::vector> metadata3 = { + {"address_hash", "quux_0"}}; + const auto rpc_options = + RpcOptions().set_metadata(std::move(metadata)).set_timeout_ms(5000); + const auto rpc_options1 = + RpcOptions().set_metadata(std::move(metadata1)).set_timeout_ms(5000); + const auto rpc_options2 = + RpcOptions().set_metadata(std::move(metadata2)).set_timeout_ms(5000); + const auto rpc_options3 = + RpcOptions().set_metadata(std::move(metadata3)).set_timeout_ms(5000); + WaitForBackend(DEBUG_LOCATION, 0, /*check_status=*/nullptr, + WaitForBackendOptions(), rpc_options); + WaitForBackend(DEBUG_LOCATION, 1, /*check_status=*/nullptr, + WaitForBackendOptions(), rpc_options1); + WaitForBackend(DEBUG_LOCATION, 2, /*check_status=*/nullptr, + WaitForBackendOptions(), rpc_options2); + WaitForBackend(DEBUG_LOCATION, 3, /*check_status=*/nullptr, + WaitForBackendOptions(), rpc_options3); + CheckRpcSendOk(DEBUG_LOCATION, 100, rpc_options); + CheckRpcSendOk(DEBUG_LOCATION, 100, rpc_options1); + CheckRpcSendOk(DEBUG_LOCATION, 100, rpc_options2); + CheckRpcSendOk(DEBUG_LOCATION, 100, rpc_options3); + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(100, backends_[i]->backend_service()->request_count()); + } +} + +TEST_P(RingHashTest, HashKeysInEdsNotEnabled) { + CreateAndStartBackends(4); + auto cluster = default_cluster_; + cluster.set_lb_policy(Cluster::RING_HASH); + balancer_->ads_service()->SetCdsResource(cluster); + auto new_route_config = default_route_config_; + auto* route = new_route_config.mutable_virtual_hosts(0)->mutable_routes(0); + auto* hash_policy = route->mutable_route()->add_hash_policy(); + hash_policy->mutable_header()->set_header_name("address_hash"); + SetListenerAndRouteConfiguration(balancer_.get(), default_listener_, + new_route_config); + EdsResourceArgs args( + {{"locality0", + { + CreateEndpoint(0, + /*health_status=*/ + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + /*lb_weight=*/1, /*additional_backend_indexes=*/{}, + /*hostname=*/"", + {{"envoy.lb", "{\"hash_key\":\"foo\"}"}}), + CreateEndpoint(1, + /*health_status=*/ + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + /*lb_weight=*/1, /*additional_backend_indexes=*/{}, + /*hostname=*/"", + {{"envoy.lb", "{\"hash_key\":\"bar\"}"}}), + CreateEndpoint(2, + /*health_status=*/ + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + /*lb_weight=*/1, /*additional_backend_indexes=*/{}, + /*hostname=*/"", + {{"envoy.lb", "{\"hash_key\":\"baz\"}"}}), + CreateEndpoint(3, + /*health_status=*/ + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + /*lb_weight=*/1, /*additional_backend_indexes=*/{}, + /*hostname=*/"", + {{"envoy.lb", "{\"hash_key\":\"quux\"}"}}), + }}}); + balancer_->ads_service()->SetEdsResource(BuildEdsResource(args)); + // Note each type of RPC will contains a header value that will always be + // hashed to a specific backend as the header value matches the value used + // to create the entry in the ring. + std::vector> metadata = { + {"address_hash", CreateMetadataValueThatHashesToBackend(0)}}; + std::vector> metadata1 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(1)}}; + std::vector> metadata2 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(2)}}; + std::vector> metadata3 = { + {"address_hash", CreateMetadataValueThatHashesToBackend(3)}}; + const auto rpc_options = + RpcOptions().set_metadata(std::move(metadata)).set_timeout_ms(5000); + const auto rpc_options1 = + RpcOptions().set_metadata(std::move(metadata1)).set_timeout_ms(5000); + const auto rpc_options2 = + RpcOptions().set_metadata(std::move(metadata2)).set_timeout_ms(5000); + const auto rpc_options3 = + RpcOptions().set_metadata(std::move(metadata3)).set_timeout_ms(5000); + WaitForBackend(DEBUG_LOCATION, 0, /*check_status=*/nullptr, + WaitForBackendOptions(), rpc_options); + WaitForBackend(DEBUG_LOCATION, 1, /*check_status=*/nullptr, + WaitForBackendOptions(), rpc_options1); + WaitForBackend(DEBUG_LOCATION, 2, /*check_status=*/nullptr, + WaitForBackendOptions(), rpc_options2); + WaitForBackend(DEBUG_LOCATION, 3, /*check_status=*/nullptr, + WaitForBackendOptions(), rpc_options3); + CheckRpcSendOk(DEBUG_LOCATION, 100, rpc_options); + CheckRpcSendOk(DEBUG_LOCATION, 100, rpc_options1); + CheckRpcSendOk(DEBUG_LOCATION, 100, rpc_options2); + CheckRpcSendOk(DEBUG_LOCATION, 100, rpc_options3); + for (size_t i = 0; i < backends_.size(); ++i) { + EXPECT_EQ(100, backends_[i]->backend_service()->request_count()); + } +} + // Tests that ring hash policy that hashes using a random value. TEST_P(RingHashTest, NoHashPolicy) { CreateAndStartBackends(2); @@ -641,7 +796,7 @@ TEST_P(RingHashTest, UnsupportedHashPolicyDefaultToRandomHashing) { } // Tests that ring hash policy that hashes using a random value can spread -// RPCs across all the backends according to locality weight. +// RPCs across all the backends according to endpoint weight. TEST_P(RingHashTest, RandomHashingDistributionAccordingToEndpointWeight) { CreateAndStartBackends(2); const size_t kWeight1 = 1; @@ -680,7 +835,7 @@ TEST_P(RingHashTest, RandomHashingDistributionAccordingToEndpointWeight) { } // Tests that ring hash policy that hashes using a random value can spread -// RPCs across all the backends according to locality weight. +// RPCs across all the backends according to locality and endpoint weight. TEST_P(RingHashTest, RandomHashingDistributionAccordingToLocalityAndEndpointWeight) { CreateAndStartBackends(2); diff --git a/test/cpp/end2end/xds/xds_utils.cc b/test/cpp/end2end/xds/xds_utils.cc index ad8f7a0c02abb..1608099a9dd76 100644 --- a/test/cpp/end2end/xds/xds_utils.cc +++ b/test/cpp/end2end/xds/xds_utils.cc @@ -26,6 +26,7 @@ #include #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -370,6 +371,16 @@ ClusterLoadAssignment XdsResourceUtils::BuildEdsResource( if (!endpoint.hostname.empty()) { endpoint_proto->set_hostname(endpoint.hostname); } + if (!endpoint.metadata.empty()) { + auto& filter_map = + *lb_endpoints->mutable_metadata()->mutable_filter_metadata(); + for (const auto& p : endpoint.metadata) { + absl::Status status = grpc::protobuf::json::JsonStringToMessage( + p.second, &filter_map[p.first], + grpc::protobuf::json::JsonParseOptions()); + CHECK(status.ok()) << status; + } + } } } if (!args.drop_categories.empty()) { diff --git a/test/cpp/end2end/xds/xds_utils.h b/test/cpp/end2end/xds/xds_utils.h index 194827664562c..b8d13ee436cce 100644 --- a/test/cpp/end2end/xds/xds_utils.h +++ b/test/cpp/end2end/xds/xds_utils.h @@ -214,23 +214,26 @@ class XdsResourceUtils { struct EdsResourceArgs { // An individual endpoint for a backend running on a specified port. struct Endpoint { - explicit Endpoint(int port, - ::envoy::config::core::v3::HealthStatus health_status = - ::envoy::config::core::v3::HealthStatus::UNKNOWN, - int lb_weight = 1, - std::vector additional_ports = {}, - absl::string_view hostname = "") + explicit Endpoint( + int port, + ::envoy::config::core::v3::HealthStatus health_status = + ::envoy::config::core::v3::HealthStatus::UNKNOWN, + int lb_weight = 1, std::vector additional_ports = {}, + absl::string_view hostname = "", + const std::map& metadata = {}) : port(port), health_status(health_status), lb_weight(lb_weight), additional_ports(std::move(additional_ports)), - hostname(hostname) {} + hostname(hostname), + metadata(metadata) {} int port; ::envoy::config::core::v3::HealthStatus health_status; int lb_weight; std::vector additional_ports; std::string hostname; + std::map metadata; }; // A locality.