From c71fd86d8a75393659d9342724e480c1d3fe9f07 Mon Sep 17 00:00:00 2001 From: Kamil Sobol <61715331+kasobol-msft@users.noreply.github.com> Date: Thu, 15 Apr 2021 09:49:55 -0700 Subject: [PATCH] Fix race conditions in BlobWriteStream / OpenWrite (#1038) * Fix blob write stream releasing lock prematurely. * Fix race conditions in BlobWriteStream. * fix file write stream as well. --- .../Blob/BlobWriteStream.cs | 19 +-- .../File/FileWriteStream.cs | 19 +-- .../Core/Util/AsyncExtensions.Common.cs | 6 +- Lib/Common/Core/Util/CounterEvent.cs | 20 ++- .../Core/AsyncManualResetEventTests.cs | 126 ++++++++++++++++++ 5 files changed, 166 insertions(+), 24 deletions(-) create mode 100644 Test/ClassLibraryCommon/Core/AsyncManualResetEventTests.cs diff --git a/Lib/ClassLibraryCommon/Blob/BlobWriteStream.cs b/Lib/ClassLibraryCommon/Blob/BlobWriteStream.cs index 08a051f8a..cd6a4ab3a 100644 --- a/Lib/ClassLibraryCommon/Blob/BlobWriteStream.cs +++ b/Lib/ClassLibraryCommon/Blob/BlobWriteStream.cs @@ -22,6 +22,7 @@ namespace Microsoft.Azure.Storage.Blob using Microsoft.Azure.Storage.Core.Util; using Microsoft.Azure.Storage.Shared.Protocol; using System; + using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net; @@ -181,8 +182,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc int initialOffset = offset; int initialCount = count; - TaskCompletionSource continueTCS = new TaskCompletionSource(); - Task continueTask = continueTCS.Task; + List continueTasks = new List(); if (this.lastException == null) { @@ -202,6 +202,9 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc if (bytesToWrite == maxBytesToWrite) { + TaskCompletionSource continueTCS = new TaskCompletionSource(); + continueTasks.Add(continueTCS.Task); + // Note that we do not await on temptask, nor do we store it. // We do not await temptask so as to enable parallel reads and writes. // We could store it and await on it later, but that ends up being more complicated @@ -220,7 +223,6 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc } token.ThrowIfCancellationRequested(); - continueTCS = null; } } } @@ -233,10 +235,8 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc this.blobChecksum.UpdateHash(buffer, initialOffset, initialCount); } - if (continueTCS == null) - { - await continueTask.ConfigureAwait(false); - } + // Wait until all continueTasks complete to let all dispatched writes increment noPendingWritesEvent counter. + await Task.WhenAll(continueTasks); } /// @@ -304,7 +304,10 @@ public override async Task FlushAsync(CancellationToken token) if (!this.IgnoreFlush) { this.ThrowLastExceptionIfExists(); - await this.DispatchWriteAsync(null, token).ConfigureAwait(false); + TaskCompletionSource continueTCS = new TaskCompletionSource(); + await this.DispatchWriteAsync(continueTCS, token).ConfigureAwait(false); + // Make sure DispatchWriteAsync had a chance to increment noPendingWritesEvent if there's anything to write. + await continueTCS.Task; await this.noPendingWritesEvent.WaitAsync().WithCancellation(token).ConfigureAwait(false); this.ThrowLastExceptionIfExists(); } diff --git a/Lib/ClassLibraryCommon/File/FileWriteStream.cs b/Lib/ClassLibraryCommon/File/FileWriteStream.cs index b88785a11..025c2fa39 100644 --- a/Lib/ClassLibraryCommon/File/FileWriteStream.cs +++ b/Lib/ClassLibraryCommon/File/FileWriteStream.cs @@ -21,6 +21,7 @@ namespace Microsoft.Azure.Storage.File using Microsoft.Azure.Storage.Core.Util; using Microsoft.Azure.Storage.Shared.Protocol; using System; + using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Threading; @@ -134,8 +135,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc int initialOffset = offset; int initialCount = count; - TaskCompletionSource continueTCS = new TaskCompletionSource(); - Task continueTask = continueTCS.Task; + List continueTasks = new List(); if (this.lastException == null) { @@ -155,6 +155,9 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc if (bytesToWrite == maxBytesToWrite) { + TaskCompletionSource continueTCS = new TaskCompletionSource(); + continueTasks.Add(continueTCS.Task); + // Note that we do not await on temptask, nor do we store it. // We do not await temptask so as to enable parallel reads and writes. // We could store it and await on it later, but that ends up being more complicated @@ -173,7 +176,6 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc } cancellationToken.ThrowIfCancellationRequested(); - continueTCS = null; } } } @@ -186,10 +188,8 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc this.fileChecksum.UpdateHash(buffer, initialOffset, initialCount); } - if (continueTCS == null) - { - await continueTask.ConfigureAwait(false); - } + // Wait until all continueTasks complete to let all dispatched writes increment noPendingWritesEvent counter. + await Task.WhenAll(continueTasks); } /// @@ -235,7 +235,10 @@ public override async Task FlushAsync(CancellationToken cancellationToken) } ThrowLastExceptionIfExists(); - await this.DispatchWriteAsync(null, cancellationToken).ConfigureAwait(false); + TaskCompletionSource continueTCS = new TaskCompletionSource(); + await this.DispatchWriteAsync(continueTCS, cancellationToken).ConfigureAwait(false); + // Make sure DispatchWriteAsync had a chance to increment noPendingWritesEvent if there's anything to write. + await continueTCS.Task; await this.noPendingWritesEvent.WaitAsync().WithCancellation(cancellationToken).ConfigureAwait(false); ThrowLastExceptionIfExists(); } diff --git a/Lib/Common/Core/Util/AsyncExtensions.Common.cs b/Lib/Common/Core/Util/AsyncExtensions.Common.cs index 229e964ed..58d55a9a9 100644 --- a/Lib/Common/Core/Util/AsyncExtensions.Common.cs +++ b/Lib/Common/Core/Util/AsyncExtensions.Common.cs @@ -229,7 +229,11 @@ public class AsyncManualResetEvent public AsyncManualResetEvent(bool initialStateSignaled) { - Task.Run(() => m_tcs.TrySetResult(initialStateSignaled)); + if (initialStateSignaled) + { + // There's nobody awaiting the task nor is there any continuation, so this should be safe to do. + m_tcs.SetResult(true); + } } public Task WaitAsync() { return m_tcs.Task; } diff --git a/Lib/Common/Core/Util/CounterEvent.cs b/Lib/Common/Core/Util/CounterEvent.cs index 40595f100..f2e0bdc67 100644 --- a/Lib/Common/Core/Util/CounterEvent.cs +++ b/Lib/Common/Core/Util/CounterEvent.cs @@ -17,12 +17,13 @@ namespace Microsoft.Azure.Storage.Core.Util { + using System.Threading; using System.Threading.Tasks; internal sealed class CounterEventAsync { private AsyncManualResetEvent internalEvent = new AsyncManualResetEvent(true); - private object counterLock = new object(); + private SemaphoreSlim semaphoreSlim = new SemaphoreSlim(1, 1); private int counter = 0; /// @@ -30,11 +31,16 @@ internal sealed class CounterEventAsync /// public void Increment() { - lock (this.counterLock) + semaphoreSlim.Wait(); + try { this.counter++; this.internalEvent.Reset(); } + finally + { + semaphoreSlim.Release(); + } } /// @@ -42,17 +48,17 @@ public void Increment() /// public async Task DecrementAsync() { - bool setEvent = false; - lock (this.counterLock) + await semaphoreSlim.WaitAsync().ConfigureAwait(false); + try { if (--this.counter == 0) { - setEvent = true; + await this.internalEvent.Set().ConfigureAwait(false); } } - if (setEvent) + finally { - await this.internalEvent.Set().ConfigureAwait(false); + semaphoreSlim.Release(); } } diff --git a/Test/ClassLibraryCommon/Core/AsyncManualResetEventTests.cs b/Test/ClassLibraryCommon/Core/AsyncManualResetEventTests.cs new file mode 100644 index 000000000..f608b240f --- /dev/null +++ b/Test/ClassLibraryCommon/Core/AsyncManualResetEventTests.cs @@ -0,0 +1,126 @@ +// ----------------------------------------------------------------------------------------- +// +// Copyright 2013 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------------------- + +using System; +using System.Threading.Tasks; +using Microsoft.Azure.Storage.Core.Util; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.Azure.Storage.Core +{ + [TestClass] + public class AsyncManualResetEventTests + { + [TestMethod] + public void CtorCreateSet() + { + // arrange + var theEvent = new AsyncManualResetEvent(true); + + // act + bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2)); + + // assert + Assert.IsTrue(completed); + } + + [TestMethod] + public void CtorCreateUnSet() + { + // arrange + var theEvent = new AsyncManualResetEvent(false); + + // act + bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2)); + + // assert + Assert.IsFalse(completed); + } + + [DataTestMethod] + [DataRow(true)] + [DataRow(false)] + public void ShouldReset(bool initialState) + { + // arrange + var theEvent = new AsyncManualResetEvent(initialState); + + // act + theEvent.Reset(); + + // assert + bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2)); + Assert.IsFalse(completed); + } + + [DataTestMethod] + [DataRow(true)] + [DataRow(false)] + public async Task ShouldResetAfterSequenceOfTransitions(bool initialState) + { + // arrange + var theEvent = new AsyncManualResetEvent(initialState); + + // act + await theEvent.Set(); + theEvent.Reset(); + await theEvent.Set(); + theEvent.Reset(); + + // assert + bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2)); + Assert.IsFalse(completed); + } + + [DataTestMethod] + [DataRow(true)] + [DataRow(false)] + public async Task ShouldSet(bool initialState) + { + // arrange + var theEvent = new AsyncManualResetEvent(initialState); + + // act + await theEvent.Set(); + + // assert + bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2)); + Assert.IsTrue(completed); + } + + [DataTestMethod] + [DataRow(true)] + [DataRow(false)] + public async Task ShouldSetAfterSequenceOfTransitions(bool initialState) + { + // arrange + var theEvent = new AsyncManualResetEvent(initialState); + + // act + await theEvent.Set(); + theEvent.Reset(); + await theEvent.Set(); + theEvent.Reset(); + await theEvent.Set(); + + // assert + bool completed = theEvent.WaitAsync().Wait(TimeSpan.FromSeconds(2)); + Assert.IsTrue(completed); + } + } + +} \ No newline at end of file