test: add RWLock to fix data race in MockDBAdapter

Signed-off-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
Shengqi Chen 2025-01-11 15:37:37 +08:00
parent 3ad551f73d
commit 15e87a5f48
No known key found for this signature in database

View File

@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -424,6 +425,8 @@ func TestHTTPServer(t *testing.T) {
type mockDBAdapter struct { type mockDBAdapter struct {
workerStore map[string]WorkerStatus workerStore map[string]WorkerStatus
statusStore map[string]MirrorStatus statusStore map[string]MirrorStatus
workerLock sync.RWMutex
statusLock sync.RWMutex
} }
func (b *mockDBAdapter) Init() error { func (b *mockDBAdapter) Init() error {
@ -431,17 +434,22 @@ func (b *mockDBAdapter) Init() error {
} }
func (b *mockDBAdapter) ListWorkers() ([]WorkerStatus, error) { func (b *mockDBAdapter) ListWorkers() ([]WorkerStatus, error) {
b.workerLock.RLock()
workers := make([]WorkerStatus, len(b.workerStore)) workers := make([]WorkerStatus, len(b.workerStore))
idx := 0 idx := 0
for _, w := range b.workerStore { for _, w := range b.workerStore {
workers[idx] = w workers[idx] = w
idx++ idx++
} }
b.workerLock.RUnlock()
return workers, nil return workers, nil
} }
func (b *mockDBAdapter) GetWorker(workerID string) (WorkerStatus, error) { func (b *mockDBAdapter) GetWorker(workerID string) (WorkerStatus, error) {
b.workerLock.RLock()
defer b.workerLock.RUnlock()
w, ok := b.workerStore[workerID] w, ok := b.workerStore[workerID]
if !ok { if !ok {
return WorkerStatus{}, fmt.Errorf("invalid workerId") return WorkerStatus{}, fmt.Errorf("invalid workerId")
} }
@ -449,7 +457,9 @@ func (b *mockDBAdapter) GetWorker(workerID string) (WorkerStatus, error) {
} }
func (b *mockDBAdapter) DeleteWorker(workerID string) error { func (b *mockDBAdapter) DeleteWorker(workerID string) error {
b.workerLock.Lock()
delete(b.workerStore, workerID) delete(b.workerStore, workerID)
b.workerLock.Unlock()
return nil return nil
} }
@ -458,7 +468,9 @@ func (b *mockDBAdapter) CreateWorker(w WorkerStatus) (WorkerStatus, error) {
// if ok { // if ok {
// return workerStatus{}, fmt.Errorf("duplicate worker name") // return workerStatus{}, fmt.Errorf("duplicate worker name")
// } // }
b.workerLock.Lock()
b.workerStore[w.ID] = w b.workerStore[w.ID] = w
b.workerLock.Unlock()
return w, nil return w, nil
} }
@ -473,7 +485,9 @@ func (b *mockDBAdapter) RefreshWorker(workerID string) (w WorkerStatus, err erro
func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (MirrorStatus, error) { func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (MirrorStatus, error) {
id := mirrorID + "/" + workerID id := mirrorID + "/" + workerID
b.statusLock.RLock()
status, ok := b.statusStore[id] status, ok := b.statusStore[id]
b.statusLock.RUnlock()
if !ok { if !ok {
return MirrorStatus{}, fmt.Errorf("no mirror %s exists in worker %s", mirrorID, workerID) return MirrorStatus{}, fmt.Errorf("no mirror %s exists in worker %s", mirrorID, workerID)
} }
@ -487,7 +501,9 @@ func (b *mockDBAdapter) UpdateMirrorStatus(workerID, mirrorID string, status Mir
// } // }
id := mirrorID + "/" + workerID id := mirrorID + "/" + workerID
b.statusLock.Lock()
b.statusStore[id] = status b.statusStore[id] = status
b.statusLock.Unlock()
return status, nil return status, nil
} }
@ -497,19 +513,23 @@ func (b *mockDBAdapter) ListMirrorStatus(workerID string) ([]MirrorStatus, error
if workerID == _magicBadWorkerID { if workerID == _magicBadWorkerID {
return []MirrorStatus{}, fmt.Errorf("database fail") return []MirrorStatus{}, fmt.Errorf("database fail")
} }
b.statusLock.RLock()
for k, v := range b.statusStore { for k, v := range b.statusStore {
if wID := strings.Split(k, "/")[1]; wID == workerID { if wID := strings.Split(k, "/")[1]; wID == workerID {
mirrorStatusList = append(mirrorStatusList, v) mirrorStatusList = append(mirrorStatusList, v)
} }
} }
b.statusLock.RUnlock()
return mirrorStatusList, nil return mirrorStatusList, nil
} }
func (b *mockDBAdapter) ListAllMirrorStatus() ([]MirrorStatus, error) { func (b *mockDBAdapter) ListAllMirrorStatus() ([]MirrorStatus, error) {
var mirrorStatusList []MirrorStatus var mirrorStatusList []MirrorStatus
b.statusLock.RLock()
for _, v := range b.statusStore { for _, v := range b.statusStore {
mirrorStatusList = append(mirrorStatusList, v) mirrorStatusList = append(mirrorStatusList, v)
} }
b.statusLock.RUnlock()
return mirrorStatusList, nil return mirrorStatusList, nil
} }