Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

Commit

Permalink
Enforce timeout on network read/writes. (#985)
Browse files Browse the repository at this point in the history
  • Loading branch information
kasobol-msft authored Apr 30, 2020
1 parent af1fe1d commit bed4848
Show file tree
Hide file tree
Showing 12 changed files with 637 additions and 7 deletions.
15 changes: 15 additions & 0 deletions Lib/ClassLibraryCommon/Core/ByteCountingStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace Microsoft.Azure.Storage.Core
using Microsoft.Azure.Storage.Core.Util;
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

/// <summary>
/// This class provides a wrapper that will update the Ingress / Egress bytes of a given request result as the stream is used.
Expand Down Expand Up @@ -107,6 +109,13 @@ public override int Read(byte[] buffer, int offset, int count)
return read;
}

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
int read = await this.wrappedStream.ReadAsync(buffer, offset, count, cancellationToken);
this.requestObject.IngressBytes += read;
return read;
}

public override int ReadByte()
{
int val = this.wrappedStream.ReadByte();
Expand Down Expand Up @@ -181,6 +190,12 @@ public override void Write(byte[] buffer, int offset, int count)
this.requestObject.EgressBytes += count;
}

public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await this.wrappedStream.WriteAsync(buffer, offset, count, cancellationToken);
this.requestObject.EgressBytes += count;
}

public override void WriteByte(byte value)
{
this.wrappedStream.WriteByte(value);
Expand Down
7 changes: 6 additions & 1 deletion Lib/ClassLibraryCommon/Core/Executor/Executor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ public static async Task<T> ExecuteAsync<T>(RESTCommand<T> cmd, IRetryPolicy pol

// 8. (Potentially reads stream from server)
executionState.CurrentOperation = ExecutorOperation.GetResponseStream;
cmd.ResponseStream = await executionState.Resp.Content.ReadAsStreamAsync().ConfigureAwait(false);
var responseStream = await executionState.Resp.Content.ReadAsStreamAsync().ConfigureAwait(false);
if (cmd.NetworkTimeout.HasValue)
{
responseStream = new TimeoutStream(responseStream, cmd.NetworkTimeout.Value);
}
cmd.ResponseStream = responseStream;

// The stream is now available in ResponseStream. Use the stream to parse out the response or error
if (executionState.ExceptionRef != null)
Expand Down
281 changes: 281 additions & 0 deletions Lib/ClassLibraryCommon/Core/TimeoutStream.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
//-----------------------------------------------------------------------
// <copyright file="ByteCountingStream.cs" company="Microsoft">
// Copyright 2013 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// </copyright>
//-----------------------------------------------------------------------


namespace Microsoft.Azure.Storage.Core
{
using Microsoft.Azure.Storage.Core.Util;
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

/// <summary>
/// Stream that will throw a <see cref="OperationCanceledException"/> if it has to wait longer than a configurable timeout to read or write more data
/// </summary>
internal class TimeoutStream : Stream
{
private readonly Stream wrappedStream;
private TimeSpan readTimeout;
private TimeSpan writeTimeout;
private CancellationTokenSource cancellationTokenSource;

public TimeoutStream(Stream wrappedStream, TimeSpan timeout)
: this(wrappedStream, timeout, timeout) { }

public TimeoutStream(Stream wrappedStream, TimeSpan readTimeout, TimeSpan writeTimeout)
{
CommonUtility.AssertNotNull("WrappedStream", wrappedStream);
CommonUtility.AssertNotNull("ReadTimeout", readTimeout);
CommonUtility.AssertNotNull("WriteTimeout", writeTimeout);
this.wrappedStream = wrappedStream;
this.readTimeout = readTimeout;
this.writeTimeout = writeTimeout;
this.UpdateReadTimeout();
this.UpdateWriteTimeout();
this.cancellationTokenSource = new CancellationTokenSource();
}

public override long Position
{
get { return this.wrappedStream.Position; }
set { this.wrappedStream.Position = value; }
}

public override long Length
{
get { return this.wrappedStream.Length; }
}

public override bool CanWrite
{
get { return this.wrappedStream.CanWrite; }
}

public override bool CanTimeout
{
get { return this.wrappedStream.CanTimeout; }
}

public override bool CanSeek
{
get { return this.wrappedStream.CanSeek; }
}

public override bool CanRead
{
get { return this.wrappedStream.CanRead; }
}

public override int ReadTimeout
{
get { return (int) this.readTimeout.TotalMilliseconds; }
set {
this.readTimeout = TimeSpan.FromMilliseconds(value);
this.UpdateReadTimeout();
}
}

public override int WriteTimeout
{
get { return (int) this.writeTimeout.TotalMilliseconds; }
set
{
this.writeTimeout = TimeSpan.FromMilliseconds(value);
this.UpdateWriteTimeout();
}
}

public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return this.wrappedStream.BeginRead(buffer, offset, count, callback, state);
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return this.wrappedStream.BeginWrite(buffer, offset, count, callback, state);
}

public override void Close()
{
this.wrappedStream.Close();
}

public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
return this.wrappedStream.CopyToAsync(destination, bufferSize, cancellationToken);
}

public override int EndRead(IAsyncResult asyncResult)
{
return this.wrappedStream.EndRead(asyncResult);
}

public override void EndWrite(IAsyncResult asyncResult)
{
this.wrappedStream.EndWrite(asyncResult);
}

public override void Flush()
{
this.wrappedStream.Flush();
}

public override async Task FlushAsync(CancellationToken cancellationToken)
{
var source = StartTimeout(cancellationToken, out bool dispose);
try
{
await this.wrappedStream.FlushAsync(source.Token);
}
finally
{
StopTimeout(source, dispose);
}
}

public override int Read(byte[] buffer, int offset, int count)
{
return wrappedStream.Read(buffer, offset, count);
}

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var source = StartTimeout(cancellationToken, out bool dispose);
try
{
return await this.wrappedStream.ReadAsync(buffer, offset, count, source.Token);
}
finally
{
StopTimeout(source, dispose);
}
}

public override int ReadByte()
{
return this.wrappedStream.ReadByte();
}

public override long Seek(long offset, SeekOrigin origin)
{
return this.wrappedStream.Seek(offset, origin);
}

public override void SetLength(long value)
{
this.wrappedStream.SetLength(value);
}

public override void Write(byte[] buffer, int offset, int count)
{
this.wrappedStream.Write(buffer, offset, count);
}

public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var source = StartTimeout(cancellationToken, out bool dispose);
try
{
await this.wrappedStream.WriteAsync(buffer, offset, count, source.Token);
}
finally
{
StopTimeout(source, dispose);
}
}

public override void WriteByte(byte value)
{
this.wrappedStream.WriteByte(value);
}

private CancellationTokenSource StartTimeout(CancellationToken additionalToken, out bool dispose)
{
if (this.cancellationTokenSource.IsCancellationRequested)
{
this.cancellationTokenSource = new CancellationTokenSource();
}

CancellationTokenSource source;
if (additionalToken.CanBeCanceled)
{
source = CancellationTokenSource.CreateLinkedTokenSource(additionalToken, this.cancellationTokenSource.Token);
dispose = true;
}
else
{
source = this.cancellationTokenSource;
dispose = false;
}

this.cancellationTokenSource.CancelAfter(this.readTimeout);

return source;
}

private void StopTimeout(CancellationTokenSource source, bool dispose)
{
this.cancellationTokenSource.CancelAfter(Timeout.InfiniteTimeSpan);
if (dispose)
{
source.Dispose();
}
}

private void UpdateReadTimeout()
{
if (this.wrappedStream.CanTimeout)
{
try
{
this.wrappedStream.ReadTimeout = (int)this.readTimeout.TotalMilliseconds;
}
catch
{
// ignore
}
}
}

private void UpdateWriteTimeout()
{
if (this.wrappedStream.CanTimeout)
{
try
{
this.wrappedStream.WriteTimeout = (int)this.writeTimeout.TotalMilliseconds;
}
catch
{
// ignore
}
}
}

protected override void Dispose(bool disposing)
{
base.Dispose(disposing);

if (disposing)
{
this.cancellationTokenSource.Dispose();
this.wrappedStream.Dispose();
}
}
}
}
14 changes: 14 additions & 0 deletions Lib/Common/Blob/BlobRequestOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public sealed class BlobRequestOptions : IRequestOptions
LocationMode = RetryPolicies.LocationMode.PrimaryOnly,
ServerTimeout = null,
MaximumExecutionTime = null,
NetworkTimeout = Constants.DefaultNetworkTimeout,
ParallelOperationThreadCount = 1,
SingleBlobUploadThresholdInBytes = Constants.MaxSingleUploadBlobSize / 2,

Expand Down Expand Up @@ -114,6 +115,7 @@ internal BlobRequestOptions(BlobRequestOptions other)
this.LocationMode = other.LocationMode;
this.ServerTimeout = other.ServerTimeout;
this.MaximumExecutionTime = other.MaximumExecutionTime;
this.NetworkTimeout = other.NetworkTimeout;
this.OperationExpiryTime = other.OperationExpiryTime;
this.ChecksumOptions.CopyFrom(other.ChecksumOptions);
this.ParallelOperationThreadCount = other.ParallelOperationThreadCount;
Expand Down Expand Up @@ -162,6 +164,11 @@ internal static BlobRequestOptions ApplyDefaults(BlobRequestOptions options, Blo
?? serviceClient.DefaultRequestOptions.MaximumExecutionTime
?? BaseDefaultRequestOptions.MaximumExecutionTime;

modifiedOptions.NetworkTimeout =
modifiedOptions.NetworkTimeout
?? serviceClient.DefaultRequestOptions.NetworkTimeout
?? BaseDefaultRequestOptions.NetworkTimeout;

modifiedOptions.ParallelOperationThreadCount =
modifiedOptions.ParallelOperationThreadCount
?? serviceClient.DefaultRequestOptions.ParallelOperationThreadCount
Expand Down Expand Up @@ -242,6 +249,8 @@ internal void ApplyToStorageCommand<T>(RESTCommand<T> cmd)
{
cmd.OperationExpiryTime = DateTime.Now + this.MaximumExecutionTime.Value;
}

cmd.NetworkTimeout = this.NetworkTimeout;
}

#if !(WINDOWS_RT || NETCORE)
Expand Down Expand Up @@ -413,6 +422,11 @@ public TimeSpan? MaximumExecutionTime
}
}

/// <summary>
/// Gets or sets the timeout applied to an individual network operations.
/// </summary>
public TimeSpan? NetworkTimeout { get; set; }

/// <summary>
/// Gets or sets the number of blocks that may be simultaneously uploaded.
/// </summary>
Expand Down
3 changes: 3 additions & 0 deletions Lib/Common/Core/Executor/StorageCommandBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ internal abstract class StorageCommandBase<T>
// Max client timeout, enforced over entire operation on client side
internal DateTime? OperationExpiryTime = null;

// Timeout applied to an individual network operations.
internal TimeSpan? NetworkTimeout = null;

// State- different than async state, this is used for ops to communicate state between invocations, i.e. bytes downloaded etc
internal object OperationState = null;

Expand Down
Loading

0 comments on commit bed4848

Please sign in to comment.