diff --git a/runc.go b/runc.go index c5eda63..f7d90d4 100644 --- a/runc.go +++ b/runc.go @@ -327,6 +327,7 @@ func (r *Runc) Delete(context context.Context, id string, opts *DeleteOpts) erro type KillOpts struct { All bool ExtraArgs []string + RawSignal string } func (o *KillOpts) args() (out []string) { @@ -344,10 +345,14 @@ func (r *Runc) Kill(context context.Context, id string, sig int, opts *KillOpts) args := []string{ "kill", } + killSignal := strconv.Itoa(sig) if opts != nil { args = append(args, opts.args()...) + if opts.RawSignal != "" { + killSignal = opts.RawSignal + } } - return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...)) + return r.runOrError(r.command(context, append(args, id, killSignal)...)) } // Stats return the stats for a container like cpu, memory, and io diff --git a/runc_test.go b/runc_test.go index 4b1275a..9a0e666 100644 --- a/runc_test.go +++ b/runc_test.go @@ -21,6 +21,7 @@ import ( "errors" "io/ioutil" "os" + "strings" "sync" "syscall" "testing" @@ -287,9 +288,7 @@ func interrupt(ctx context.Context, t *testing.T, started <-chan int) { } } -// dummySleepRunc creates s simple script that just runs `sleep 10` to replace -// runc for testing process that are longer running. -func dummySleepRunc() (_ string, err error) { +func createScript(content string) (_ string, err error) { fh, err := ioutil.TempFile("", "*.sh") if err != nil { return "", err @@ -299,7 +298,7 @@ func dummySleepRunc() (_ string, err error) { os.Remove(fh.Name()) } }() - _, err = fh.Write([]byte("#!/bin/sh\nexec /bin/sleep 10")) + _, err = fh.Write([]byte(content)) if err != nil { return "", err } @@ -314,6 +313,22 @@ func dummySleepRunc() (_ string, err error) { return fh.Name(), nil } +// dummySleepRunc creates a simple script that just runs `sleep 10` to replace +// runc for testing process that are longer running. +func dummySleepRunc() (_ string, err error) { + return createScript("#!/bin/sh\nexec /bin/sleep 10") +} + +// debugCommand creates a simple script that echos the arguments passed to +// runc, and returns them as part of the error message. +func debugCommand() (string, error) { + return createScript(`#!/bin/sh + echo "$@" + # force non-zero exit code, so that the error message contains the output + exit 1 + `) +} + func TestCreateArgs(t *testing.T) { o := &CreateOpts{} args, err := o.args() @@ -336,3 +351,72 @@ func TestCreateArgs(t *testing.T) { } } + +func TestRuncKill(t *testing.T) { + ctx, timeout := context.WithTimeout(context.Background(), 10*time.Second) + defer timeout() + + dummyCmd, err := debugCommand() + if err != nil { + t.Fatalf("Failed to create dummy debug command: %v", err) + } + defer os.Remove(dummyCmd) + + debugRunc := &Runc{Command: dummyCmd} + + type config struct { + name string + rawSignal string + numericalSignal int + expectedSignal string + } + tests := []config{ + { + name: "Kill sends raw signal", + rawSignal: "SIGTERM", + expectedSignal: "SIGTERM", + }, + { + name: "Kill sends raw signal number", + rawSignal: "15", + expectedSignal: "15", + }, + { + name: "Kill prefers raw signal over numerical signal", + rawSignal: "SIGTERM", + numericalSignal: 9, + expectedSignal: "SIGTERM", + }, + { + name: "Kill prefers raw signal number over numerical signal", + rawSignal: "15", + numericalSignal: 9, + expectedSignal: "15", + }, + { + name: "Kill sends numerical signal when no raw signal specified", + numericalSignal: 9, + expectedSignal: "9", + }, + } + for _, test := range tests { + t.Run(test.name, func(_ *testing.T) { + opts := &KillOpts{ + RawSignal: test.rawSignal, + } + err = debugRunc.Kill(ctx, "fake_id", test.numericalSignal, opts) + if err == nil { + t.Fatal("expected dummy debug command to return error, instead got nil") + } + errorMessage := err.Error() + words := strings.Fields(errorMessage) + if len(words) < 3 { + t.Fatalf("expected dummy debug command to error with the kill command sent, instead got %s", errorMessage) + } + actualSignal := words[len(words)-1] + if actualSignal != test.expectedSignal { + t.Fatalf("expected kill command to send signal %v, instead got %v", test.expectedSignal, actualSignal) + } + }) + } +}