From: Joe Wreschnig Date: Sat, 13 Jun 2020 15:12:54 +0000 (+0200) Subject: Remove ‘Signal’ from structure names X-Git-Url: https://git.korewanetadesu.com/?p=go-signalcontext.git;a=commitdiff_plain;h=8cf685c73e37c718d1dcfe057caad080196f3e18 Remove ‘Signal’ from structure names The package is already named ‘signalcontext’, there’s no reason to further scope `SignalError` or `SignalContext`. Split error code and tests into a separate file. --- diff --git a/context.go b/context.go new file mode 100644 index 0000000..1bd24d3 --- /dev/null +++ b/context.go @@ -0,0 +1,105 @@ +package signalcontext + +import ( + "context" + "os" + "os/signal" + "sync" + "time" +) + +// A Context is an implementation of the context.Context interface which +// completes when a signal (e.g. os.Interrupt) is received. +// +// Contexts should be created via the UntilSignal function. +type Context struct { + parent context.Context + done chan struct{} + err error + + // The mutex synchronizes access to err and clearing the + // internal Signal channel after initialization. + m sync.Mutex + c chan os.Signal +} + +// UntilSignal returns a new Context which will complete when the parent +// does or when any of the specified signals are received. +func UntilSignal(parent context.Context, sig ...os.Signal) *Context { + ctx := new(Context) + ctx.parent = parent + ctx.done = make(chan struct{}) + + if err := parent.Err(); err != nil { + close(ctx.done) + ctx.err = err + return ctx + } + + ctx.c = make(chan os.Signal, 1) + signal.Notify(ctx.c, sig...) + go ctx.wait(sig...) + return ctx +} + +func (s *Context) wait(sig ...os.Signal) { + var err error + select { + case <-s.parent.Done(): + err = s.parent.Err() + case v := <-s.c: + if v != nil { + err = Error{v} + } + } + signal.Stop(s.c) + s.m.Lock() + if s.err == nil { + s.err = err + } + close(s.c) + s.c = nil + s.m.Unlock() + close(s.done) +} + +// Cancel cancels this context, if it hasn’t already been completed. (If +// it has, this is safe but has no effect.) +func (s *Context) Cancel() { + s.m.Lock() + if s.c != nil { + s.err = context.Canceled + select { + case s.c <- nil: + default: + } + } + s.m.Unlock() +} + +// Deadline implements context.Context; a Context’s deadline is that of +// its parent. +func (s *Context) Deadline() (time.Time, bool) { + return s.parent.Deadline() +} + +// Value implements context.Context; any value is that of its parent. +func (s *Context) Value(key interface{}) interface{} { + return s.parent.Value(key) +} + +// Done implements context.Context. +func (s *Context) Done() <-chan struct{} { + return s.done +} + +// Err implements context.Context; it returns context.Canceled if the +// context was canceled; an Error if the context completed due to a +// signal; the parent’s error if the parent was done before either of +// those; or nil if the context is not yet done. +func (s *Context) Err() error { + s.m.Lock() + err := s.err + s.m.Unlock() + return err +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..a7b30a4 --- /dev/null +++ b/context_test.go @@ -0,0 +1,122 @@ +package signalcontext + +import ( + "context" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestReceivesSignal(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err()) +} + +func TestForwardsParent(t *testing.T) { + parent, _ := context.WithTimeout( + context.WithValue(context.Background(), t, "test"), + time.Millisecond) + ctx := UntilSignal(parent, syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + dl, ok := ctx.Deadline() + assert.True(t, ok) + assert.WithinDuration(t, time.Now(), dl, time.Millisecond) + assert.EqualValues(t, "test", ctx.Value(t)) + <-ctx.Done() + assert.Equal(t, context.DeadlineExceeded, ctx.Err()) +} + +func TestChildForwardsErr(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + child, cancel := context.WithTimeout(ctx, time.Second) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-child.Done() + <-ctx.Done() + cancel() + assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err()) + assert.Equal(t, Error{syscall.SIGUSR2}, child.Err()) +} + +func TestSignalAfterCancel(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + ctx.Cancel() + <-ctx.Done() + assert.Equal(t, context.Canceled, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + time.Sleep(5 * time.Millisecond) + assert.Equal(t, context.Canceled, ctx.Err()) +} + +func TestCancelAfterSignal(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err()) + ctx.Cancel() + time.Sleep(5 * time.Millisecond) + assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err()) +} + +func TestImmediateCompletion(t *testing.T) { + parent, cancel := context.WithCancel(context.Background()) + cancel() + <-parent.Done() + ctx := UntilSignal(parent, syscall.SIGUSR2) + // peek inside to be certain we never set up the signal channel. + assert.Nil(t, ctx.c) + select { + case _, ok := <-ctx.Done(): + assert.False(t, ok, "Done() should be closed") + default: + assert.False(t, true, "context should be complete") + } + assert.Equal(t, context.Canceled, ctx.Err()) +} + +func BenchmarkReceivesSignal(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + } +} + +func BenchmarkCancelChildren(b *testing.B) { + children := make([]context.Context, b.N) + cancels := make([]context.CancelFunc, b.N) + b.ResetTimer() + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + for i := range children { + children[i], cancels[i] = context.WithTimeout(ctx, time.Hour) + } + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + for i := range children { + <-children[i].Done() + } + b.StopTimer() + for i := range cancels { + cancels[i]() + } +} + +func BenchmarkCanceledAsChild(b *testing.B) { + children := make([]context.Context, b.N) + parent, cancel := context.WithCancel(context.Background()) + b.ResetTimer() + for i := range children { + children[i] = UntilSignal(parent, syscall.SIGUSR2) + } + cancel() + <-parent.Done() + for i := range children { + <-children[i].Done() + } +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..c0a4f97 --- /dev/null +++ b/error.go @@ -0,0 +1,20 @@ +package signalcontext + +import ( + "fmt" + "os" +) + +// A Error will be returned by a SignalContext’s Err() method when it +// was finished due to a signal (rather than e.g. parent cancellation). +type Error struct { + os.Signal +} + +func (e Error) Error() string { + return e.String() +} + +func (e Error) String() string { + return fmt.Sprintf("received signal: %s", e.Signal) +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..eb347ef --- /dev/null +++ b/error_test.go @@ -0,0 +1,13 @@ +package signalcontext + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestError(t *testing.T) { + assert.EqualError(t, Error{os.Interrupt}, + "received signal: "+os.Interrupt.String()) +} diff --git a/signalcontext.go b/signalcontext.go deleted file mode 100644 index d718e11..0000000 --- a/signalcontext.go +++ /dev/null @@ -1,121 +0,0 @@ -package signalcontext - -import ( - "context" - "fmt" - "os" - "os/signal" - "sync" - "time" -) - -// A SignalContext is an implementation of the context.Context interface -// which completes when a signal (e.g. os.Interrupt) is received. -// -// SignalContexts should be created via the UntilSignal function. -type SignalContext struct { - parent context.Context - done chan struct{} - err error - - // The mutex synchronizes access to err and clearing the - // internal Signal channel after initialization. - m sync.Mutex - c chan os.Signal -} - -// UntilSignal returns a new SignalContext which will complete when the -// parent does or when any of the specified signals are received. -func UntilSignal(parent context.Context, sig ...os.Signal) *SignalContext { - ctx := new(SignalContext) - ctx.parent = parent - ctx.done = make(chan struct{}) - - if err := parent.Err(); err != nil { - close(ctx.done) - ctx.err = err - return ctx - } - - ctx.c = make(chan os.Signal, 1) - signal.Notify(ctx.c, sig...) - go ctx.wait(sig...) - return ctx -} - -func (s *SignalContext) wait(sig ...os.Signal) { - var err error - select { - case <-s.parent.Done(): - err = s.parent.Err() - case v := <-s.c: - if v != nil { - err = SignalError{v} - } - } - signal.Stop(s.c) - s.m.Lock() - if s.err == nil { - s.err = err - } - close(s.c) - s.c = nil - s.m.Unlock() - close(s.done) -} - -// Cancel cancels this context, if it hasn’t already been canceled or -// received a signal. (If it has, this is safe but has no effect.) -func (s *SignalContext) Cancel() { - s.m.Lock() - if s.c != nil { - s.err = context.Canceled - select { - case s.c <- nil: - default: - } - } - s.m.Unlock() -} - -// Deadline implements context.Context; a SignalContext’s deadline is -// that of its parent. -func (s *SignalContext) Deadline() (time.Time, bool) { - return s.parent.Deadline() -} - -// Value implements context.Context; any value is that of its parent. -func (s *SignalContext) Value(key interface{}) interface{} { - return s.parent.Value(key) -} - -// Done implements context.Context. -func (s *SignalContext) Done() <-chan struct{} { - return s.done -} - -// Err implements context.Context; it returns context.Canceled if the -// context was canceled; a SignalError if the context completed due to a -// signal; the parent’s error if the parent was done before either of -// those; or nil if the context is not yet done. -func (s *SignalContext) Err() error { - s.m.Lock() - err := s.err - s.m.Unlock() - return err -} - -// A SignalError will be returned by a SignalContext’s Err() method when -// it was finished due to a signal (rather than e.g. parent -// cancellation). -type SignalError struct { - os.Signal -} - -func (e SignalError) Error() string { - return e.String() -} - -func (e SignalError) String() string { - return fmt.Sprintf("received signal: %s", e.Signal) -} diff --git a/signalcontext_test.go b/signalcontext_test.go deleted file mode 100644 index 07109c5..0000000 --- a/signalcontext_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package signalcontext - -import ( - "context" - "syscall" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestReceivesSignal(t *testing.T) { - ctx := UntilSignal(context.Background(), syscall.SIGUSR2) - assert.NoError(t, ctx.Err()) - syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) - <-ctx.Done() - assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err()) - assert.EqualError(t, ctx.Err(), - "received signal: "+syscall.SIGUSR2.String()) -} - -func TestForwardsParent(t *testing.T) { - parent, _ := context.WithTimeout( - context.WithValue(context.Background(), t, "test"), - time.Millisecond) - ctx := UntilSignal(parent, syscall.SIGUSR2) - assert.NoError(t, ctx.Err()) - dl, ok := ctx.Deadline() - assert.True(t, ok) - assert.WithinDuration(t, time.Now(), dl, time.Millisecond) - assert.EqualValues(t, "test", ctx.Value(t)) - <-ctx.Done() - assert.Equal(t, context.DeadlineExceeded, ctx.Err()) -} - -func TestChildForwardsErr(t *testing.T) { - ctx := UntilSignal(context.Background(), syscall.SIGUSR2) - child, cancel := context.WithTimeout(ctx, time.Second) - syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) - <-child.Done() - <-ctx.Done() - cancel() - assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err()) - assert.Equal(t, SignalError{syscall.SIGUSR2}, child.Err()) -} - -func TestSignalAfterCancel(t *testing.T) { - ctx := UntilSignal(context.Background(), syscall.SIGUSR2) - assert.NoError(t, ctx.Err()) - ctx.Cancel() - <-ctx.Done() - assert.Equal(t, context.Canceled, ctx.Err()) - syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) - time.Sleep(5 * time.Millisecond) - assert.Equal(t, context.Canceled, ctx.Err()) -} - -func TestCancelAfterSignal(t *testing.T) { - ctx := UntilSignal(context.Background(), syscall.SIGUSR2) - assert.NoError(t, ctx.Err()) - syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) - <-ctx.Done() - assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err()) - ctx.Cancel() - time.Sleep(5 * time.Millisecond) - assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err()) -} - -func TestImmediateCompletion(t *testing.T) { - parent, cancel := context.WithCancel(context.Background()) - cancel() - <-parent.Done() - ctx := UntilSignal(parent, syscall.SIGUSR2) - // peek inside to be certain we never set up the signal channel. - assert.Nil(t, ctx.c) - select { - case _, ok := <-ctx.Done(): - assert.False(t, ok, "Done() should be closed") - default: - assert.False(t, true, "context should be complete") - } - assert.Equal(t, context.Canceled, ctx.Err()) -} - -func BenchmarkReceivesSignal(b *testing.B) { - for i := 0; i < b.N; i++ { - ctx := UntilSignal(context.Background(), syscall.SIGUSR2) - syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) - <-ctx.Done() - } -} - -func BenchmarkCancelChildren(b *testing.B) { - children := make([]context.Context, b.N) - cancels := make([]context.CancelFunc, b.N) - b.ResetTimer() - ctx := UntilSignal(context.Background(), syscall.SIGUSR2) - for i := range children { - children[i], cancels[i] = context.WithTimeout(ctx, time.Hour) - } - syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) - <-ctx.Done() - for i := range children { - <-children[i].Done() - } - b.StopTimer() - for i := range cancels { - cancels[i]() - } -} - -func BenchmarkCanceledAsChild(b *testing.B) { - children := make([]context.Context, b.N) - parent, cancel := context.WithCancel(context.Background()) - b.ResetTimer() - for i := range children { - children[i] = UntilSignal(parent, syscall.SIGUSR2) - } - cancel() - <-parent.Done() - for i := range children { - <-children[i].Done() - } -}