Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace RequestData with BoundConfiguration #141

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ public ExposingPipelineFactory(TConfiguration configuration)
private TConfiguration Configuration { get; }
public ITransport<TConfiguration> Transport { get; }

public override RequestPipeline Create(RequestData requestData) =>
new RequestPipeline(requestData);
public override RequestPipeline Create(BoundConfiguration boundConfiguration) => new(boundConfiguration);
}
#nullable restore
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ private void UpdateCluster(VirtualCluster cluster)
private bool IsPingRequest(Endpoint endpoint) => _productRegistration.IsPingRequest(endpoint);

/// <inheritdoc cref="IRequestInvoker.RequestAsync{TResponse}"/>>
public Task<TResponse> RequestAsync<TResponse>(Endpoint endpoint, RequestData requestData, PostData? postData, CancellationToken cancellationToken)
public Task<TResponse> RequestAsync<TResponse>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, CancellationToken cancellationToken)
where TResponse : TransportResponse, new() =>
Task.FromResult(Request<TResponse>(endpoint, requestData, postData));
Task.FromResult(Request<TResponse>(endpoint, boundConfiguration, postData));

/// <inheritdoc cref="IRequestInvoker.Request{TResponse}"/>>
public TResponse Request<TResponse>(Endpoint endpoint, RequestData requestData, PostData? postData)
public TResponse Request<TResponse>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData)
where TResponse : TransportResponse, new()
{
if (!_calls.ContainsKey(endpoint.Uri.Port))
Expand All @@ -138,11 +138,11 @@ public TResponse Request<TResponse>(Endpoint endpoint, RequestData requestData,
_ = Interlocked.Increment(ref state.Sniffed);
return HandleRules<TResponse, ISniffRule>(
endpoint,
requestData,
boundConfiguration,
postData,
nameof(VirtualCluster.Sniff),
_cluster.SniffingRules,
requestData.RequestTimeout,
boundConfiguration.RequestTimeout,
(r) => UpdateCluster(r.NewClusterState),
(r) => _productRegistration.CreateSniffResponseBytes(_cluster.Nodes, _cluster.ElasticsearchVersion, _cluster.PublishAddressOverride, _cluster.SniffShouldReturnFqnd)
);
Expand All @@ -152,36 +152,36 @@ public TResponse Request<TResponse>(Endpoint endpoint, RequestData requestData,
_ = Interlocked.Increment(ref state.Pinged);
return HandleRules<TResponse, IRule>(
endpoint,
requestData,
boundConfiguration,
postData,
nameof(VirtualCluster.Ping),
_cluster.PingingRules,
requestData.PingTimeout,
boundConfiguration.PingTimeout,
(r) => { },
(r) => null //HEAD request
);
}
_ = Interlocked.Increment(ref state.Called);
return HandleRules<TResponse, IClientCallRule>(
endpoint,
requestData,
boundConfiguration,
postData,
nameof(VirtualCluster.ClientCalls),
_cluster.ClientCallRules,
requestData.RequestTimeout,
boundConfiguration.RequestTimeout,
(r) => { },
CallResponse
);
}
catch (TheException e)
{
return ResponseFactory.Create<TResponse>(endpoint, requestData, postData, e, null, null, Stream.Null, null, -1, null, null);
return ResponseFactory.Create<TResponse>(endpoint, boundConfiguration, postData, e, null, null, Stream.Null, null, -1, null, null);
}
}

private TResponse HandleRules<TResponse, TRule>(
Endpoint endpoint,
RequestData requestData,
BoundConfiguration boundConfiguration,
PostData? postData,
string origin,
IList<TRule> rules,
Expand All @@ -203,28 +203,28 @@ private TResponse HandleRules<TResponse, TRule>(
if (rule.OnPort == null || rule.OnPort.Value != endpoint.Uri.Port) continue;

if (always)
return Always<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
return Always<TResponse, TRule>(endpoint, boundConfiguration, postData, timeout, beforeReturn, successResponse, rule);

if (rule.ExecuteCount > times) continue;

return Sometimes<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
return Sometimes<TResponse, TRule>(endpoint, boundConfiguration, postData, timeout, beforeReturn, successResponse, rule);
}
foreach (var rule in rules.Where(s => !s.OnPort.HasValue))
{
var always = rule.Times.Match(t => true, t => false);
var times = rule.Times.Match(t => -1, t => t);
if (always)
return Always<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
return Always<TResponse, TRule>(endpoint, boundConfiguration, postData, timeout, beforeReturn, successResponse, rule);

if (rule.ExecuteCount > times) continue;

return Sometimes<TResponse, TRule>(endpoint, requestData, postData, timeout, beforeReturn, successResponse, rule);
return Sometimes<TResponse, TRule>(endpoint, boundConfiguration, postData, timeout, beforeReturn, successResponse, rule);
}
var count = _calls.Select(kv => kv.Value.Called).Sum();
throw new Exception($@"No global or port specific {origin} rule ({endpoint.Uri.Port}) matches any longer after {count} calls in to the cluster");
}

private TResponse Always<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
private TResponse Always<TResponse, TRule>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
)
where TResponse : TransportResponse, new()
where TRule : IRule
Expand All @@ -233,20 +233,20 @@ private TResponse Always<TResponse, TRule>(Endpoint endpoint, RequestData reques
{
var time = timeout < rule.Takes.Value ? timeout : rule.Takes.Value;
_dateTimeProvider.ChangeTime(d => d.Add(time));
if (rule.Takes.Value > requestData.RequestTimeout)
if (rule.Takes.Value > boundConfiguration.RequestTimeout)
{
throw new TheException(
$"Request timed out after {time} : call configured to take {rule.Takes.Value} while requestTimeout was: {timeout}");
}
}

return rule.Succeeds
? Success<TResponse, TRule>(endpoint, requestData, postData, beforeReturn, successResponse, rule)
: Fail<TResponse, TRule>(endpoint, requestData, postData, rule);
? Success<TResponse, TRule>(endpoint, boundConfiguration, postData, beforeReturn, successResponse, rule)
: Fail<TResponse, TRule>(endpoint, boundConfiguration, postData, rule);
}

private TResponse Sometimes<TResponse, TRule>(
Endpoint endpoint, RequestData requestData, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, TimeSpan timeout, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse, TRule rule
)
where TResponse : TransportResponse, new()
where TRule : IRule
Expand All @@ -255,20 +255,20 @@ private TResponse Sometimes<TResponse, TRule>(
{
var time = timeout < rule.Takes.Value ? timeout : rule.Takes.Value;
_dateTimeProvider.ChangeTime(d => d.Add(time));
if (rule.Takes.Value > requestData.RequestTimeout)
if (rule.Takes.Value > boundConfiguration.RequestTimeout)
{
throw new TheException(
$"Request timed out after {time} : call configured to take {rule.Takes.Value} while requestTimeout was: {timeout}");
}
}

if (rule.Succeeds)
return Success<TResponse, TRule>(endpoint, requestData, postData, beforeReturn, successResponse, rule);
return Success<TResponse, TRule>(endpoint, boundConfiguration, postData, beforeReturn, successResponse, rule);

return Fail<TResponse, TRule>(endpoint, requestData, postData, rule);
return Fail<TResponse, TRule>(endpoint, boundConfiguration, postData, rule);
}

private TResponse Fail<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, TRule rule, RuleOption<Exception, int>? returnOverride = null)
private TResponse Fail<TResponse, TRule>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, TRule rule, RuleOption<Exception, int>? returnOverride = null)
where TResponse : TransportResponse, new()
where TRule : IRule
{
Expand All @@ -282,13 +282,13 @@ private TResponse Fail<TResponse, TRule>(Endpoint endpoint, RequestData requestD

return ret.Match(
e => throw e,
statusCode => _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, requestData, postData, CallResponse(rule),
statusCode => _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, boundConfiguration, postData, CallResponse(rule),
//make sure we never return a valid status code in Fail responses because of a bad rule.
statusCode >= 200 && statusCode < 300 ? 502 : statusCode, rule.ReturnContentType)
);
}

private TResponse Success<TResponse, TRule>(Endpoint endpoint, RequestData requestData, PostData? postData, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse,
private TResponse Success<TResponse, TRule>(Endpoint endpoint, BoundConfiguration boundConfiguration, PostData? postData, Action<TRule> beforeReturn, Func<TRule, byte[]?> successResponse,
TRule rule
)
where TResponse : TransportResponse, new()
Expand All @@ -299,7 +299,7 @@ TRule rule
rule.RecordExecuted();

beforeReturn?.Invoke(rule);
return _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, requestData, postData, successResponse(rule), contentType: rule.ReturnContentType);
return _inMemoryRequestInvoker.BuildResponse<TResponse>(endpoint, boundConfiguration, postData, successResponse(rule), contentType: rule.ReturnContentType);
}

private static byte[] CallResponse<TRule>(TRule rule)
Expand Down
4 changes: 2 additions & 2 deletions src/Elastic.Transport.VirtualizedCluster/Rules/RuleBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ public TRule ReturnResponse<T>(T response)
r = ms.ToArray();
}
Self.ReturnResponse = r;
Self.ReturnContentType = RequestData.DefaultContentType;
Self.ReturnContentType = BoundConfiguration.DefaultContentType;
return (TRule)this;
}

public TRule ReturnByteResponse(byte[] response, string responseContentType = RequestData.DefaultContentType)
public TRule ReturnByteResponse(byte[] response, string responseContentType = BoundConfiguration.DefaultContentType)
{
Self.ReturnResponse = response;
Self.ReturnContentType = responseContentType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
namespace Elastic.Transport;

/// <summary>
/// Where and how <see cref="IRequestInvoker.Request{TResponse}" /> should connect to.
/// <para>
/// Represents the cumulative configuration from <see cref="ITransportConfiguration" />
/// and <see cref="IRequestConfiguration" />.
/// </para>
/// </summary>
public sealed record RequestData
public sealed record BoundConfiguration : IRequestConfiguration
{
private const string OpaqueIdHeader = "X-Opaque-Id";

Expand All @@ -27,8 +24,8 @@ public sealed record RequestData
/// The security header used to run requests as a different user.
public const string RunAsSecurityHeader = "es-security-runas-user";

/// <inheritdoc cref="RequestData"/>
public RequestData(ITransportConfiguration global, IRequestConfiguration? local = null)
/// <inheritdoc cref="BoundConfiguration"/>
public BoundConfiguration(ITransportConfiguration global, IRequestConfiguration? local = null)
{
ConnectionSettings = global;
MemoryStreamFactory = global.MemoryStreamFactory;
Expand All @@ -55,7 +52,7 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
Accept = local?.Accept ?? global.Accept ?? DefaultContentType;
ThrowExceptions = local?.ThrowExceptions ?? global.ThrowExceptions ?? false;
RequestTimeout = local?.RequestTimeout ?? global.RequestTimeout ?? RequestConfiguration.DefaultRequestTimeout;
RequestMetaData = local?.RequestMetaData?.Items ?? EmptyReadOnly<string, string>.Dictionary;
RequestMetaData = local?.RequestMetaData;
AuthenticationHeader = local?.Authentication ?? global.Authentication;
AllowedStatusCodes = local?.AllowedStatusCodes ?? EmptyReadOnly<int>.Collection;
ClientCertificates = local?.ClientCertificates ?? global.ClientCertificates;
Expand All @@ -81,6 +78,7 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
Headers[key] = local.Headers[key];
}

OpaqueId = local?.OpaqueId;
if (!string.IsNullOrEmpty(local?.OpaqueId))
{
Headers ??= [];
Expand Down Expand Up @@ -115,6 +113,7 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
}

ProductResponseBuilders = global.ProductRegistration.ResponseBuilders;
DisableAuditTrail = local?.DisableAuditTrail ?? global.DisableAuditTrail ?? false;
}

/// <inheritdoc cref="ITransportConfiguration.MemoryStreamFactory"/>
Expand All @@ -140,7 +139,7 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
/// <inheritdoc cref="ITransportConfiguration.DnsRefreshTimeout"/>
public TimeSpan DnsRefreshTimeout { get; }
/// <inheritdoc cref="IRequestConfiguration.RequestMetaData"/>
public IReadOnlyDictionary<string, string> RequestMetaData { get; }
public RequestMetaData? RequestMetaData { get; }
/// <inheritdoc cref="IRequestConfiguration.Accept"/>
public string Accept { get; }
/// <inheritdoc cref="IRequestConfiguration.AllowedStatusCodes"/>
Expand Down Expand Up @@ -191,4 +190,45 @@ public RequestData(ITransportConfiguration global, IRequestConfiguration? local
public IReadOnlyCollection<IResponseBuilder> ProductResponseBuilders { get; }
/// <inheritdoc cref="IRequestConfiguration.ResponseBuilders"/>
public IReadOnlyCollection<IResponseBuilder> ResponseBuilders { get; }
/// <inheritdoc cref="IRequestConfiguration.DisableAuditTrail"/>
public bool DisableAuditTrail { get; }
/// <inheritdoc cref="IRequestConfiguration.OpaqueId"/>
public string? OpaqueId { get; }

string? IRequestConfiguration.Accept => Accept;
IReadOnlyCollection<int>? IRequestConfiguration.AllowedStatusCodes => AllowedStatusCodes;
AuthorizationHeader? IRequestConfiguration.Authentication => AuthenticationHeader;
X509CertificateCollection? IRequestConfiguration.ClientCertificates => ClientCertificates;
string? IRequestConfiguration.ContentType => ContentType;
bool? IRequestConfiguration.DisableDirectStreaming => DisableDirectStreaming;
bool? IRequestConfiguration.DisableAuditTrail => DisableAuditTrail;
bool? IRequestConfiguration.DisablePings => DisablePings;
bool? IRequestConfiguration.DisableSniff => DisableSniff;
bool? IRequestConfiguration.HttpPipeliningEnabled => HttpPipeliningEnabled;
bool? IRequestConfiguration.EnableHttpCompression => HttpCompression;
Uri? IRequestConfiguration.ForceNode => ForceNode;
int? IRequestConfiguration.MaxRetries => MaxRetries;
TimeSpan? IRequestConfiguration.MaxRetryTimeout => RequestTimeout;
string? IRequestConfiguration.OpaqueId => OpaqueId;
bool? IRequestConfiguration.ParseAllHeaders => ParseAllHeaders;
TimeSpan? IRequestConfiguration.PingTimeout => PingTimeout;
TimeSpan? IRequestConfiguration.RequestTimeout => RequestTimeout;
IReadOnlyCollection<IResponseBuilder> IRequestConfiguration.ResponseBuilders => ResponseBuilders;
HeadersList? IRequestConfiguration.ResponseHeadersToParse => ResponseHeadersToParse;
string? IRequestConfiguration.RunAs => RunAs;
bool? IRequestConfiguration.ThrowExceptions => ThrowExceptions;
bool? IRequestConfiguration.TransferEncodingChunked => TransferEncodingChunked;
NameValueCollection? IRequestConfiguration.Headers => Headers;
bool? IRequestConfiguration.EnableTcpStats => EnableTcpStats;
bool? IRequestConfiguration.EnableThreadPoolStats => EnableThreadPoolStats;
RequestMetaData? IRequestConfiguration.RequestMetaData => RequestMetaData;

/// <summary>
/// Create a cachable instance of <see cref="BoundConfiguration"/> for use in high-performance scenarios.
/// </summary>
/// <param name="transport">An existing <see cref="ITransport{TConfiguration}"/> from which to bind transport configuration.</param>
/// <param name="requestConfiguration">A request specific <see cref="IRequestConfiguration"/>.</param>
/// <returns></returns>
public static BoundConfiguration Create(ITransport<ITransportConfiguration> transport, IRequestConfiguration requestConfiguration) =>
new(transport.Configuration, requestConfiguration);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@ internal sealed class DefaultResponseBuilder : IResponseBuilder
bool IResponseBuilder.CanBuild<TResponse>() => true;

/// <inheritdoc/>
public TResponse Build<TResponse>(ApiCallDetails apiCallDetails, RequestData requestData,
public TResponse Build<TResponse>(ApiCallDetails apiCallDetails, BoundConfiguration boundConfiguration,
Stream responseStream, string contentType, long contentLength)
where TResponse : TransportResponse, new() =>
SetBodyCoreAsync<TResponse>(false, apiCallDetails, requestData, responseStream).EnsureCompleted();
SetBodyCoreAsync<TResponse>(false, apiCallDetails, boundConfiguration, responseStream).EnsureCompleted();

/// <inheritdoc/>
public Task<TResponse> BuildAsync<TResponse>(
ApiCallDetails apiCallDetails, RequestData requestData, Stream responseStream, string contentType, long contentLength,
ApiCallDetails apiCallDetails, BoundConfiguration boundConfiguration, Stream responseStream, string contentType, long contentLength,
CancellationToken cancellationToken) where TResponse : TransportResponse, new() =>
SetBodyCoreAsync<TResponse>(true, apiCallDetails, requestData, responseStream, cancellationToken).AsTask();
SetBodyCoreAsync<TResponse>(true, apiCallDetails, boundConfiguration, responseStream, cancellationToken).AsTask();

private static async ValueTask<TResponse> SetBodyCoreAsync<TResponse>(bool isAsync,
ApiCallDetails details, RequestData requestData, Stream responseStream,
ApiCallDetails details, BoundConfiguration boundConfiguration, Stream responseStream,
CancellationToken cancellationToken = default)
where TResponse : TransportResponse, new()
{
TResponse response = null;

if (details.HttpStatusCode.HasValue &&
requestData.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value))
boundConfiguration.SkipDeserializationForStatusCodes.Contains(details.HttpStatusCode.Value))
{
return response;
}
Expand All @@ -51,9 +51,9 @@ private static async ValueTask<TResponse> SetBodyCoreAsync<TResponse>(bool isAsy
var beforeTicks = Stopwatch.GetTimestamp();

if (isAsync)
response = await requestData.ConnectionSettings.RequestResponseSerializer.DeserializeAsync<TResponse>(responseStream, cancellationToken).ConfigureAwait(false);
response = await boundConfiguration.ConnectionSettings.RequestResponseSerializer.DeserializeAsync<TResponse>(responseStream, cancellationToken).ConfigureAwait(false);
else
response = requestData.ConnectionSettings.RequestResponseSerializer.Deserialize<TResponse>(responseStream);
response = boundConfiguration.ConnectionSettings.RequestResponseSerializer.Deserialize<TResponse>(responseStream);

var deserializeResponseMs = (Stopwatch.GetTimestamp() - beforeTicks) / (Stopwatch.Frequency / 1000);

Expand Down
Loading
Loading