Remove ‘Signal’ from structure names
authorJoe Wreschnig <joe.wreschnig@gmail.com>
Sat, 13 Jun 2020 15:12:54 +0000 (17:12 +0200)
committerJoe Wreschnig <joe.wreschnig@gmail.com>
Sun, 14 Jun 2020 13:50:00 +0000 (15:50 +0200)
The package is already named ‘signalcontext’, there’s no reason to
further scope `SignalError` or `SignalContext`.

Split error code and tests into a separate file.

context.go [new file with mode: 0644]
context_test.go [new file with mode: 0644]
error.go [new file with mode: 0644]
error_test.go [new file with mode: 0644]
signalcontext.go [deleted file]
signalcontext_test.go [deleted file]

diff --git a/context.go b/context.go
new file mode 100644 (file)
index 0000000..1bd24d3
--- /dev/null
@@ -0,0 +1,105 @@
+package signalcontext
+
+import (
+       "context"
+       "os"
+       "os/signal"
+       "sync"
+       "time"
+)
+
+// A Context is an implementation of the context.Context interface which
+// completes when a signal (e.g. os.Interrupt) is received.
+//
+// Contexts should be created via the UntilSignal function.
+type Context struct {
+       parent context.Context
+       done   chan struct{}
+       err    error
+
+       // The mutex synchronizes access to err and clearing the
+       // internal Signal channel after initialization.
+       m sync.Mutex
+       c chan os.Signal
+}
+
+// UntilSignal returns a new Context which will complete when the parent
+// does or when any of the specified signals are received.
+func UntilSignal(parent context.Context, sig ...os.Signal) *Context {
+       ctx := new(Context)
+       ctx.parent = parent
+       ctx.done = make(chan struct{})
+
+       if err := parent.Err(); err != nil {
+               close(ctx.done)
+               ctx.err = err
+               return ctx
+       }
+
+       ctx.c = make(chan os.Signal, 1)
+       signal.Notify(ctx.c, sig...)
+       go ctx.wait(sig...)
+       return ctx
+}
+
+func (s *Context) wait(sig ...os.Signal) {
+       var err error
+       select {
+       case <-s.parent.Done():
+               err = s.parent.Err()
+       case v := <-s.c:
+               if v != nil {
+                       err = Error{v}
+               }
+       }
+       signal.Stop(s.c)
+       s.m.Lock()
+       if s.err == nil {
+               s.err = err
+       }
+       close(s.c)
+       s.c = nil
+       s.m.Unlock()
+       close(s.done)
+}
+
+// Cancel cancels this context, if it hasn’t already been completed. (If
+// it has, this is safe but has no effect.)
+func (s *Context) Cancel() {
+       s.m.Lock()
+       if s.c != nil {
+               s.err = context.Canceled
+               select {
+               case s.c <- nil:
+               default:
+               }
+       }
+       s.m.Unlock()
+}
+
+// Deadline implements context.Context; a Context’s deadline is that of
+// its parent.
+func (s *Context) Deadline() (time.Time, bool) {
+       return s.parent.Deadline()
+}
+
+// Value implements context.Context; any value is that of its parent.
+func (s *Context) Value(key interface{}) interface{} {
+       return s.parent.Value(key)
+}
+
+// Done implements context.Context.
+func (s *Context) Done() <-chan struct{} {
+       return s.done
+}
+
+// Err implements context.Context; it returns context.Canceled if the
+// context was canceled; an Error if the context completed due to a
+// signal; the parent’s error if the parent was done before either of
+// those; or nil if the context is not yet done.
+func (s *Context) Err() error {
+       s.m.Lock()
+       err := s.err
+       s.m.Unlock()
+       return err
+}
diff --git a/context_test.go b/context_test.go
new file mode 100644 (file)
index 0000000..a7b30a4
--- /dev/null
@@ -0,0 +1,122 @@
+package signalcontext
+
+import (
+       "context"
+       "syscall"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestReceivesSignal(t *testing.T) {
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       assert.NoError(t, ctx.Err())
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       <-ctx.Done()
+       assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
+}
+
+func TestForwardsParent(t *testing.T) {
+       parent, _ := context.WithTimeout(
+               context.WithValue(context.Background(), t, "test"),
+               time.Millisecond)
+       ctx := UntilSignal(parent, syscall.SIGUSR2)
+       assert.NoError(t, ctx.Err())
+       dl, ok := ctx.Deadline()
+       assert.True(t, ok)
+       assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
+       assert.EqualValues(t, "test", ctx.Value(t))
+       <-ctx.Done()
+       assert.Equal(t, context.DeadlineExceeded, ctx.Err())
+}
+
+func TestChildForwardsErr(t *testing.T) {
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       child, cancel := context.WithTimeout(ctx, time.Second)
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       <-child.Done()
+       <-ctx.Done()
+       cancel()
+       assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
+       assert.Equal(t, Error{syscall.SIGUSR2}, child.Err())
+}
+
+func TestSignalAfterCancel(t *testing.T) {
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       assert.NoError(t, ctx.Err())
+       ctx.Cancel()
+       <-ctx.Done()
+       assert.Equal(t, context.Canceled, ctx.Err())
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       time.Sleep(5 * time.Millisecond)
+       assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func TestCancelAfterSignal(t *testing.T) {
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       assert.NoError(t, ctx.Err())
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       <-ctx.Done()
+       assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
+       ctx.Cancel()
+       time.Sleep(5 * time.Millisecond)
+       assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
+}
+
+func TestImmediateCompletion(t *testing.T) {
+       parent, cancel := context.WithCancel(context.Background())
+       cancel()
+       <-parent.Done()
+       ctx := UntilSignal(parent, syscall.SIGUSR2)
+       // peek inside to be certain we never set up the signal channel.
+       assert.Nil(t, ctx.c)
+       select {
+       case _, ok := <-ctx.Done():
+               assert.False(t, ok, "Done() should be closed")
+       default:
+               assert.False(t, true, "context should be complete")
+       }
+       assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func BenchmarkReceivesSignal(b *testing.B) {
+       for i := 0; i < b.N; i++ {
+               ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+               syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+               <-ctx.Done()
+       }
+}
+
+func BenchmarkCancelChildren(b *testing.B) {
+       children := make([]context.Context, b.N)
+       cancels := make([]context.CancelFunc, b.N)
+       b.ResetTimer()
+       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+       for i := range children {
+               children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
+       }
+       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+       <-ctx.Done()
+       for i := range children {
+               <-children[i].Done()
+       }
+       b.StopTimer()
+       for i := range cancels {
+               cancels[i]()
+       }
+}
+
+func BenchmarkCanceledAsChild(b *testing.B) {
+       children := make([]context.Context, b.N)
+       parent, cancel := context.WithCancel(context.Background())
+       b.ResetTimer()
+       for i := range children {
+               children[i] = UntilSignal(parent, syscall.SIGUSR2)
+       }
+       cancel()
+       <-parent.Done()
+       for i := range children {
+               <-children[i].Done()
+       }
+}
diff --git a/error.go b/error.go
new file mode 100644 (file)
index 0000000..c0a4f97
--- /dev/null
+++ b/error.go
@@ -0,0 +1,20 @@
+package signalcontext
+
+import (
+       "fmt"
+       "os"
+)
+
+// A Error will be returned by a SignalContext’s Err() method when it
+// was finished due to a signal (rather than e.g. parent cancellation).
+type Error struct {
+       os.Signal
+}
+
+func (e Error) Error() string {
+       return e.String()
+}
+
+func (e Error) String() string {
+       return fmt.Sprintf("received signal: %s", e.Signal)
+}
diff --git a/error_test.go b/error_test.go
new file mode 100644 (file)
index 0000000..eb347ef
--- /dev/null
@@ -0,0 +1,13 @@
+package signalcontext
+
+import (
+       "os"
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+)
+
+func TestError(t *testing.T) {
+       assert.EqualError(t, Error{os.Interrupt},
+               "received signal: "+os.Interrupt.String())
+}
diff --git a/signalcontext.go b/signalcontext.go
deleted file mode 100644 (file)
index d718e11..0000000
+++ /dev/null
@@ -1,121 +0,0 @@
-package signalcontext
-
-import (
-       "context"
-       "fmt"
-       "os"
-       "os/signal"
-       "sync"
-       "time"
-)
-
-// A SignalContext is an implementation of the context.Context interface
-// which completes when a signal (e.g. os.Interrupt) is received.
-//
-// SignalContexts should be created via the UntilSignal function.
-type SignalContext struct {
-       parent context.Context
-       done   chan struct{}
-       err    error
-
-       // The mutex synchronizes access to err and clearing the
-       // internal Signal channel after initialization.
-       m sync.Mutex
-       c chan os.Signal
-}
-
-// UntilSignal returns a new SignalContext which will complete when the
-// parent does or when any of the specified signals are received.
-func UntilSignal(parent context.Context, sig ...os.Signal) *SignalContext {
-       ctx := new(SignalContext)
-       ctx.parent = parent
-       ctx.done = make(chan struct{})
-
-       if err := parent.Err(); err != nil {
-               close(ctx.done)
-               ctx.err = err
-               return ctx
-       }
-
-       ctx.c = make(chan os.Signal, 1)
-       signal.Notify(ctx.c, sig...)
-       go ctx.wait(sig...)
-       return ctx
-}
-
-func (s *SignalContext) wait(sig ...os.Signal) {
-       var err error
-       select {
-       case <-s.parent.Done():
-               err = s.parent.Err()
-       case v := <-s.c:
-               if v != nil {
-                       err = SignalError{v}
-               }
-       }
-       signal.Stop(s.c)
-       s.m.Lock()
-       if s.err == nil {
-               s.err = err
-       }
-       close(s.c)
-       s.c = nil
-       s.m.Unlock()
-       close(s.done)
-}
-
-// Cancel cancels this context, if it hasn’t already been canceled or
-// received a signal. (If it has, this is safe but has no effect.)
-func (s *SignalContext) Cancel() {
-       s.m.Lock()
-       if s.c != nil {
-               s.err = context.Canceled
-               select {
-               case s.c <- nil:
-               default:
-               }
-       }
-       s.m.Unlock()
-}
-
-// Deadline implements context.Context; a SignalContext’s deadline is
-// that of its parent.
-func (s *SignalContext) Deadline() (time.Time, bool) {
-       return s.parent.Deadline()
-}
-
-// Value implements context.Context; any value is that of its parent.
-func (s *SignalContext) Value(key interface{}) interface{} {
-       return s.parent.Value(key)
-}
-
-// Done implements context.Context.
-func (s *SignalContext) Done() <-chan struct{} {
-       return s.done
-}
-
-// Err implements context.Context; it returns context.Canceled if the
-// context was canceled; a SignalError if the context completed due to a
-// signal; the parent’s error if the parent was done before either of
-// those; or nil if the context is not yet done.
-func (s *SignalContext) Err() error {
-       s.m.Lock()
-       err := s.err
-       s.m.Unlock()
-       return err
-}
-
-// A SignalError will be returned by a SignalContext’s Err() method when
-// it was finished due to a signal (rather than e.g. parent
-// cancellation).
-type SignalError struct {
-       os.Signal
-}
-
-func (e SignalError) Error() string {
-       return e.String()
-}
-
-func (e SignalError) String() string {
-       return fmt.Sprintf("received signal: %s", e.Signal)
-}
diff --git a/signalcontext_test.go b/signalcontext_test.go
deleted file mode 100644 (file)
index 07109c5..0000000
+++ /dev/null
@@ -1,124 +0,0 @@
-package signalcontext
-
-import (
-       "context"
-       "syscall"
-       "testing"
-       "time"
-
-       "github.com/stretchr/testify/assert"
-)
-
-func TestReceivesSignal(t *testing.T) {
-       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
-       assert.NoError(t, ctx.Err())
-       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
-       <-ctx.Done()
-       assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
-       assert.EqualError(t, ctx.Err(),
-               "received signal: "+syscall.SIGUSR2.String())
-}
-
-func TestForwardsParent(t *testing.T) {
-       parent, _ := context.WithTimeout(
-               context.WithValue(context.Background(), t, "test"),
-               time.Millisecond)
-       ctx := UntilSignal(parent, syscall.SIGUSR2)
-       assert.NoError(t, ctx.Err())
-       dl, ok := ctx.Deadline()
-       assert.True(t, ok)
-       assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
-       assert.EqualValues(t, "test", ctx.Value(t))
-       <-ctx.Done()
-       assert.Equal(t, context.DeadlineExceeded, ctx.Err())
-}
-
-func TestChildForwardsErr(t *testing.T) {
-       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
-       child, cancel := context.WithTimeout(ctx, time.Second)
-       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
-       <-child.Done()
-       <-ctx.Done()
-       cancel()
-       assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
-       assert.Equal(t, SignalError{syscall.SIGUSR2}, child.Err())
-}
-
-func TestSignalAfterCancel(t *testing.T) {
-       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
-       assert.NoError(t, ctx.Err())
-       ctx.Cancel()
-       <-ctx.Done()
-       assert.Equal(t, context.Canceled, ctx.Err())
-       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
-       time.Sleep(5 * time.Millisecond)
-       assert.Equal(t, context.Canceled, ctx.Err())
-}
-
-func TestCancelAfterSignal(t *testing.T) {
-       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
-       assert.NoError(t, ctx.Err())
-       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
-       <-ctx.Done()
-       assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
-       ctx.Cancel()
-       time.Sleep(5 * time.Millisecond)
-       assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
-}
-
-func TestImmediateCompletion(t *testing.T) {
-       parent, cancel := context.WithCancel(context.Background())
-       cancel()
-       <-parent.Done()
-       ctx := UntilSignal(parent, syscall.SIGUSR2)
-       // peek inside to be certain we never set up the signal channel.
-       assert.Nil(t, ctx.c)
-       select {
-       case _, ok := <-ctx.Done():
-               assert.False(t, ok, "Done() should be closed")
-       default:
-               assert.False(t, true, "context should be complete")
-       }
-       assert.Equal(t, context.Canceled, ctx.Err())
-}
-
-func BenchmarkReceivesSignal(b *testing.B) {
-       for i := 0; i < b.N; i++ {
-               ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
-               syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
-               <-ctx.Done()
-       }
-}
-
-func BenchmarkCancelChildren(b *testing.B) {
-       children := make([]context.Context, b.N)
-       cancels := make([]context.CancelFunc, b.N)
-       b.ResetTimer()
-       ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
-       for i := range children {
-               children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
-       }
-       syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
-       <-ctx.Done()
-       for i := range children {
-               <-children[i].Done()
-       }
-       b.StopTimer()
-       for i := range cancels {
-               cancels[i]()
-       }
-}
-
-func BenchmarkCanceledAsChild(b *testing.B) {
-       children := make([]context.Context, b.N)
-       parent, cancel := context.WithCancel(context.Background())
-       b.ResetTimer()
-       for i := range children {
-               children[i] = UntilSignal(parent, syscall.SIGUSR2)
-       }
-       cancel()
-       <-parent.Done()
-       for i := range children {
-               <-children[i].Done()
-       }
-}