mirror of
https://github.com/seaweedfs/seaweedfs.git
synced 2024-01-19 02:48:24 +00:00
Merge pull request #1060 from divinerapier/master
fix: non-thread-safe rand will panic
This commit is contained in:
commit
b0e4771135
|
@ -3,11 +3,11 @@ package wdclient
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/chrislusf/seaweedfs/weed/glog"
|
"github.com/chrislusf/seaweedfs/weed/glog"
|
||||||
)
|
)
|
||||||
|
@ -20,16 +20,27 @@ type Location struct {
|
||||||
type vidMap struct {
|
type vidMap struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
vid2Locations map[uint32][]Location
|
vid2Locations map[uint32][]Location
|
||||||
r *rand.Rand
|
|
||||||
|
cursor int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newVidMap() vidMap {
|
func newVidMap() vidMap {
|
||||||
return vidMap{
|
return vidMap{
|
||||||
vid2Locations: make(map[uint32][]Location),
|
vid2Locations: make(map[uint32][]Location),
|
||||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
cursor: -1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (vc *vidMap) getLocationIndex(length int64) (int64, error) {
|
||||||
|
if length <= 0 {
|
||||||
|
return 0, fmt.Errorf("invalid length: %d", length)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt64(&vc.cursor) == math.MaxInt64 {
|
||||||
|
atomic.CompareAndSwapInt64(&vc.cursor, math.MaxInt64, -1)
|
||||||
|
}
|
||||||
|
return atomic.AddInt64(&vc.cursor, 1) % length, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (vc *vidMap) LookupVolumeServerUrl(vid string) (serverUrl string, err error) {
|
func (vc *vidMap) LookupVolumeServerUrl(vid string) (serverUrl string, err error) {
|
||||||
id, err := strconv.Atoi(vid)
|
id, err := strconv.Atoi(vid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -94,7 +105,12 @@ func (vc *vidMap) GetRandomLocation(vid uint32) (serverUrl string, err error) {
|
||||||
return "", fmt.Errorf("volume %d not found", vid)
|
return "", fmt.Errorf("volume %d not found", vid)
|
||||||
}
|
}
|
||||||
|
|
||||||
return locations[vc.r.Intn(len(locations))].Url, nil
|
index, err := vc.getLocationIndex(int64(len(locations)))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("volume %d. %v", vid, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return locations[index].Url, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vc *vidMap) addLocation(vid uint32, location Location) {
|
func (vc *vidMap) addLocation(vid uint32, location Location) {
|
||||||
|
|
77
weed/wdclient/vid_map_test.go
Normal file
77
weed/wdclient/vid_map_test.go
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
package wdclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocationIndex(t *testing.T) {
|
||||||
|
vm := vidMap{}
|
||||||
|
// test must be failed
|
||||||
|
mustFailed := func(length int64) {
|
||||||
|
_, err := vm.getLocationIndex(length)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("length %d must be failed", length)
|
||||||
|
}
|
||||||
|
if err.Error() != fmt.Sprintf("invalid length: %d", length) {
|
||||||
|
t.Errorf("length %d must be failed. error: %v", length, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mustFailed(-1)
|
||||||
|
mustFailed(0)
|
||||||
|
|
||||||
|
mustOk := func(length, cursor, expect int64) {
|
||||||
|
if length <= 0 {
|
||||||
|
t.Fatal("please don't do this")
|
||||||
|
}
|
||||||
|
vm.cursor = cursor
|
||||||
|
got, err := vm.getLocationIndex(length)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("length: %d, why? %v\n", length, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != expect {
|
||||||
|
t.Errorf("cursor: %d, length: %d, expect: %d, got: %d\n", cursor, length, expect, got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int64(-1); i < 100; i++ {
|
||||||
|
mustOk(7, i, (i+1)%7)
|
||||||
|
}
|
||||||
|
|
||||||
|
// when cursor reaches MaxInt64
|
||||||
|
mustOk(7, math.MaxInt64, 0)
|
||||||
|
|
||||||
|
// test with constructor
|
||||||
|
vm = newVidMap()
|
||||||
|
length := int64(7)
|
||||||
|
for i := int64(0); i < 100; i++ {
|
||||||
|
got, err := vm.getLocationIndex(length)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("length: %d, why? %v\n", length, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != i%length {
|
||||||
|
t.Errorf("length: %d, i: %d, got: %d\n", length, i, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLocationIndex(b *testing.B) {
|
||||||
|
b.SetParallelism(8)
|
||||||
|
vm := vidMap{
|
||||||
|
cursor: math.MaxInt64 - 10000,
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
_, err := vm.getLocationIndex(3)
|
||||||
|
if err != nil {
|
||||||
|
b.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in a new issue