--- /dev/null
+package signalcontext
+
+import (
+ "context"
+ "os"
+ "os/signal"
+ "sync"
+ "time"
+)
+
+// A Context is an implementation of the context.Context interface which
+// completes when a signal (e.g. os.Interrupt) is received.
+//
+// Contexts should be created via the UntilSignal function.
+type Context struct {
+ parent context.Context
+ done chan struct{}
+ err error
+
+ // The mutex synchronizes access to err and clearing the
+ // internal Signal channel after initialization.
+ m sync.Mutex
+ c chan os.Signal
+}
+
+// UntilSignal returns a new Context which will complete when the parent
+// does or when any of the specified signals are received.
+func UntilSignal(parent context.Context, sig ...os.Signal) *Context {
+ ctx := new(Context)
+ ctx.parent = parent
+ ctx.done = make(chan struct{})
+
+ if err := parent.Err(); err != nil {
+ close(ctx.done)
+ ctx.err = err
+ return ctx
+ }
+
+ ctx.c = make(chan os.Signal, 1)
+ signal.Notify(ctx.c, sig...)
+ go ctx.wait(sig...)
+ return ctx
+}
+
+func (s *Context) wait(sig ...os.Signal) {
+ var err error
+ select {
+ case <-s.parent.Done():
+ err = s.parent.Err()
+ case v := <-s.c:
+ if v != nil {
+ err = Error{v}
+ }
+ }
+ signal.Stop(s.c)
+ s.m.Lock()
+ if s.err == nil {
+ s.err = err
+ }
+ close(s.c)
+ s.c = nil
+ s.m.Unlock()
+ close(s.done)
+}
+
+// Cancel cancels this context, if it hasn’t already been completed. (If
+// it has, this is safe but has no effect.)
+func (s *Context) Cancel() {
+ s.m.Lock()
+ if s.c != nil {
+ s.err = context.Canceled
+ select {
+ case s.c <- nil:
+ default:
+ }
+ }
+ s.m.Unlock()
+}
+
+// Deadline implements context.Context; a Context’s deadline is that of
+// its parent.
+func (s *Context) Deadline() (time.Time, bool) {
+ return s.parent.Deadline()
+}
+
+// Value implements context.Context; any value is that of its parent.
+func (s *Context) Value(key interface{}) interface{} {
+ return s.parent.Value(key)
+}
+
+// Done implements context.Context.
+func (s *Context) Done() <-chan struct{} {
+ return s.done
+}
+
+// Err implements context.Context; it returns context.Canceled if the
+// context was canceled; an Error if the context completed due to a
+// signal; the parent’s error if the parent was done before either of
+// those; or nil if the context is not yet done.
+func (s *Context) Err() error {
+ s.m.Lock()
+ err := s.err
+ s.m.Unlock()
+ return err
+}
--- /dev/null
+package signalcontext
+
+import (
+ "context"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestReceivesSignal(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
+}
+
+func TestForwardsParent(t *testing.T) {
+ parent, _ := context.WithTimeout(
+ context.WithValue(context.Background(), t, "test"),
+ time.Millisecond)
+ ctx := UntilSignal(parent, syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ dl, ok := ctx.Deadline()
+ assert.True(t, ok)
+ assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
+ assert.EqualValues(t, "test", ctx.Value(t))
+ <-ctx.Done()
+ assert.Equal(t, context.DeadlineExceeded, ctx.Err())
+}
+
+func TestChildForwardsErr(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ child, cancel := context.WithTimeout(ctx, time.Second)
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-child.Done()
+ <-ctx.Done()
+ cancel()
+ assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
+ assert.Equal(t, Error{syscall.SIGUSR2}, child.Err())
+}
+
+func TestSignalAfterCancel(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ ctx.Cancel()
+ <-ctx.Done()
+ assert.Equal(t, context.Canceled, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ time.Sleep(5 * time.Millisecond)
+ assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func TestCancelAfterSignal(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
+ ctx.Cancel()
+ time.Sleep(5 * time.Millisecond)
+ assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
+}
+
+func TestImmediateCompletion(t *testing.T) {
+ parent, cancel := context.WithCancel(context.Background())
+ cancel()
+ <-parent.Done()
+ ctx := UntilSignal(parent, syscall.SIGUSR2)
+ // peek inside to be certain we never set up the signal channel.
+ assert.Nil(t, ctx.c)
+ select {
+ case _, ok := <-ctx.Done():
+ assert.False(t, ok, "Done() should be closed")
+ default:
+ assert.False(t, true, "context should be complete")
+ }
+ assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func BenchmarkReceivesSignal(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ }
+}
+
+func BenchmarkCancelChildren(b *testing.B) {
+ children := make([]context.Context, b.N)
+ cancels := make([]context.CancelFunc, b.N)
+ b.ResetTimer()
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ for i := range children {
+ children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
+ }
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ for i := range children {
+ <-children[i].Done()
+ }
+ b.StopTimer()
+ for i := range cancels {
+ cancels[i]()
+ }
+}
+
+func BenchmarkCanceledAsChild(b *testing.B) {
+ children := make([]context.Context, b.N)
+ parent, cancel := context.WithCancel(context.Background())
+ b.ResetTimer()
+ for i := range children {
+ children[i] = UntilSignal(parent, syscall.SIGUSR2)
+ }
+ cancel()
+ <-parent.Done()
+ for i := range children {
+ <-children[i].Done()
+ }
+}
--- /dev/null
+package signalcontext
+
+import (
+ "fmt"
+ "os"
+)
+
+// A Error will be returned by a SignalContext’s Err() method when it
+// was finished due to a signal (rather than e.g. parent cancellation).
+type Error struct {
+ os.Signal
+}
+
+func (e Error) Error() string {
+ return e.String()
+}
+
+func (e Error) String() string {
+ return fmt.Sprintf("received signal: %s", e.Signal)
+}
--- /dev/null
+package signalcontext
+
+import (
+ "os"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestError(t *testing.T) {
+ assert.EqualError(t, Error{os.Interrupt},
+ "received signal: "+os.Interrupt.String())
+}
+++ /dev/null
-package signalcontext
-
-import (
- "context"
- "fmt"
- "os"
- "os/signal"
- "sync"
- "time"
-)
-
-// A SignalContext is an implementation of the context.Context interface
-// which completes when a signal (e.g. os.Interrupt) is received.
-//
-// SignalContexts should be created via the UntilSignal function.
-type SignalContext struct {
- parent context.Context
- done chan struct{}
- err error
-
- // The mutex synchronizes access to err and clearing the
- // internal Signal channel after initialization.
- m sync.Mutex
- c chan os.Signal
-}
-
-// UntilSignal returns a new SignalContext which will complete when the
-// parent does or when any of the specified signals are received.
-func UntilSignal(parent context.Context, sig ...os.Signal) *SignalContext {
- ctx := new(SignalContext)
- ctx.parent = parent
- ctx.done = make(chan struct{})
-
- if err := parent.Err(); err != nil {
- close(ctx.done)
- ctx.err = err
- return ctx
- }
-
- ctx.c = make(chan os.Signal, 1)
- signal.Notify(ctx.c, sig...)
- go ctx.wait(sig...)
- return ctx
-}
-
-func (s *SignalContext) wait(sig ...os.Signal) {
- var err error
- select {
- case <-s.parent.Done():
- err = s.parent.Err()
- case v := <-s.c:
- if v != nil {
- err = SignalError{v}
- }
- }
- signal.Stop(s.c)
- s.m.Lock()
- if s.err == nil {
- s.err = err
- }
- close(s.c)
- s.c = nil
- s.m.Unlock()
- close(s.done)
-}
-
-// Cancel cancels this context, if it hasn’t already been canceled or
-// received a signal. (If it has, this is safe but has no effect.)
-func (s *SignalContext) Cancel() {
- s.m.Lock()
- if s.c != nil {
- s.err = context.Canceled
- select {
- case s.c <- nil:
- default:
- }
- }
- s.m.Unlock()
-}
-
-// Deadline implements context.Context; a SignalContext’s deadline is
-// that of its parent.
-func (s *SignalContext) Deadline() (time.Time, bool) {
- return s.parent.Deadline()
-}
-
-// Value implements context.Context; any value is that of its parent.
-func (s *SignalContext) Value(key interface{}) interface{} {
- return s.parent.Value(key)
-}
-
-// Done implements context.Context.
-func (s *SignalContext) Done() <-chan struct{} {
- return s.done
-}
-
-// Err implements context.Context; it returns context.Canceled if the
-// context was canceled; a SignalError if the context completed due to a
-// signal; the parent’s error if the parent was done before either of
-// those; or nil if the context is not yet done.
-func (s *SignalContext) Err() error {
- s.m.Lock()
- err := s.err
- s.m.Unlock()
- return err
-}
-
-// A SignalError will be returned by a SignalContext’s Err() method when
-// it was finished due to a signal (rather than e.g. parent
-// cancellation).
-type SignalError struct {
- os.Signal
-}
-
-func (e SignalError) Error() string {
- return e.String()
-}
-
-func (e SignalError) String() string {
- return fmt.Sprintf("received signal: %s", e.Signal)
-}
+++ /dev/null
-package signalcontext
-
-import (
- "context"
- "syscall"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestReceivesSignal(t *testing.T) {
- ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
- assert.NoError(t, ctx.Err())
- syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
- <-ctx.Done()
- assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
- assert.EqualError(t, ctx.Err(),
- "received signal: "+syscall.SIGUSR2.String())
-}
-
-func TestForwardsParent(t *testing.T) {
- parent, _ := context.WithTimeout(
- context.WithValue(context.Background(), t, "test"),
- time.Millisecond)
- ctx := UntilSignal(parent, syscall.SIGUSR2)
- assert.NoError(t, ctx.Err())
- dl, ok := ctx.Deadline()
- assert.True(t, ok)
- assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
- assert.EqualValues(t, "test", ctx.Value(t))
- <-ctx.Done()
- assert.Equal(t, context.DeadlineExceeded, ctx.Err())
-}
-
-func TestChildForwardsErr(t *testing.T) {
- ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
- child, cancel := context.WithTimeout(ctx, time.Second)
- syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
- <-child.Done()
- <-ctx.Done()
- cancel()
- assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
- assert.Equal(t, SignalError{syscall.SIGUSR2}, child.Err())
-}
-
-func TestSignalAfterCancel(t *testing.T) {
- ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
- assert.NoError(t, ctx.Err())
- ctx.Cancel()
- <-ctx.Done()
- assert.Equal(t, context.Canceled, ctx.Err())
- syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
- time.Sleep(5 * time.Millisecond)
- assert.Equal(t, context.Canceled, ctx.Err())
-}
-
-func TestCancelAfterSignal(t *testing.T) {
- ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
- assert.NoError(t, ctx.Err())
- syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
- <-ctx.Done()
- assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
- ctx.Cancel()
- time.Sleep(5 * time.Millisecond)
- assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
-}
-
-func TestImmediateCompletion(t *testing.T) {
- parent, cancel := context.WithCancel(context.Background())
- cancel()
- <-parent.Done()
- ctx := UntilSignal(parent, syscall.SIGUSR2)
- // peek inside to be certain we never set up the signal channel.
- assert.Nil(t, ctx.c)
- select {
- case _, ok := <-ctx.Done():
- assert.False(t, ok, "Done() should be closed")
- default:
- assert.False(t, true, "context should be complete")
- }
- assert.Equal(t, context.Canceled, ctx.Err())
-}
-
-func BenchmarkReceivesSignal(b *testing.B) {
- for i := 0; i < b.N; i++ {
- ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
- syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
- <-ctx.Done()
- }
-}
-
-func BenchmarkCancelChildren(b *testing.B) {
- children := make([]context.Context, b.N)
- cancels := make([]context.CancelFunc, b.N)
- b.ResetTimer()
- ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
- for i := range children {
- children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
- }
- syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
- <-ctx.Done()
- for i := range children {
- <-children[i].Done()
- }
- b.StopTimer()
- for i := range cancels {
- cancels[i]()
- }
-}
-
-func BenchmarkCanceledAsChild(b *testing.B) {
- children := make([]context.Context, b.N)
- parent, cancel := context.WithCancel(context.Background())
- b.ResetTimer()
- for i := range children {
- children[i] = UntilSignal(parent, syscall.SIGUSR2)
- }
- cancel()
- <-parent.Done()
- for i := range children {
- <-children[i].Done()
- }
-}