From fd1e1885057ee327ea5487dff2c9b53f28c1f4d5 Mon Sep 17 00:00:00 2001 From: Leonard Lyubich Date: Wed, 4 Dec 2024 19:17:36 +0300 Subject: [PATCH] pool: Use concurrency-safe random number generator in sampler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, pool used `rand.NewSource` which is documented as 'not safe for concurrent use by multiple goroutines'. Since pool is generally accessed from multiple goroutines, it could finish with panic thrown by the mentioned component. Now pool uses safe random generators from the same `math/rand` package. Previous approach is kept for some tests that do not need multi-threading. Benchmark runs show that execution speed has obviously dropped due to sync costs, but it is negligible: ``` goos: linux goarch: amd64 pkg: github.com/nspcc-dev/neofs-sdk-go/pool cpu: Intel(R) Core(TM) i5-10210U CPU @ 1.60GHz │ sec/op │ Sampler-8 37.93n ± 13% │ B/op │ Sampler-8 0.000 ± 0% │ allocs/op │ Sampler-8 0.000 ± 0% ``` Fixes #631. Signed-off-by: Leonard Lyubich --- pool/pool.go | 7 ++----- pool/sampler.go | 18 +++++++++++++++--- pool/sampler_test.go | 12 ++++++------ 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/pool/pool.go b/pool/pool.go index 8edcf5cf..514196d5 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math/rand" "sort" "sync" "sync/atomic" @@ -769,8 +768,7 @@ func (p *Pool) Dial(ctx context.Context) error { atLeastOneHealthy = true } - source := rand.NewSource(time.Now().UnixNano()) - sampl := newSampler(params.weights, source) + sampl := newSampler(params.weights, safeRand{}) inner[i] = &innerPool{ sampler: sampl, @@ -940,9 +938,8 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights if healthyChanged.Load() { probabilities := adjustWeights(bufferWeights) - source := rand.NewSource(time.Now().UnixNano()) pool.lock.Lock() - pool.sampler = newSampler(probabilities, source) + pool.sampler = newSampler(probabilities, safeRand{}) pool.lock.Unlock() } } diff --git a/pool/sampler.go b/pool/sampler.go index cde60578..aa4bba73 100644 --- a/pool/sampler.go +++ b/pool/sampler.go @@ -2,10 +2,22 @@ package pool import "math/rand" +// [rand.Rand] interface. +type rander interface { + Intn(n int) int + Float64() float64 +} + +// replacement of [rand.Rand] safe for concurrent use. +type safeRand struct{} + +func (safeRand) Intn(n int) int { return rand.Intn(n) } +func (safeRand) Float64() float64 { return rand.Float64() } + // sampler implements weighted random number generation using Vose's Alias // Method (https://www.keithschwarz.com/darts-dice-coins/). type sampler struct { - randomGenerator *rand.Rand + randomGenerator rander probabilities []float64 alias []int } @@ -13,14 +25,14 @@ type sampler struct { // newSampler creates new sampler with a given set of probabilities using // given source of randomness. Created sampler will produce numbers from // 0 to len(probabilities). -func newSampler(probabilities []float64, source rand.Source) *sampler { +func newSampler(probabilities []float64, r rander) *sampler { sampler := &sampler{} var ( small workList large workList ) n := len(probabilities) - sampler.randomGenerator = rand.New(source) + sampler.randomGenerator = r sampler.probabilities = make([]float64, n) sampler.alias = make([]int, n) // Compute scaled probabilities. diff --git a/pool/sampler_test.go b/pool/sampler_test.go index 84f4a174..842bc00a 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -36,7 +36,7 @@ func TestSamplerStability(t *testing.T) { } for _, tc := range cases { - sampler := newSampler(tc.probabilities, rand.NewSource(0)) + sampler := newSampler(tc.probabilities, rand.New(rand.NewSource(0))) res := make([]int, len(tc.probabilities)) for range COUNT { res[sampler.next()]++ @@ -62,7 +62,7 @@ func TestHealthyReweight(t *testing.T) { client2 := newMockClient(names[1], neofscryptotest.Signer()) inner := &innerPool{ - sampler: newSampler(weights, rand.NewSource(0)), + sampler: newSampler(weights, rand.New(rand.NewSource(0))), clients: []internalClient{client1, client2}, } p := &Pool{ @@ -91,7 +91,7 @@ func TestHealthyReweight(t *testing.T) { inner.lock.Unlock() p.updateInnerNodesHealth(context.TODO(), 0, buffer) - inner.sampler = newSampler(weights, rand.NewSource(0)) + inner.sampler = newSampler(weights, rand.New(rand.NewSource(0))) connection0, err = p.connection() require.NoError(t, err) @@ -106,7 +106,7 @@ func TestHealthyNoReweight(t *testing.T) { buffer = make([]float64, len(weights)) ) - sampl := newSampler(weights, rand.NewSource(0)) + sampl := newSampler(weights, rand.New(rand.NewSource(0))) inner := &innerPool{ sampler: sampl, clients: []internalClient{ @@ -134,7 +134,7 @@ func TestSamplerSafety(t *testing.T) { cause any stack []byte } - s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, rand.NewSource(0)) + s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, safeRand{}) var pr atomic.Value // panicInfo var wg sync.WaitGroup for range 1000 { @@ -163,7 +163,7 @@ func TestSamplerSafety(t *testing.T) { } func BenchmarkSampler(b *testing.B) { - s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, rand.NewSource(0)) + s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, safeRand{}) for range b.N { s.next() }