mirror of
https://github.com/tuna/tunasync.git
synced 2025-06-15 14:12:47 +00:00
tests(manager): add tests for server.go, validate workerID in middleware
This commit is contained in:
parent
02bb8c16ab
commit
401b6a694e
@ -1,6 +1,9 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,3 +17,17 @@ func contextErrorLogger(c *gin.Context) {
|
|||||||
// pass on to the next middleware in chain
|
// pass on to the next middleware in chain
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *managerServer) workerIDValidator(c *gin.Context) {
|
||||||
|
workerID := c.Param("id")
|
||||||
|
_, err := s.adapter.GetWorker(workerID)
|
||||||
|
if err != nil {
|
||||||
|
// no worker named `workerID` exists
|
||||||
|
err := fmt.Errorf("invalid workerID %s", workerID)
|
||||||
|
s.returnErrJSON(c, http.StatusBadRequest, err)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// pass on to the next middleware in chain
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
@ -20,10 +20,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type worker struct {
|
type worker struct {
|
||||||
// worker name
|
ID string `json:"id"` // worker name
|
||||||
id string
|
Token string `json:"token"` // session token
|
||||||
// session token
|
|
||||||
token string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -64,7 +62,7 @@ func (s *managerServer) listWorkers(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
for _, w := range workers {
|
for _, w := range workers {
|
||||||
workerInfos = append(workerInfos,
|
workerInfos = append(workerInfos,
|
||||||
WorkerInfoMsg{w.id})
|
WorkerInfoMsg{w.ID})
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, workerInfos)
|
c.JSON(http.StatusOK, workerInfos)
|
||||||
}
|
}
|
||||||
@ -85,7 +83,7 @@ func (s *managerServer) registerWorker(c *gin.Context) {
|
|||||||
// create workerCmd channel for this worker
|
// create workerCmd channel for this worker
|
||||||
workerChannelMu.Lock()
|
workerChannelMu.Lock()
|
||||||
defer workerChannelMu.Unlock()
|
defer workerChannelMu.Unlock()
|
||||||
workerChannels[_worker.id] = make(chan WorkerCmd, maxQueuedCmdNum)
|
workerChannels[_worker.ID] = make(chan WorkerCmd, maxQueuedCmdNum)
|
||||||
c.JSON(http.StatusOK, newWorker)
|
c.JSON(http.StatusOK, newWorker)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,8 +198,12 @@ func makeHTTPServer(debug bool) *managerServer {
|
|||||||
gin.Default(),
|
gin.Default(),
|
||||||
nil,
|
nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// common log middleware
|
||||||
|
s.Use(contextErrorLogger)
|
||||||
|
|
||||||
s.GET("/ping", func(c *gin.Context) {
|
s.GET("/ping", func(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{"msg": "pong"})
|
c.JSON(http.StatusOK, gin.H{_infoKey: "pong"})
|
||||||
})
|
})
|
||||||
// list jobs, status page
|
// list jobs, status page
|
||||||
s.GET("/jobs", s.listAllJobs)
|
s.GET("/jobs", s.listAllJobs)
|
||||||
@ -209,15 +211,17 @@ func makeHTTPServer(debug bool) *managerServer {
|
|||||||
// list workers
|
// list workers
|
||||||
s.GET("/workers", s.listWorkers)
|
s.GET("/workers", s.listWorkers)
|
||||||
// worker online
|
// worker online
|
||||||
s.POST("/workers/:id", s.registerWorker)
|
s.POST("/workers", s.registerWorker)
|
||||||
|
|
||||||
|
// workerID should be valid in this route group
|
||||||
|
workerValidateGroup := s.Group("/workers", s.workerIDValidator)
|
||||||
// get job list
|
// get job list
|
||||||
s.GET("/workers/:id/jobs", s.listJobsOfWorker)
|
workerValidateGroup.GET(":id/jobs", s.listJobsOfWorker)
|
||||||
// post job status
|
// post job status
|
||||||
s.POST("/workers/:id/jobs/:job", s.updateJobOfWorker)
|
workerValidateGroup.POST(":id/jobs/:job", s.updateJobOfWorker)
|
||||||
|
|
||||||
// worker command polling
|
// worker command polling
|
||||||
s.GET("/workers/:id/cmd_stream", s.getCmdOfWorker)
|
workerValidateGroup.GET(":id/cmd_stream", s.getCmdOfWorker)
|
||||||
|
|
||||||
// for tunasynctl to post commands
|
// for tunasynctl to post commands
|
||||||
s.POST("/cmd/", s.handleClientCmd)
|
s.POST("/cmd/", s.handleClientCmd)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@ -11,8 +12,146 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
. "github.com/tuna/tunasync/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
_magicBadWorkerID = "magic_bad_worker_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
func postJSON(url string, obj interface{}) (*http.Response, error) {
|
||||||
|
b := new(bytes.Buffer)
|
||||||
|
json.NewEncoder(b).Encode(obj)
|
||||||
|
return http.Post(url, "application/json; charset=utf-8", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPServer(t *testing.T) {
|
||||||
|
Convey("HTTP server should work", t, func() {
|
||||||
|
InitLogger(true, true, false)
|
||||||
|
s := makeHTTPServer(false)
|
||||||
|
So(s, ShouldNotBeNil)
|
||||||
|
s.setDBAdapter(&mockDBAdapter{
|
||||||
|
workerStore: map[string]worker{
|
||||||
|
_magicBadWorkerID: worker{
|
||||||
|
ID: _magicBadWorkerID,
|
||||||
|
}},
|
||||||
|
statusStore: make(map[string]mirrorStatus),
|
||||||
|
})
|
||||||
|
port := rand.Intn(10000) + 20000
|
||||||
|
baseURL := fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||||
|
go func() {
|
||||||
|
s.Run(fmt.Sprintf("127.0.0.1:%d", port))
|
||||||
|
}()
|
||||||
|
time.Sleep(50 * time.Microsecond)
|
||||||
|
resp, err := http.Get(baseURL + "/ping")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
||||||
|
So(resp.Header.Get("Content-Type"), ShouldEqual, "application/json; charset=utf-8")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
var p map[string]string
|
||||||
|
err = json.Unmarshal(body, &p)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(p[_infoKey], ShouldEqual, "pong")
|
||||||
|
|
||||||
|
Convey("when database fail", func() {
|
||||||
|
resp, err := http.Get(fmt.Sprintf("%s/workers/%s/jobs", baseURL, _magicBadWorkerID))
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
var msg map[string]string
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&msg)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(msg[_errorKey], ShouldEqual, fmt.Sprintf("failed to list jobs of worker %s: %s", _magicBadWorkerID, "database fail"))
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("when register a worker", func() {
|
||||||
|
w := worker{
|
||||||
|
ID: "test_worker1",
|
||||||
|
}
|
||||||
|
resp, err := postJSON(baseURL+"/workers", w)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
|
Convey("list all workers", func() {
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
resp, err := http.Get(baseURL + "/workers")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
var actualResponseObj []WorkerInfoMsg
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&actualResponseObj)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(len(actualResponseObj), ShouldEqual, 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("update mirror status of a existed worker", func() {
|
||||||
|
status := mirrorStatus{
|
||||||
|
Name: "arch-sync1",
|
||||||
|
Worker: "test_worker1",
|
||||||
|
IsMaster: true,
|
||||||
|
Status: Success,
|
||||||
|
LastUpdate: time.Now(),
|
||||||
|
Upstream: "mirrors.tuna.tsinghua.edu.cn",
|
||||||
|
Size: "3GB",
|
||||||
|
}
|
||||||
|
resp, err := postJSON(fmt.Sprintf("%s/workers/%s/jobs/%s", baseURL, status.Worker, status.Name), status)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
||||||
|
|
||||||
|
Convey("list mirror status of an existed worker", func() {
|
||||||
|
|
||||||
|
expectedResponse, err := json.Marshal([]mirrorStatus{status})
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
resp, err := http.Get(baseURL + "/workers/test_worker1/jobs")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
||||||
|
// err = json.NewDecoder(resp.Body).Decode(&mirrorStatusList)
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(strings.TrimSpace(string(body)), ShouldEqual, string(expectedResponse))
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("list all job status of all workers", func() {
|
||||||
|
expectedResponse, err := json.Marshal([]mirrorStatus{status})
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
resp, err := http.Get(baseURL + "/jobs")
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(strings.TrimSpace(string(body)), ShouldEqual, string(expectedResponse))
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("update mirror status of an inexisted worker", func() {
|
||||||
|
invalidWorker := "test_worker2"
|
||||||
|
status := mirrorStatus{
|
||||||
|
Name: "arch-sync2",
|
||||||
|
Worker: invalidWorker,
|
||||||
|
IsMaster: true,
|
||||||
|
Status: Success,
|
||||||
|
LastUpdate: time.Now(),
|
||||||
|
Upstream: "mirrors.tuna.tsinghua.edu.cn",
|
||||||
|
Size: "4GB",
|
||||||
|
}
|
||||||
|
resp, err := postJSON(fmt.Sprintf("%s/workers/%s/jobs/%s",
|
||||||
|
baseURL, status.Worker, status.Name), status)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(resp.StatusCode, ShouldEqual, http.StatusBadRequest)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
var msg map[string]string
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&msg)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(msg[_errorKey], ShouldEqual, "invalid workerID "+invalidWorker)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type mockDBAdapter struct {
|
type mockDBAdapter struct {
|
||||||
workerStore map[string]worker
|
workerStore map[string]worker
|
||||||
statusStore map[string]mirrorStatus
|
statusStore map[string]mirrorStatus
|
||||||
@ -31,23 +170,22 @@ func (b *mockDBAdapter) ListWorkers() ([]worker, error) {
|
|||||||
func (b *mockDBAdapter) GetWorker(workerID string) (worker, error) {
|
func (b *mockDBAdapter) GetWorker(workerID string) (worker, error) {
|
||||||
w, ok := b.workerStore[workerID]
|
w, ok := b.workerStore[workerID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return worker{}, fmt.Errorf("inexist workerId")
|
return worker{}, fmt.Errorf("invalid workerId")
|
||||||
}
|
}
|
||||||
return w, nil
|
return w, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *mockDBAdapter) CreateWorker(w worker) (worker, error) {
|
func (b *mockDBAdapter) CreateWorker(w worker) (worker, error) {
|
||||||
_, ok := b.workerStore[w.id]
|
// _, ok := b.workerStore[w.ID]
|
||||||
if ok {
|
// if ok {
|
||||||
return worker{}, fmt.Errorf("duplicate worker name")
|
// return worker{}, fmt.Errorf("duplicate worker name")
|
||||||
}
|
// }
|
||||||
b.workerStore[w.id] = w
|
b.workerStore[w.ID] = w
|
||||||
return w, nil
|
return w, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (mirrorStatus, error) {
|
func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (mirrorStatus, error) {
|
||||||
// TODO: need to check worker exist first
|
id := mirrorID + "/" + workerID
|
||||||
id := workerID + "/" + mirrorID
|
|
||||||
status, ok := b.statusStore[id]
|
status, ok := b.statusStore[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return mirrorStatus{}, fmt.Errorf("no mirror %s exists in worker %s", mirrorID, workerID)
|
return mirrorStatus{}, fmt.Errorf("no mirror %s exists in worker %s", mirrorID, workerID)
|
||||||
@ -56,13 +194,22 @@ func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (mirrorStatus
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *mockDBAdapter) UpdateMirrorStatus(workerID, mirrorID string, status mirrorStatus) (mirrorStatus, error) {
|
func (b *mockDBAdapter) UpdateMirrorStatus(workerID, mirrorID string, status mirrorStatus) (mirrorStatus, error) {
|
||||||
id := workerID + "/" + mirrorID
|
// if _, ok := b.workerStore[workerID]; !ok {
|
||||||
|
// // unregistered worker
|
||||||
|
// return mirrorStatus{}, fmt.Errorf("invalid workerID %s", workerID)
|
||||||
|
// }
|
||||||
|
|
||||||
|
id := mirrorID + "/" + workerID
|
||||||
b.statusStore[id] = status
|
b.statusStore[id] = status
|
||||||
return status, nil
|
return status, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *mockDBAdapter) ListMirrorStatus(workerID string) ([]mirrorStatus, error) {
|
func (b *mockDBAdapter) ListMirrorStatus(workerID string) ([]mirrorStatus, error) {
|
||||||
var mirrorStatusList []mirrorStatus
|
var mirrorStatusList []mirrorStatus
|
||||||
|
// simulating a database fail
|
||||||
|
if workerID == _magicBadWorkerID {
|
||||||
|
return []mirrorStatus{}, fmt.Errorf("database fail")
|
||||||
|
}
|
||||||
for k, v := range b.statusStore {
|
for k, v := range b.statusStore {
|
||||||
if wID := strings.Split(k, "/")[1]; wID == workerID {
|
if wID := strings.Split(k, "/")[1]; wID == workerID {
|
||||||
mirrorStatusList = append(mirrorStatusList, v)
|
mirrorStatusList = append(mirrorStatusList, v)
|
||||||
@ -79,26 +226,6 @@ func (b *mockDBAdapter) ListAllMirrorStatus() ([]mirrorStatus, error) {
|
|||||||
return mirrorStatusList, nil
|
return mirrorStatusList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPServer(t *testing.T) {
|
func (b *mockDBAdapter) Close() error {
|
||||||
Convey("HTTP server should work", t, func() {
|
return nil
|
||||||
s := makeHTTPServer(false)
|
|
||||||
So(s, ShouldNotBeNil)
|
|
||||||
port := rand.Intn(10000) + 20000
|
|
||||||
go func() {
|
|
||||||
s.Run(fmt.Sprintf("127.0.0.1:%d", port))
|
|
||||||
}()
|
|
||||||
time.Sleep(50 * time.Microsecond)
|
|
||||||
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/ping", port))
|
|
||||||
So(err, ShouldBeNil)
|
|
||||||
So(resp.StatusCode, ShouldEqual, http.StatusOK)
|
|
||||||
So(resp.Header.Get("Content-Type"), ShouldEqual, "application/json; charset=utf-8")
|
|
||||||
defer resp.Body.Close()
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
So(err, ShouldBeNil)
|
|
||||||
var p map[string]string
|
|
||||||
err = json.Unmarshal(body, &p)
|
|
||||||
So(err, ShouldBeNil)
|
|
||||||
So(p["msg"], ShouldEqual, "pong")
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user