diff --git a/worker/cgroup.go b/worker/cgroup.go index 097aaa7..26ae7e9 100644 --- a/worker/cgroup.go +++ b/worker/cgroup.go @@ -2,18 +2,20 @@ package worker import ( "bufio" + "errors" "fmt" "os" "path/filepath" "strconv" "syscall" + "time" "golang.org/x/sys/unix" "github.com/codeskyblue/go-sh" ) -var cgSubsystem string = "cpu" +var cgSubsystem = "cpu" type cgroupHook struct { emptyHook @@ -82,23 +84,43 @@ func (c *cgroupHook) killAll() error { return nil } name := c.provider.Name() - taskFile, err := os.Open(filepath.Join(c.basePath, cgSubsystem, c.baseGroup, name, "tasks")) - if err != nil { - return err + + readTaskList := func() ([]int, error) { + taskList := []int{} + taskFile, err := os.Open(filepath.Join(c.basePath, cgSubsystem, c.baseGroup, name, "tasks")) + if err != nil { + return taskList, err + } + defer taskFile.Close() + + scanner := bufio.NewScanner(taskFile) + for scanner.Scan() { + pid, err := strconv.Atoi(scanner.Text()) + if err != nil { + return taskList, err + } + taskList = append(taskList, pid) + } + return taskList, nil } - defer taskFile.Close() - taskList := []int{} - scanner := bufio.NewScanner(taskFile) - for scanner.Scan() { - pid, err := strconv.Atoi(scanner.Text()) + + for i := 0; i < 4; i++ { + if i == 3 { + return errors.New("Unable to kill all child tasks") + } + taskList, err := readTaskList() if err != nil { return err } - taskList = append(taskList, pid) - } - for _, pid := range taskList { - logger.Debugf("Killing process: %d", pid) - unix.Kill(pid, syscall.SIGKILL) + if len(taskList) == 0 { + return nil + } + for _, pid := range taskList { + logger.Debugf("Killing process: %d", pid) + unix.Kill(pid, syscall.SIGKILL) + } + // sleep 10ms for the first round, and 1.01s, 2.01s, 3.01s for the rest + time.Sleep(time.Duration(i)*time.Second + 10*time.Millisecond) } return nil