diff --git a/manager/middleware.go b/manager/middleware.go index df00426..f620261 100644 --- a/manager/middleware.go +++ b/manager/middleware.go @@ -1,6 +1,9 @@ package manager import ( + "fmt" + "net/http" + "github.com/gin-gonic/gin" ) @@ -14,3 +17,17 @@ func contextErrorLogger(c *gin.Context) { // pass on to the next middleware in chain c.Next() } + +func (s *managerServer) workerIDValidator(c *gin.Context) { + workerID := c.Param("id") + _, err := s.adapter.GetWorker(workerID) + if err != nil { + // no worker named `workerID` exists + err := fmt.Errorf("invalid workerID %s", workerID) + s.returnErrJSON(c, http.StatusBadRequest, err) + c.Abort() + return + } + // pass on to the next middleware in chain + c.Next() +} diff --git a/manager/server.go b/manager/server.go index 828e975..70a2b00 100644 --- a/manager/server.go +++ b/manager/server.go @@ -20,10 +20,8 @@ const ( ) type worker struct { - // worker name - id string - // session token - token string + ID string `json:"id"` // worker name + Token string `json:"token"` // session token } var ( @@ -64,7 +62,7 @@ func (s *managerServer) listWorkers(c *gin.Context) { } for _, w := range workers { workerInfos = append(workerInfos, - WorkerInfoMsg{w.id}) + WorkerInfoMsg{w.ID}) } c.JSON(http.StatusOK, workerInfos) } @@ -85,7 +83,7 @@ func (s *managerServer) registerWorker(c *gin.Context) { // create workerCmd channel for this worker workerChannelMu.Lock() defer workerChannelMu.Unlock() - workerChannels[_worker.id] = make(chan WorkerCmd, maxQueuedCmdNum) + workerChannels[_worker.ID] = make(chan WorkerCmd, maxQueuedCmdNum) c.JSON(http.StatusOK, newWorker) } @@ -200,8 +198,12 @@ func makeHTTPServer(debug bool) *managerServer { gin.Default(), nil, } + + // common log middleware + s.Use(contextErrorLogger) + s.GET("/ping", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"msg": "pong"}) + c.JSON(http.StatusOK, gin.H{_infoKey: "pong"}) }) // list jobs, status page s.GET("/jobs", s.listAllJobs) @@ -209,15 +211,17 @@ func makeHTTPServer(debug bool) *managerServer { // list workers s.GET("/workers", s.listWorkers) // worker online - s.POST("/workers/:id", s.registerWorker) + s.POST("/workers", s.registerWorker) + // workerID should be valid in this route group + workerValidateGroup := s.Group("/workers", s.workerIDValidator) // get job list - s.GET("/workers/:id/jobs", s.listJobsOfWorker) + workerValidateGroup.GET(":id/jobs", s.listJobsOfWorker) // post job status - s.POST("/workers/:id/jobs/:job", s.updateJobOfWorker) + workerValidateGroup.POST(":id/jobs/:job", s.updateJobOfWorker) // worker command polling - s.GET("/workers/:id/cmd_stream", s.getCmdOfWorker) + workerValidateGroup.GET(":id/cmd_stream", s.getCmdOfWorker) // for tunasynctl to post commands s.POST("/cmd/", s.handleClientCmd) diff --git a/manager/server_test.go b/manager/server_test.go index 09ca72c..cfc229b 100644 --- a/manager/server_test.go +++ b/manager/server_test.go @@ -1,6 +1,7 @@ package manager import ( + "bytes" "encoding/json" "fmt" "io/ioutil" @@ -11,8 +12,146 @@ import ( "time" . "github.com/smartystreets/goconvey/convey" + . "github.com/tuna/tunasync/internal" ) +const ( + _magicBadWorkerID = "magic_bad_worker_id" +) + +func postJSON(url string, obj interface{}) (*http.Response, error) { + b := new(bytes.Buffer) + json.NewEncoder(b).Encode(obj) + return http.Post(url, "application/json; charset=utf-8", b) +} + +func TestHTTPServer(t *testing.T) { + Convey("HTTP server should work", t, func() { + InitLogger(true, true, false) + s := makeHTTPServer(false) + So(s, ShouldNotBeNil) + s.setDBAdapter(&mockDBAdapter{ + workerStore: map[string]worker{ + _magicBadWorkerID: worker{ + ID: _magicBadWorkerID, + }}, + statusStore: make(map[string]mirrorStatus), + }) + port := rand.Intn(10000) + 20000 + baseURL := fmt.Sprintf("http://127.0.0.1:%d", port) + go func() { + s.Run(fmt.Sprintf("127.0.0.1:%d", port)) + }() + time.Sleep(50 * time.Microsecond) + resp, err := http.Get(baseURL + "/ping") + So(err, ShouldBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusOK) + So(resp.Header.Get("Content-Type"), ShouldEqual, "application/json; charset=utf-8") + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + So(err, ShouldBeNil) + var p map[string]string + err = json.Unmarshal(body, &p) + So(err, ShouldBeNil) + So(p[_infoKey], ShouldEqual, "pong") + + Convey("when database fail", func() { + resp, err := http.Get(fmt.Sprintf("%s/workers/%s/jobs", baseURL, _magicBadWorkerID)) + So(err, ShouldBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError) + defer resp.Body.Close() + var msg map[string]string + err = json.NewDecoder(resp.Body).Decode(&msg) + So(err, ShouldBeNil) + So(msg[_errorKey], ShouldEqual, fmt.Sprintf("failed to list jobs of worker %s: %s", _magicBadWorkerID, "database fail")) + }) + + Convey("when register a worker", func() { + w := worker{ + ID: "test_worker1", + } + resp, err := postJSON(baseURL+"/workers", w) + So(err, ShouldBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusOK) + + Convey("list all workers", func() { + So(err, ShouldBeNil) + resp, err := http.Get(baseURL + "/workers") + So(err, ShouldBeNil) + defer resp.Body.Close() + var actualResponseObj []WorkerInfoMsg + err = json.NewDecoder(resp.Body).Decode(&actualResponseObj) + So(err, ShouldBeNil) + So(len(actualResponseObj), ShouldEqual, 2) + }) + + Convey("update mirror status of a existed worker", func() { + status := mirrorStatus{ + Name: "arch-sync1", + Worker: "test_worker1", + IsMaster: true, + Status: Success, + LastUpdate: time.Now(), + Upstream: "mirrors.tuna.tsinghua.edu.cn", + Size: "3GB", + } + resp, err := postJSON(fmt.Sprintf("%s/workers/%s/jobs/%s", baseURL, status.Worker, status.Name), status) + So(err, ShouldBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusOK) + + Convey("list mirror status of an existed worker", func() { + + expectedResponse, err := json.Marshal([]mirrorStatus{status}) + So(err, ShouldBeNil) + resp, err := http.Get(baseURL + "/workers/test_worker1/jobs") + So(err, ShouldBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusOK) + // err = json.NewDecoder(resp.Body).Decode(&mirrorStatusList) + body, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + So(err, ShouldBeNil) + So(strings.TrimSpace(string(body)), ShouldEqual, string(expectedResponse)) + }) + + Convey("list all job status of all workers", func() { + expectedResponse, err := json.Marshal([]mirrorStatus{status}) + So(err, ShouldBeNil) + resp, err := http.Get(baseURL + "/jobs") + So(err, ShouldBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusOK) + body, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + So(err, ShouldBeNil) + So(strings.TrimSpace(string(body)), ShouldEqual, string(expectedResponse)) + + }) + }) + + Convey("update mirror status of an inexisted worker", func() { + invalidWorker := "test_worker2" + status := mirrorStatus{ + Name: "arch-sync2", + Worker: invalidWorker, + IsMaster: true, + Status: Success, + LastUpdate: time.Now(), + Upstream: "mirrors.tuna.tsinghua.edu.cn", + Size: "4GB", + } + resp, err := postJSON(fmt.Sprintf("%s/workers/%s/jobs/%s", + baseURL, status.Worker, status.Name), status) + So(err, ShouldBeNil) + So(resp.StatusCode, ShouldEqual, http.StatusBadRequest) + defer resp.Body.Close() + var msg map[string]string + err = json.NewDecoder(resp.Body).Decode(&msg) + So(err, ShouldBeNil) + So(msg[_errorKey], ShouldEqual, "invalid workerID "+invalidWorker) + }) + }) + }) +} + type mockDBAdapter struct { workerStore map[string]worker statusStore map[string]mirrorStatus @@ -31,23 +170,22 @@ func (b *mockDBAdapter) ListWorkers() ([]worker, error) { func (b *mockDBAdapter) GetWorker(workerID string) (worker, error) { w, ok := b.workerStore[workerID] if !ok { - return worker{}, fmt.Errorf("inexist workerId") + return worker{}, fmt.Errorf("invalid workerId") } return w, nil } func (b *mockDBAdapter) CreateWorker(w worker) (worker, error) { - _, ok := b.workerStore[w.id] - if ok { - return worker{}, fmt.Errorf("duplicate worker name") - } - b.workerStore[w.id] = w + // _, ok := b.workerStore[w.ID] + // if ok { + // return worker{}, fmt.Errorf("duplicate worker name") + // } + b.workerStore[w.ID] = w return w, nil } func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (mirrorStatus, error) { - // TODO: need to check worker exist first - id := workerID + "/" + mirrorID + id := mirrorID + "/" + workerID status, ok := b.statusStore[id] if !ok { return mirrorStatus{}, fmt.Errorf("no mirror %s exists in worker %s", mirrorID, workerID) @@ -56,13 +194,22 @@ func (b *mockDBAdapter) GetMirrorStatus(workerID, mirrorID string) (mirrorStatus } func (b *mockDBAdapter) UpdateMirrorStatus(workerID, mirrorID string, status mirrorStatus) (mirrorStatus, error) { - id := workerID + "/" + mirrorID + // if _, ok := b.workerStore[workerID]; !ok { + // // unregistered worker + // return mirrorStatus{}, fmt.Errorf("invalid workerID %s", workerID) + // } + + id := mirrorID + "/" + workerID b.statusStore[id] = status return status, nil } func (b *mockDBAdapter) ListMirrorStatus(workerID string) ([]mirrorStatus, error) { var mirrorStatusList []mirrorStatus + // simulating a database fail + if workerID == _magicBadWorkerID { + return []mirrorStatus{}, fmt.Errorf("database fail") + } for k, v := range b.statusStore { if wID := strings.Split(k, "/")[1]; wID == workerID { mirrorStatusList = append(mirrorStatusList, v) @@ -79,26 +226,6 @@ func (b *mockDBAdapter) ListAllMirrorStatus() ([]mirrorStatus, error) { return mirrorStatusList, nil } -func TestHTTPServer(t *testing.T) { - Convey("HTTP server should work", t, func() { - s := makeHTTPServer(false) - So(s, ShouldNotBeNil) - port := rand.Intn(10000) + 20000 - go func() { - s.Run(fmt.Sprintf("127.0.0.1:%d", port)) - }() - time.Sleep(50 * time.Microsecond) - resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/ping", port)) - So(err, ShouldBeNil) - So(resp.StatusCode, ShouldEqual, http.StatusOK) - So(resp.Header.Get("Content-Type"), ShouldEqual, "application/json; charset=utf-8") - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - So(err, ShouldBeNil) - var p map[string]string - err = json.Unmarshal(body, &p) - So(err, ShouldBeNil) - So(p["msg"], ShouldEqual, "pong") - }) - +func (b *mockDBAdapter) Close() error { + return nil }