--- /dev/null
+GOPKGSRC ?= $(wildcard *.go)
+GOTESTSRC ?= $(wildcard *_test.go)
+GOSRC ?= $(filter-out $(GOTESTSRC),$(GOPKGSRC))
+GOTESTDATA ?= $(shell test ! -d testdata || find testdata)
+
+GOMAINSRC ?= $(wildcard cmd/*/main.go)
+GOBIN ?= $(patsubst cmd/%/main.go,%,$(GOMAINSRC))
+
+GOCOVERAGE ?= go.coverage
+GOBENCHMARK ?= go.benchmark
+
+%: cmd/%/main.go $(GOPKGSRC)
+ go build -o $@ $(GOBUILDFLAGS) $<
+
+test:: $(GOCOVERAGE).out
+
+$(GOCOVERAGE).out: GOTESTFLAGS ?= -race
+$(GOCOVERAGE).out: $(GOPKGSRC) $(GOTESTSRC) $(GOTESTDATA) go.mod
+ go test -coverprofile=$@ $(GOTESTFLAGS) ./...
+
+$(GOCOVERAGE).html: $(GOCOVERAGE).out
+ go tool cover -html $< -o $@
+
+bench:: $(GOBENCHMARK)
+
+$(GOBENCHMARK): GOBENCHFLAGS ?= -benchmem
+$(GOBENCHMARK): $(GOPKGSRC) $(GOTESTSRC) $(GOTESTDATA) go.mod
+ go test -bench . $(GOBENCHFLAGS) | tee $@
+
+clean::
+ $(RM) $(GOCOVERAGE).{html,out} $(GOBENCHMARK) $(GOBIN)
+
+.PHONY: clean test $(GOCOVERAGE)
+.DELETE_ON_ERROR:
--- /dev/null
+package signalcontext
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/signal"
+ "sync"
+ "time"
+)
+
+// A SignalContext is an implementation of the context.Context interface
+// which completes when a signal (e.g. os.Interrupt) is received.
+//
+// SignalContexts should be created via the UntilSignal function.
+type SignalContext struct {
+ parent context.Context
+ done chan struct{}
+ err error
+
+ // The mutex synchronizes access to err and clearing the
+ // internal Signal channel after initialization.
+ m sync.Mutex
+ c chan os.Signal
+}
+
+// UntilSignal returns a new SignalContext which will complete when the
+// parent does or when any of the specified signals are received.
+func UntilSignal(parent context.Context, sig ...os.Signal) *SignalContext {
+ ctx := new(SignalContext)
+ ctx.parent = parent
+ ctx.done = make(chan struct{})
+
+ if err := parent.Err(); err != nil {
+ close(ctx.done)
+ ctx.err = err
+ return ctx
+ }
+
+ ctx.c = make(chan os.Signal, 1)
+ signal.Notify(ctx.c, sig...)
+ go ctx.wait(sig...)
+ return ctx
+}
+
+func (s *SignalContext) wait(sig ...os.Signal) {
+ var err error
+ select {
+ case <-s.parent.Done():
+ err = s.parent.Err()
+ case v := <-s.c:
+ if v != nil {
+ err = SignalError{v}
+ }
+ }
+ signal.Stop(s.c)
+ s.m.Lock()
+ if s.err == nil {
+ s.err = err
+ }
+ close(s.c)
+ s.c = nil
+ s.m.Unlock()
+ close(s.done)
+}
+
+// Cancel cancels this context, if it hasn’t already been canceled or
+// received a signal. (If it has, this is safe but has no effect.)
+func (s *SignalContext) Cancel() {
+ s.m.Lock()
+ if s.c != nil {
+ s.err = context.Canceled
+ select {
+ case s.c <- nil:
+ default:
+ }
+ }
+ s.m.Unlock()
+}
+
+// Deadline implements context.Context; a SignalContext’s deadline is
+// that of its parent.
+func (s *SignalContext) Deadline() (time.Time, bool) {
+ return s.parent.Deadline()
+}
+
+// Value implements context.Context; any value is that of its parent.
+func (s *SignalContext) Value(key interface{}) interface{} {
+ return s.parent.Value(key)
+}
+
+// Done implements context.Context.
+func (s *SignalContext) Done() <-chan struct{} {
+ return s.done
+}
+
+// Err implements context.Context; it returns context.Canceled if the
+// context was canceled; a SignalError if the context completed due to a
+// signal; the parent’s error if the parent was done before either of
+// those; or nil if the context is not yet done.
+func (s *SignalContext) Err() error {
+ s.m.Lock()
+ err := s.err
+ s.m.Unlock()
+ return err
+}
+
+// A SignalError will be returned by a SignalContext’s Err() method when
+// it was finished due to a signal (rather than e.g. parent
+// cancellation).
+type SignalError struct {
+ os.Signal
+}
+
+func (e SignalError) Error() string {
+ return e.String()
+}
+
+func (e SignalError) String() string {
+ return fmt.Sprintf("received signal: %s", e.Signal)
+}
--- /dev/null
+package signalcontext
+
+import (
+ "context"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestReceivesSignal(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+ assert.EqualError(t, ctx.Err(),
+ "received signal: "+syscall.SIGUSR2.String())
+}
+
+func TestForwardsParent(t *testing.T) {
+ parent, _ := context.WithTimeout(
+ context.WithValue(context.Background(), t, "test"),
+ time.Millisecond)
+ ctx := UntilSignal(parent, syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ dl, ok := ctx.Deadline()
+ assert.True(t, ok)
+ assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
+ assert.EqualValues(t, "test", ctx.Value(t))
+ <-ctx.Done()
+ assert.Equal(t, context.DeadlineExceeded, ctx.Err())
+}
+
+func TestChildForwardsErr(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ child, cancel := context.WithTimeout(ctx, time.Second)
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-child.Done()
+ <-ctx.Done()
+ cancel()
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, child.Err())
+}
+
+func TestSignalAfterCancel(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ ctx.Cancel()
+ <-ctx.Done()
+ assert.Equal(t, context.Canceled, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ time.Sleep(5 * time.Millisecond)
+ assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func TestCancelAfterSignal(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+ ctx.Cancel()
+ time.Sleep(5 * time.Millisecond)
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+}
+
+func TestImmediateCompletion(t *testing.T) {
+ parent, cancel := context.WithCancel(context.Background())
+ cancel()
+ <-parent.Done()
+ ctx := UntilSignal(parent, syscall.SIGUSR2)
+ // peek inside to be certain we never set up the signal channel.
+ assert.Nil(t, ctx.c)
+ select {
+ case _, ok := <-ctx.Done():
+ assert.False(t, ok, "Done() should be closed")
+ default:
+ assert.False(t, true, "context should be complete")
+ }
+ assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func BenchmarkReceivesSignal(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ }
+}
+
+func BenchmarkCancelChildren(b *testing.B) {
+ children := make([]context.Context, b.N)
+ cancels := make([]context.CancelFunc, b.N)
+ b.ResetTimer()
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ for i := range children {
+ children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
+ }
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ for i := range children {
+ <-children[i].Done()
+ }
+ b.StopTimer()
+ for i := range cancels {
+ cancels[i]()
+ }
+}
+
+func BenchmarkCanceledAsChild(b *testing.B) {
+ children := make([]context.Context, b.N)
+ parent, cancel := context.WithCancel(context.Background())
+ b.ResetTimer()
+ for i := range children {
+ children[i] = UntilSignal(parent, syscall.SIGUSR2)
+ }
+ cancel()
+ <-parent.Done()
+ for i := range children {
+ <-children[i].Done()
+ }
+}