Initial import
authorJoe Wreschnig <joe.wreschnig@gmail.com>
Sat, 13 Jun 2020 07:13:34 +0000 (09:13 +0200)
committerJoe Wreschnig <joe.wreschnig@gmail.com>
Sun, 14 Jun 2020 13:49:51 +0000 (15:49 +0200)
.gitignore [new file with mode: 0644]
Makefile [new file with mode: 0755]
cmd/example/main.go [new file with mode: 0644]
go.mk [new file with mode: 0644]
go.mod [new file with mode: 0644]
go.sum [new file with mode: 0644]
signalcontext.go [new file with mode: 0644]
signalcontext_test.go [new file with mode: 0644]

diff --git a/.gitignore b/.gitignore
new file mode 100644 (file)
index 0000000..510252d
--- /dev/null
@@ -0,0 +1,4 @@
+/example
+/go.benchmark
+/go.coverage.out
+/go.coverage.html
diff --git a/Makefile b/Makefile
new file mode 100755 (executable)
index 0000000..37f78a2
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,7 @@
+#!/usr/bin/make -rf
+
+include ./go.mk
+
+all:: $(GOBIN)
+
+.DEFAULT_GOAL := all
diff --git a/cmd/example/main.go b/cmd/example/main.go
new file mode 100644 (file)
index 0000000..81585b6
--- /dev/null
@@ -0,0 +1,20 @@
+package main
+
+import (
+       "context"
+       "log"
+       "os"
+       "syscall"
+       "time"
+
+       "git.korewanetadesu.com/go-signalcontext"
+)
+
+func main() {
+       log.Print("waiting 10 seconds or until signal...")
+       p, cancel := context.WithTimeout(context.Background(), time.Second*10)
+       ctx := signalcontext.UntilSignal(p, os.Interrupt, syscall.SIGTERM)
+       <-ctx.Done()
+       cancel()
+       log.Print(ctx.Err())
+}
diff --git a/go.mk b/go.mk
new file mode 100644 (file)
index 0000000..8360cfb
--- /dev/null
+++ b/go.mk
@@ -0,0 +1,34 @@
+GOPKGSRC ?= $(wildcard *.go)
+GOTESTSRC ?= $(wildcard *_test.go)
+GOSRC ?= $(filter-out $(GOTESTSRC),$(GOPKGSRC))
+GOTESTDATA ?= $(shell test ! -d testdata || find testdata)
+
+GOMAINSRC ?= $(wildcard cmd/*/main.go)
+GOBIN ?= $(patsubst cmd/%/main.go,%,$(GOMAINSRC))
+
+GOCOVERAGE ?= go.coverage
+GOBENCHMARK ?= go.benchmark
+
+%: cmd/%/main.go $(GOPKGSRC)
+       go build -o $@ $(GOBUILDFLAGS) $<
+
+test:: $(GOCOVERAGE).out
+
+$(GOCOVERAGE).out: GOTESTFLAGS ?= -race
+$(GOCOVERAGE).out: $(GOPKGSRC) $(GOTESTSRC) $(GOTESTDATA) go.mod
+       go test -coverprofile=$@ $(GOTESTFLAGS) ./...
+
+$(GOCOVERAGE).html: $(GOCOVERAGE).out
+       go tool cover -html $< -o $@
+
+bench:: $(GOBENCHMARK)
+
+$(GOBENCHMARK): GOBENCHFLAGS ?= -benchmem
+$(GOBENCHMARK): $(GOPKGSRC) $(GOTESTSRC) $(GOTESTDATA) go.mod
+       go test -bench . $(GOBENCHFLAGS) | tee $@
+
+clean::
+       $(RM) $(GOCOVERAGE).{html,out} $(GOBENCHMARK) $(GOBIN)
+
+.PHONY: clean test $(GOCOVERAGE)
+.DELETE_ON_ERROR:
diff --git a/go.mod b/go.mod
new file mode 100644 (file)
index 0000000..ec2df26
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,5 @@
+module git.korewanetadesu.com/go-signalcontext
+
+go 1.13
+
+require github.com/stretchr/testify v1.6.1
diff --git a/go.sum b/go.sum
new file mode 100644 (file)
index 0000000..afe7890
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,11 @@
+github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/signalcontext.go b/signalcontext.go
new file mode 100644 (file)
index 0000000..d718e11
--- /dev/null
@@ -0,0 +1,121 @@
+package signalcontext
+
+import (
+       "context"
+       "fmt"
+       "os"
+       "os/signal"
+       "sync"
+       "time"
+)
+
+// A SignalContext is an implementation of the context.Context interface
+// which completes when a signal (e.g. os.Interrupt) is received.
+//
+// SignalContexts should be created via the UntilSignal function.
+type SignalContext struct {
+       parent context.Context
+       done   chan struct{}
+       err    error
+
+       // The mutex synchronizes access to err and clearing the
+       // internal Signal channel after initialization.
+       m sync.Mutex
+       c chan os.Signal
+}
+
+// UntilSignal returns a new SignalContext which will complete when the
+// parent does or when any of the specified signals are received.
+func UntilSignal(parent context.Context, sig ...os.Signal) *SignalContext {
+       ctx := new(SignalContext)
+       ctx.parent = parent
+       ctx.done = make(chan struct{})
+
+       if err := parent.Err(); err != nil {
+               close(ctx.done)
+               ctx.err = err
+               return ctx
+       }
+
+       ctx.c = make(chan os.Signal, 1)
+       signal.Notify(ctx.c, sig...)
+       go ctx.wait(sig...)
+       return ctx
+}
+
+func (s *SignalContext) wait(sig ...os.Signal) {
+       var err error
+       select {
+       case <-s.parent.Done():
+               err = s.parent.Err()
+       case v := <-s.c:
+               if v != nil {
+                       err = SignalError{v}
+               }
+       }
+       signal.Stop(s.c)
+       s.m.Lock()
+       if s.err == nil {
+               s.err = err
+       }
+       close(s.c)
+       s.c = nil
+       s.m.Unlock()
+       close(s.done)
+}
+
+// Cancel cancels this context, if it hasn’t already been canceled or
+// received a signal. (If it has, this is safe but has no effect.)
+func (s *SignalContext) Cancel() {
+       s.m.Lock()
+       if s.c != nil {
+               s.err = context.Canceled
+               select {
+               case s.c <- nil:
+               default:
+               }
+       }
+       s.m.Unlock()
+}
+
+// Deadline implements context.Context; a SignalContext’s deadline is
+// that of its parent.
+func (s *SignalContext) Deadline() (time.Time, bool) {
+       return s.parent.Deadline()
+}
+
+// Value implements context.Context; any value is that of its parent.
+func (s *SignalContext) Value(key interface{}) interface{} {
+       return s.parent.Value(key)
+}
+
+// Done implements context.Context.
+func (s *SignalContext) Done() <-chan struct{} {
+       return s.done
+}
+
+// Err implements context.Context; it returns context.Canceled if the
+// context was canceled; a SignalError if the context completed due to a
+// signal; the parent’s error if the parent was done before either of
+// those; or nil if the context is not yet done.
+func (s *SignalContext) Err() error {
+       s.m.Lock()
+       err := s.err
+       s.m.Unlock()
+       return err
+}
+
+// A SignalError will be returned by a SignalContext’s Err() method when
+// it was finished due to a signal (rather than e.g. parent
+// cancellation).
+type SignalError struct {
+       os.Signal
+}
+
+func (e SignalError) Error() string {
+       return e.String()
+}
+
+func (e SignalError) String() string {
+       return fmt.Sprintf("received signal: %s", e.Signal)
+}
diff --git a/signalcontext_test.go b/signalcontext_test.go
new file mode 100644 (file)
index 0000000..07109c5
--- /dev/null
@@ -0,0 +1,124 @@
+package signalcontext
+
+import (
+       "context"
+       "syscall"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestReceivesSignal(t *testing.T) {
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       assert.NoError(t, ctx.Err())
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       <-ctx.Done()
+       assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+       assert.EqualError(t, ctx.Err(),
+               "received signal: "+syscall.SIGUSR2.String())
+}
+
+func TestForwardsParent(t *testing.T) {
+       parent, _ := context.WithTimeout(
+               context.WithValue(context.Background(), t, "test"),
+               time.Millisecond)
+       ctx := UntilSignal(parent, syscall.SIGUSR2)
+       assert.NoError(t, ctx.Err())
+       dl, ok := ctx.Deadline()
+       assert.True(t, ok)
+       assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
+       assert.EqualValues(t, "test", ctx.Value(t))
+       <-ctx.Done()
+       assert.Equal(t, context.DeadlineExceeded, ctx.Err())
+}
+
+func TestChildForwardsErr(t *testing.T) {
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       child, cancel := context.WithTimeout(ctx, time.Second)
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       <-child.Done()
+       <-ctx.Done()
+       cancel()
+       assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+       assert.Equal(t, SignalError{syscall.SIGUSR2}, child.Err())
+}
+
+func TestSignalAfterCancel(t *testing.T) {
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       assert.NoError(t, ctx.Err())
+       ctx.Cancel()
+       <-ctx.Done()
+       assert.Equal(t, context.Canceled, ctx.Err())
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       time.Sleep(5 * time.Millisecond)
+       assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func TestCancelAfterSignal(t *testing.T) {
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       assert.NoError(t, ctx.Err())
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       <-ctx.Done()
+       assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+       ctx.Cancel()
+       time.Sleep(5 * time.Millisecond)
+       assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+}
+
+func TestImmediateCompletion(t *testing.T) {
+       parent, cancel := context.WithCancel(context.Background())
+       cancel()
+       <-parent.Done()
+       ctx := UntilSignal(parent, syscall.SIGUSR2)
+       // peek inside to be certain we never set up the signal channel.
+       assert.Nil(t, ctx.c)
+       select {
+       case _, ok := <-ctx.Done():
+               assert.False(t, ok, "Done() should be closed")
+       default:
+               assert.False(t, true, "context should be complete")
+       }
+       assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func BenchmarkReceivesSignal(b *testing.B) {
+       for i := 0; i < b.N; i++ {
+               ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+               syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+               <-ctx.Done()
+       }
+}
+
+func BenchmarkCancelChildren(b *testing.B) {
+       children := make([]context.Context, b.N)
+       cancels := make([]context.CancelFunc, b.N)
+       b.ResetTimer()
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       for i := range children {
+               children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
+       }
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       <-ctx.Done()
+       for i := range children {
+               <-children[i].Done()
+       }
+       b.StopTimer()
+       for i := range cancels {
+               cancels[i]()
+       }
+}
+
+func BenchmarkCanceledAsChild(b *testing.B) {
+       children := make([]context.Context, b.N)
+       parent, cancel := context.WithCancel(context.Background())
+       b.ResetTimer()
+       for i := range children {
+               children[i] = UntilSignal(parent, syscall.SIGUSR2)
+       }
+       cancel()
+       <-parent.Done()
+       for i := range children {
+               <-children[i].Done()
+       }
+}