Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

Commit

Permalink
Fix race conditions in BlobWriteStream / OpenWrite (#1038)
Browse files Browse the repository at this point in the history
* Fix blob write stream releasing lock prematurely.

* Fix race conditions in BlobWriteStream.

* fix file write stream as well.
  • Loading branch information
kasobol-msft authored Apr 15, 2021
1 parent b0d47bf commit c71fd86
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 24 deletions.
19 changes: 11 additions & 8 deletions Lib/ClassLibraryCommon/Blob/BlobWriteStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace Microsoft.Azure.Storage.Blob
using Microsoft.Azure.Storage.Core.Util;
using Microsoft.Azure.Storage.Shared.Protocol;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net;
Expand Down Expand Up @@ -181,8 +182,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
int initialOffset = offset;
int initialCount = count;

TaskCompletionSource<bool> continueTCS = new TaskCompletionSource<bool>();
Task<bool> continueTask = continueTCS.Task;
List<Task> continueTasks = new List<Task>();

if (this.lastException == null)
{
Expand All @@ -202,6 +202,9 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc

if (bytesToWrite == maxBytesToWrite)
{
TaskCompletionSource<bool> continueTCS = new TaskCompletionSource<bool>();
continueTasks.Add(continueTCS.Task);

// Note that we do not await on temptask, nor do we store it.
// We do not await temptask so as to enable parallel reads and writes.
// We could store it and await on it later, but that ends up being more complicated
Expand All @@ -220,7 +223,6 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
}

token.ThrowIfCancellationRequested();
continueTCS = null;
}
}
}
Expand All @@ -233,10 +235,8 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
this.blobChecksum.UpdateHash(buffer, initialOffset, initialCount);
}

if (continueTCS == null)
{
await continueTask.ConfigureAwait(false);
}
// Wait until all continueTasks complete to let all dispatched writes increment noPendingWritesEvent counter.
await Task.WhenAll(continueTasks);
}

/// <summary>
Expand Down Expand Up @@ -304,7 +304,10 @@ public override async Task FlushAsync(CancellationToken token)
if (!this.IgnoreFlush)
{
this.ThrowLastExceptionIfExists();
await this.DispatchWriteAsync(null, token).ConfigureAwait(false);
TaskCompletionSource<bool> continueTCS = new TaskCompletionSource<bool>();
await this.DispatchWriteAsync(continueTCS, token).ConfigureAwait(false);
// Make sure DispatchWriteAsync had a chance to increment noPendingWritesEvent if there's anything to write.
await continueTCS.Task;
await this.noPendingWritesEvent.WaitAsync().WithCancellation(token).ConfigureAwait(false);
this.ThrowLastExceptionIfExists();
}
Expand Down
19 changes: 11 additions & 8 deletions Lib/ClassLibraryCommon/File/FileWriteStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace Microsoft.Azure.Storage.File
using Microsoft.Azure.Storage.Core.Util;
using Microsoft.Azure.Storage.Shared.Protocol;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;
Expand Down Expand Up @@ -134,8 +135,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
int initialOffset = offset;
int initialCount = count;

TaskCompletionSource<bool> continueTCS = new TaskCompletionSource<bool>();
Task<bool> continueTask = continueTCS.Task;
List<Task> continueTasks = new List<Task>();

if (this.lastException == null)
{
Expand All @@ -155,6 +155,9 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc

if (bytesToWrite == maxBytesToWrite)
{
TaskCompletionSource<bool> continueTCS = new TaskCompletionSource<bool>();
continueTasks.Add(continueTCS.Task);

// Note that we do not await on temptask, nor do we store it.
// We do not await temptask so as to enable parallel reads and writes.
// We could store it and await on it later, but that ends up being more complicated
Expand All @@ -173,7 +176,6 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
}

cancellationToken.ThrowIfCancellationRequested();
continueTCS = null;
}
}
}
Expand All @@ -186,10 +188,8 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
this.fileChecksum.UpdateHash(buffer, initialOffset, initialCount);
}

if (continueTCS == null)
{
await continueTask.ConfigureAwait(false);
}
// Wait until all continueTasks complete to let all dispatched writes increment noPendingWritesEvent counter.
await Task.WhenAll(continueTasks);
}

/// <summary>
Expand Down Expand Up @@ -235,7 +235,10 @@ public override async Task FlushAsync(CancellationToken cancellationToken)
}

ThrowLastExceptionIfExists();
await this.DispatchWriteAsync(null, cancellationToken).ConfigureAwait(false);
TaskCompletionSource<bool> continueTCS = new TaskCompletionSource<bool>();
await this.DispatchWriteAsync(continueTCS, cancellationToken).ConfigureAwait(false);
// Make sure DispatchWriteAsync had a chance to increment noPendingWritesEvent if there's anything to write.
await continueTCS.Task;
await this.noPendingWritesEvent.WaitAsync().WithCancellation(cancellationToken).ConfigureAwait(false);
ThrowLastExceptionIfExists();
}
Expand Down
6 changes: 5 additions & 1 deletion Lib/Common/Core/Util/AsyncExtensions.Common.cs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,11 @@ public class AsyncManualResetEvent

public AsyncManualResetEvent(bool initialStateSignaled)
{
Task.Run(() => m_tcs.TrySetResult(initialStateSignaled));
if (initialStateSignaled)
{
// There's nobody awaiting the task nor is there any continuation, so this should be safe to do.
m_tcs.SetResult(true);
}
}

public Task WaitAsync() { return m_tcs.Task; }
Expand Down
20 changes: 13 additions & 7 deletions Lib/Common/Core/Util/CounterEvent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,48 @@

namespace Microsoft.Azure.Storage.Core.Util
{
using System.Threading;
using System.Threading.Tasks;

internal sealed class CounterEventAsync
{
private AsyncManualResetEvent internalEvent = new AsyncManualResetEvent(true);
private object counterLock = new object();
private SemaphoreSlim semaphoreSlim = new SemaphoreSlim(1, 1);
private int counter = 0;

/// <summary>
/// Increments the counter by one and thus sets the state of the event to non-signaled, causing threads to block.
/// </summary>
public void Increment()
{
lock (this.counterLock)
semaphoreSlim.Wait();
try
{
this.counter++;
this.internalEvent.Reset();
}
finally
{
semaphoreSlim.Release();
}
}

/// <summary>
/// Decrements the counter by one. If the counter reaches zero, sets the state of the event to signaled, allowing one or more waiting threads to proceed.
/// </summary>
public async Task DecrementAsync()
{
bool setEvent = false;
lock (this.counterLock)
await semaphoreSlim.WaitAsync().ConfigureAwait(false);
try
{
if (--this.counter == 0)
{
setEvent = true;
await this.internalEvent.Set().ConfigureAwait(false);
}
}
if (setEvent)
finally
{
await this.internalEvent.Set().ConfigureAwait(false);
semaphoreSlim.Release();
}
}

Expand Down
126 changes: 126 additions & 0 deletions Test/ClassLibraryCommon/Core/AsyncManualResetEventTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// -----------------------------------------------------------------------------------------
// <copyright file="AsyncStreamCopierTests.cs" company="Microsoft">
// Copyright 2013 Microsoft Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// </copyright>
// -----------------------------------------------------------------------------------------

using System;
using System.Threading.Tasks;
using Microsoft.Azure.Storage.Core.Util;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Microsoft.Azure.Storage.Core
{
[TestClass]
public class AsyncManualResetEventTests
{
[TestMethod]
public void CtorCreateSet()
{
// arrange
var theEvent = new AsyncManualResetEvent(true);

// act
bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2));

// assert
Assert.IsTrue(completed);
}

[TestMethod]
public void CtorCreateUnSet()
{
// arrange
var theEvent = new AsyncManualResetEvent(false);

// act
bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2));

// assert
Assert.IsFalse(completed);
}

[DataTestMethod]
[DataRow(true)]
[DataRow(false)]
public void ShouldReset(bool initialState)
{
// arrange
var theEvent = new AsyncManualResetEvent(initialState);

// act
theEvent.Reset();

// assert
bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2));
Assert.IsFalse(completed);
}

[DataTestMethod]
[DataRow(true)]
[DataRow(false)]
public async Task ShouldResetAfterSequenceOfTransitions(bool initialState)
{
// arrange
var theEvent = new AsyncManualResetEvent(initialState);

// act
await theEvent.Set();
theEvent.Reset();
await theEvent.Set();
theEvent.Reset();

// assert
bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2));
Assert.IsFalse(completed);
}

[DataTestMethod]
[DataRow(true)]
[DataRow(false)]
public async Task ShouldSet(bool initialState)
{
// arrange
var theEvent = new AsyncManualResetEvent(initialState);

// act
await theEvent.Set();

// assert
bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2));
Assert.IsTrue(completed);
}

[DataTestMethod]
[DataRow(true)]
[DataRow(false)]
public async Task ShouldSetAfterSequenceOfTransitions(bool initialState)
{
// arrange
var theEvent = new AsyncManualResetEvent(initialState);

// act
await theEvent.Set();
theEvent.Reset();
await theEvent.Set();
theEvent.Reset();
await theEvent.Set();

// assert
bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2));
Assert.IsTrue(completed);
}
}

}

0 comments on commit c71fd86

Please sign in to comment.