Skip to content

Commit

Permalink
pool: Use concurrency-safe random number generator in sampler
Browse files Browse the repository at this point in the history
Previously, pool used `rand.NewSource` which is documented as 'not safe
for concurrent use by multiple goroutines'. Since pool is generally
accessed from multiple goroutines, it could finish with panic thrown by
the mentioned component.

Now pool uses safe random generators from the same `math/rand` package.
Previous approach is kept for some tests that do not need
multi-threading.

Benchmark runs show that execution speed has obviously dropped due to
sync costs, but it is negligible:
```
goos: linux
goarch: amd64
pkg: github.com/nspcc-dev/neofs-sdk-go/pool
cpu: Intel(R) Core(TM) i5-10210U CPU @ 1.60GHz
          │    sec/op    │
Sampler-8   37.93n ± 13%
          │     B/op     │
Sampler-8     0.000 ± 0%
          │  allocs/op   │
Sampler-8     0.000 ± 0%
```

Fixes #631.

Signed-off-by: Leonard Lyubich <leonard@morphbits.io>
  • Loading branch information
cthulhu-rider committed Dec 4, 2024
1 parent 9eaf6fb commit fd1e188
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
7 changes: 2 additions & 5 deletions pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"sort"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -769,8 +768,7 @@ func (p *Pool) Dial(ctx context.Context) error {

atLeastOneHealthy = true
}
source := rand.NewSource(time.Now().UnixNano())
sampl := newSampler(params.weights, source)
sampl := newSampler(params.weights, safeRand{})

inner[i] = &innerPool{
sampler: sampl,
Expand Down Expand Up @@ -940,9 +938,8 @@ func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights

if healthyChanged.Load() {
probabilities := adjustWeights(bufferWeights)
source := rand.NewSource(time.Now().UnixNano())
pool.lock.Lock()
pool.sampler = newSampler(probabilities, source)
pool.sampler = newSampler(probabilities, safeRand{})
pool.lock.Unlock()
}
}
Expand Down
18 changes: 15 additions & 3 deletions pool/sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,37 @@ package pool

import "math/rand"

// [rand.Rand] interface.
type rander interface {
Intn(n int) int
Float64() float64
}

// replacement of [rand.Rand] safe for concurrent use.
type safeRand struct{}

func (safeRand) Intn(n int) int { return rand.Intn(n) }
func (safeRand) Float64() float64 { return rand.Float64() }

// sampler implements weighted random number generation using Vose's Alias
// Method (https://www.keithschwarz.com/darts-dice-coins/).
type sampler struct {
randomGenerator *rand.Rand
randomGenerator rander
probabilities []float64
alias []int
}

// newSampler creates new sampler with a given set of probabilities using
// given source of randomness. Created sampler will produce numbers from
// 0 to len(probabilities).
func newSampler(probabilities []float64, source rand.Source) *sampler {
func newSampler(probabilities []float64, r rander) *sampler {
sampler := &sampler{}
var (
small workList
large workList
)
n := len(probabilities)
sampler.randomGenerator = rand.New(source)
sampler.randomGenerator = r
sampler.probabilities = make([]float64, n)
sampler.alias = make([]int, n)
// Compute scaled probabilities.
Expand Down
12 changes: 6 additions & 6 deletions pool/sampler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestSamplerStability(t *testing.T) {
}

for _, tc := range cases {
sampler := newSampler(tc.probabilities, rand.NewSource(0))
sampler := newSampler(tc.probabilities, rand.New(rand.NewSource(0)))
res := make([]int, len(tc.probabilities))
for range COUNT {
res[sampler.next()]++
Expand All @@ -62,7 +62,7 @@ func TestHealthyReweight(t *testing.T) {
client2 := newMockClient(names[1], neofscryptotest.Signer())

inner := &innerPool{
sampler: newSampler(weights, rand.NewSource(0)),
sampler: newSampler(weights, rand.New(rand.NewSource(0))),
clients: []internalClient{client1, client2},
}
p := &Pool{
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestHealthyReweight(t *testing.T) {
inner.lock.Unlock()

p.updateInnerNodesHealth(context.TODO(), 0, buffer)
inner.sampler = newSampler(weights, rand.NewSource(0))
inner.sampler = newSampler(weights, rand.New(rand.NewSource(0)))

connection0, err = p.connection()
require.NoError(t, err)
Expand All @@ -106,7 +106,7 @@ func TestHealthyNoReweight(t *testing.T) {
buffer = make([]float64, len(weights))
)

sampl := newSampler(weights, rand.NewSource(0))
sampl := newSampler(weights, rand.New(rand.NewSource(0)))
inner := &innerPool{
sampler: sampl,
clients: []internalClient{
Expand Down Expand Up @@ -134,7 +134,7 @@ func TestSamplerSafety(t *testing.T) {
cause any
stack []byte
}
s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, rand.NewSource(0))
s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, safeRand{})
var pr atomic.Value // panicInfo
var wg sync.WaitGroup
for range 1000 {
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestSamplerSafety(t *testing.T) {
}

func BenchmarkSampler(b *testing.B) {
s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, rand.NewSource(0))
s := newSampler([]float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1}, safeRand{})
for range b.N {
s.next()
}
Expand Down

0 comments on commit fd1e188

Please sign in to comment.