diff --git a/pool/pool.go b/pool/pool.go index 8edcf5cf..514196d5 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "math/rand" "sort" "sync" "sync/atomic" @@ -769,8 +768,7 @@ func (p *Pool) Dial(ctx context.Context) error { atLeastOneHealthy = true } - source := rand.NewSource(time.Now().UnixNano()) - sampl := newSampler(params.weights, source) + sampl := newSampler(params.weights, safeRand{}) inner[i] = &innerPool{ sampler: sampl, @@ -940,9 +938,8 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights if healthyChanged.Load() { probabilities := adjustWeights(bufferWeights) - source := rand.NewSource(time.Now().UnixNano()) pool.lock.Lock() - pool.sampler = newSampler(probabilities, source) + pool.sampler = newSampler(probabilities, safeRand{}) pool.lock.Unlock() } } diff --git a/pool/sampler.go b/pool/sampler.go index cde60578..aa4bba73 100644 --- a/pool/sampler.go +++ b/pool/sampler.go @@ -2,10 +2,22 @@ package pool import "math/rand" +// [rand.Rand] interface. +type rander interface { + Intn(n int) int + Float64() float64 +} + +// replacement of [rand.Rand] safe for concurrent use. +type safeRand struct{} + +func (safeRand) Intn(n int) int { return rand.Intn(n) } +func (safeRand) Float64() float64 { return rand.Float64() } + // sampler implements weighted random number generation using Vose's Alias // Method (https://www.keithschwarz.com/darts-dice-coins/). type sampler struct { - randomGenerator *rand.Rand + randomGenerator rander probabilities []float64 alias []int } @@ -13,14 +25,14 @@ type sampler struct { // newSampler creates new sampler with a given set of probabilities using // given source of randomness. Created sampler will produce numbers from // 0 to len(probabilities). -func newSampler(probabilities []float64, source rand.Source) *sampler { +func newSampler(probabilities []float64, r rander) *sampler { sampler := &sampler{} var ( small workList large workList ) n := len(probabilities) - sampler.randomGenerator = rand.New(source) + sampler.randomGenerator = r sampler.probabilities = make([]float64, n) sampler.alias = make([]int, n) // Compute scaled probabilities. diff --git a/pool/sampler_test.go b/pool/sampler_test.go index 84f4a174..842bc00a 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -36,7 +36,7 @@ func TestSamplerStability(t *testing.T) { } for _, tc := range cases { - sampler := newSampler(tc.probabilities, rand.NewSource(0)) + sampler := newSampler(tc.probabilities, rand.New(rand.NewSource(0))) res := make([]int, len(tc.probabilities)) for range COUNT { res[sampler.next()]++ @@ -62,7 +62,7 @@ func TestHealthyReweight(t *testing.T) { client2 := newMockClient(names[1], neofscryptotest.Signer()) inner := &innerPool{ - sampler: newSampler(weights, rand.NewSource(0)), + sampler: newSampler(weights, rand.New(rand.NewSource(0))), clients: []internalClient{client1, client2}, } p := &Pool{ @@ -91,7 +91,7 @@ func TestHealthyReweight(t *testing.T) { inner.lock.Unlock() p.updateInnerNodesHealth(context.TODO(), 0, buffer) - inner.sampler = newSampler(weights, rand.NewSource(0)) + inner.sampler = newSampler(weights, rand.New(rand.NewSource(0))) connection0, err = p.connection() require.NoError(t, err) @@ -106,7 +106,7 @@ func TestHealthyNoReweight(t *testing.T) { buffer = make([]float64, len(weights)) ) - sampl := newSampler(weights, rand.NewSource(0)) + sampl := newSampler(weights, rand.New(rand.NewSource(0))) inner := &innerPool{ sampler: sampl, clients: []internalClient{ @@ -134,7 +134,7 @@ func TestSamplerSafety(t *testing.T) { cause any stack []byte } - s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, rand.NewSource(0)) + s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, safeRand{}) var pr atomic.Value // panicInfo var wg sync.WaitGroup for range 1000 { @@ -163,7 +163,7 @@ func TestSamplerSafety(t *testing.T) { } func BenchmarkSampler(b *testing.B) { - s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, rand.NewSource(0)) + s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, safeRand{}) for range b.N { s.next() }