From f6b6f942533c75bd0b6884816e652794def00e44 Mon Sep 17 00:00:00 2001 From: Joe Wreschnig Date: Sat, 13 Jun 2020 09:13:34 +0200 Subject: [PATCH 1/1] Initial import --- .gitignore | 4 ++ Makefile | 7 +++ cmd/example/main.go | 20 +++++++ go.mk | 34 ++++++++++++ go.mod | 5 ++ go.sum | 11 ++++ signalcontext.go | 121 +++++++++++++++++++++++++++++++++++++++++ signalcontext_test.go | 124 ++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 326 insertions(+) create mode 100644 .gitignore create mode 100755 Makefile create mode 100644 cmd/example/main.go create mode 100644 go.mk create mode 100644 go.mod create mode 100644 go.sum create mode 100644 signalcontext.go create mode 100644 signalcontext_test.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..510252d --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/example +/go.benchmark +/go.coverage.out +/go.coverage.html diff --git a/Makefile b/Makefile new file mode 100755 index 0000000..37f78a2 --- /dev/null +++ b/Makefile @@ -0,0 +1,7 @@ +#!/usr/bin/make -rf + +include ./go.mk + +all:: $(GOBIN) + +.DEFAULT_GOAL := all diff --git a/cmd/example/main.go b/cmd/example/main.go new file mode 100644 index 0000000..81585b6 --- /dev/null +++ b/cmd/example/main.go @@ -0,0 +1,20 @@ +package main + +import ( + "context" + "log" + "os" + "syscall" + "time" + + "git.korewanetadesu.com/go-signalcontext" +) + +func main() { + log.Print("waiting 10 seconds or until signal...") + p, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx := signalcontext.UntilSignal(p, os.Interrupt, syscall.SIGTERM) + <-ctx.Done() + cancel() + log.Print(ctx.Err()) +} diff --git a/go.mk b/go.mk new file mode 100644 index 0000000..8360cfb --- /dev/null +++ b/go.mk @@ -0,0 +1,34 @@ +GOPKGSRC ?= $(wildcard *.go) +GOTESTSRC ?= $(wildcard *_test.go) +GOSRC ?= $(filter-out $(GOTESTSRC),$(GOPKGSRC)) +GOTESTDATA ?= $(shell test ! -d testdata || find testdata) + +GOMAINSRC ?= $(wildcard cmd/*/main.go) +GOBIN ?= $(patsubst cmd/%/main.go,%,$(GOMAINSRC)) + +GOCOVERAGE ?= go.coverage +GOBENCHMARK ?= go.benchmark + +%: cmd/%/main.go $(GOPKGSRC) + go build -o $@ $(GOBUILDFLAGS) $< + +test:: $(GOCOVERAGE).out + +$(GOCOVERAGE).out: GOTESTFLAGS ?= -race +$(GOCOVERAGE).out: $(GOPKGSRC) $(GOTESTSRC) $(GOTESTDATA) go.mod + go test -coverprofile=$@ $(GOTESTFLAGS) ./... + +$(GOCOVERAGE).html: $(GOCOVERAGE).out + go tool cover -html $< -o $@ + +bench:: $(GOBENCHMARK) + +$(GOBENCHMARK): GOBENCHFLAGS ?= -benchmem +$(GOBENCHMARK): $(GOPKGSRC) $(GOTESTSRC) $(GOTESTDATA) go.mod + go test -bench . $(GOBENCHFLAGS) | tee $@ + +clean:: + $(RM) $(GOCOVERAGE).{html,out} $(GOBENCHMARK) $(GOBIN) + +.PHONY: clean test $(GOCOVERAGE) +.DELETE_ON_ERROR: diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ec2df26 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.korewanetadesu.com/go-signalcontext + +go 1.13 + +require github.com/stretchr/testify v1.6.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..afe7890 --- /dev/null +++ b/go.sum @@ -0,0 +1,11 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/signalcontext.go b/signalcontext.go new file mode 100644 index 0000000..d718e11 --- /dev/null +++ b/signalcontext.go @@ -0,0 +1,121 @@ +package signalcontext + +import ( + "context" + "fmt" + "os" + "os/signal" + "sync" + "time" +) + +// A SignalContext is an implementation of the context.Context interface +// which completes when a signal (e.g. os.Interrupt) is received. +// +// SignalContexts should be created via the UntilSignal function. +type SignalContext struct { + parent context.Context + done chan struct{} + err error + + // The mutex synchronizes access to err and clearing the + // internal Signal channel after initialization. + m sync.Mutex + c chan os.Signal +} + +// UntilSignal returns a new SignalContext which will complete when the +// parent does or when any of the specified signals are received. +func UntilSignal(parent context.Context, sig ...os.Signal) *SignalContext { + ctx := new(SignalContext) + ctx.parent = parent + ctx.done = make(chan struct{}) + + if err := parent.Err(); err != nil { + close(ctx.done) + ctx.err = err + return ctx + } + + ctx.c = make(chan os.Signal, 1) + signal.Notify(ctx.c, sig...) + go ctx.wait(sig...) + return ctx +} + +func (s *SignalContext) wait(sig ...os.Signal) { + var err error + select { + case <-s.parent.Done(): + err = s.parent.Err() + case v := <-s.c: + if v != nil { + err = SignalError{v} + } + } + signal.Stop(s.c) + s.m.Lock() + if s.err == nil { + s.err = err + } + close(s.c) + s.c = nil + s.m.Unlock() + close(s.done) +} + +// Cancel cancels this context, if it hasn’t already been canceled or +// received a signal. (If it has, this is safe but has no effect.) +func (s *SignalContext) Cancel() { + s.m.Lock() + if s.c != nil { + s.err = context.Canceled + select { + case s.c <- nil: + default: + } + } + s.m.Unlock() +} + +// Deadline implements context.Context; a SignalContext’s deadline is +// that of its parent. +func (s *SignalContext) Deadline() (time.Time, bool) { + return s.parent.Deadline() +} + +// Value implements context.Context; any value is that of its parent. +func (s *SignalContext) Value(key interface{}) interface{} { + return s.parent.Value(key) +} + +// Done implements context.Context. +func (s *SignalContext) Done() <-chan struct{} { + return s.done +} + +// Err implements context.Context; it returns context.Canceled if the +// context was canceled; a SignalError if the context completed due to a +// signal; the parent’s error if the parent was done before either of +// those; or nil if the context is not yet done. +func (s *SignalContext) Err() error { + s.m.Lock() + err := s.err + s.m.Unlock() + return err +} + +// A SignalError will be returned by a SignalContext’s Err() method when +// it was finished due to a signal (rather than e.g. parent +// cancellation). +type SignalError struct { + os.Signal +} + +func (e SignalError) Error() string { + return e.String() +} + +func (e SignalError) String() string { + return fmt.Sprintf("received signal: %s", e.Signal) +} diff --git a/signalcontext_test.go b/signalcontext_test.go new file mode 100644 index 0000000..07109c5 --- /dev/null +++ b/signalcontext_test.go @@ -0,0 +1,124 @@ +package signalcontext + +import ( + "context" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestReceivesSignal(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err()) + assert.EqualError(t, ctx.Err(), + "received signal: "+syscall.SIGUSR2.String()) +} + +func TestForwardsParent(t *testing.T) { + parent, _ := context.WithTimeout( + context.WithValue(context.Background(), t, "test"), + time.Millisecond) + ctx := UntilSignal(parent, syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + dl, ok := ctx.Deadline() + assert.True(t, ok) + assert.WithinDuration(t, time.Now(), dl, time.Millisecond) + assert.EqualValues(t, "test", ctx.Value(t)) + <-ctx.Done() + assert.Equal(t, context.DeadlineExceeded, ctx.Err()) +} + +func TestChildForwardsErr(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + child, cancel := context.WithTimeout(ctx, time.Second) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-child.Done() + <-ctx.Done() + cancel() + assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err()) + assert.Equal(t, SignalError{syscall.SIGUSR2}, child.Err()) +} + +func TestSignalAfterCancel(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + ctx.Cancel() + <-ctx.Done() + assert.Equal(t, context.Canceled, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + time.Sleep(5 * time.Millisecond) + assert.Equal(t, context.Canceled, ctx.Err()) +} + +func TestCancelAfterSignal(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err()) + ctx.Cancel() + time.Sleep(5 * time.Millisecond) + assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err()) +} + +func TestImmediateCompletion(t *testing.T) { + parent, cancel := context.WithCancel(context.Background()) + cancel() + <-parent.Done() + ctx := UntilSignal(parent, syscall.SIGUSR2) + // peek inside to be certain we never set up the signal channel. + assert.Nil(t, ctx.c) + select { + case _, ok := <-ctx.Done(): + assert.False(t, ok, "Done() should be closed") + default: + assert.False(t, true, "context should be complete") + } + assert.Equal(t, context.Canceled, ctx.Err()) +} + +func BenchmarkReceivesSignal(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + } +} + +func BenchmarkCancelChildren(b *testing.B) { + children := make([]context.Context, b.N) + cancels := make([]context.CancelFunc, b.N) + b.ResetTimer() + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + for i := range children { + children[i], cancels[i] = context.WithTimeout(ctx, time.Hour) + } + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + for i := range children { + <-children[i].Done() + } + b.StopTimer() + for i := range cancels { + cancels[i]() + } +} + +func BenchmarkCanceledAsChild(b *testing.B) { + children := make([]context.Context, b.N) + parent, cancel := context.WithCancel(context.Background()) + b.ResetTimer() + for i := range children { + children[i] = UntilSignal(parent, syscall.SIGUSR2) + } + cancel() + <-parent.Done() + for i := range children { + <-children[i].Done() + } +} -- 2.30.2