+package signalcontext
+
+import (
+ "context"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestReceivesSignal(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+ assert.EqualError(t, ctx.Err(),
+ "received signal: "+syscall.SIGUSR2.String())
+}
+
+func TestForwardsParent(t *testing.T) {
+ parent, _ := context.WithTimeout(
+ context.WithValue(context.Background(), t, "test"),
+ time.Millisecond)
+ ctx := UntilSignal(parent, syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ dl, ok := ctx.Deadline()
+ assert.True(t, ok)
+ assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
+ assert.EqualValues(t, "test", ctx.Value(t))
+ <-ctx.Done()
+ assert.Equal(t, context.DeadlineExceeded, ctx.Err())
+}
+
+func TestChildForwardsErr(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ child, cancel := context.WithTimeout(ctx, time.Second)
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-child.Done()
+ <-ctx.Done()
+ cancel()
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, child.Err())
+}
+
+func TestSignalAfterCancel(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ ctx.Cancel()
+ <-ctx.Done()
+ assert.Equal(t, context.Canceled, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ time.Sleep(5 * time.Millisecond)
+ assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func TestCancelAfterSignal(t *testing.T) {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ assert.NoError(t, ctx.Err())
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+ ctx.Cancel()
+ time.Sleep(5 * time.Millisecond)
+ assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
+}
+
+func TestImmediateCompletion(t *testing.T) {
+ parent, cancel := context.WithCancel(context.Background())
+ cancel()
+ <-parent.Done()
+ ctx := UntilSignal(parent, syscall.SIGUSR2)
+ // peek inside to be certain we never set up the signal channel.
+ assert.Nil(t, ctx.c)
+ select {
+ case _, ok := <-ctx.Done():
+ assert.False(t, ok, "Done() should be closed")
+ default:
+ assert.False(t, true, "context should be complete")
+ }
+ assert.Equal(t, context.Canceled, ctx.Err())
+}
+
+func BenchmarkReceivesSignal(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ }
+}
+
+func BenchmarkCancelChildren(b *testing.B) {
+ children := make([]context.Context, b.N)
+ cancels := make([]context.CancelFunc, b.N)
+ b.ResetTimer()
+ ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
+ for i := range children {
+ children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
+ }
+ syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
+ <-ctx.Done()
+ for i := range children {
+ <-children[i].Done()
+ }
+ b.StopTimer()
+ for i := range cancels {
+ cancels[i]()
+ }
+}
+
+func BenchmarkCanceledAsChild(b *testing.B) {
+ children := make([]context.Context, b.N)
+ parent, cancel := context.WithCancel(context.Background())
+ b.ResetTimer()
+ for i := range children {
+ children[i] = UntilSignal(parent, syscall.SIGUSR2)
+ }
+ cancel()
+ <-parent.Done()
+ for i := range children {
+ <-children[i].Done()
+ }
+}