Initial import
[go-signalcontext.git] / signalcontext_test.go
1 package signalcontext
2
3 import (
4 "context"
5 "syscall"
6 "testing"
7 "time"
8
9 "github.com/stretchr/testify/assert"
10 )
11
12 func TestReceivesSignal(t *testing.T) {
13 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
14 assert.NoError(t, ctx.Err())
15 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
16 <-ctx.Done()
17 assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
18 assert.EqualError(t, ctx.Err(),
19 "received signal: "+syscall.SIGUSR2.String())
20 }
21
22 func TestForwardsParent(t *testing.T) {
23 parent, _ := context.WithTimeout(
24 context.WithValue(context.Background(), t, "test"),
25 time.Millisecond)
26 ctx := UntilSignal(parent, syscall.SIGUSR2)
27 assert.NoError(t, ctx.Err())
28 dl, ok := ctx.Deadline()
29 assert.True(t, ok)
30 assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
31 assert.EqualValues(t, "test", ctx.Value(t))
32 <-ctx.Done()
33 assert.Equal(t, context.DeadlineExceeded, ctx.Err())
34 }
35
36 func TestChildForwardsErr(t *testing.T) {
37 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
38 child, cancel := context.WithTimeout(ctx, time.Second)
39 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
40 <-child.Done()
41 <-ctx.Done()
42 cancel()
43 assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
44 assert.Equal(t, SignalError{syscall.SIGUSR2}, child.Err())
45 }
46
47 func TestSignalAfterCancel(t *testing.T) {
48 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
49 assert.NoError(t, ctx.Err())
50 ctx.Cancel()
51 <-ctx.Done()
52 assert.Equal(t, context.Canceled, ctx.Err())
53 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
54 time.Sleep(5 * time.Millisecond)
55 assert.Equal(t, context.Canceled, ctx.Err())
56 }
57
58 func TestCancelAfterSignal(t *testing.T) {
59 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
60 assert.NoError(t, ctx.Err())
61 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
62 <-ctx.Done()
63 assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
64 ctx.Cancel()
65 time.Sleep(5 * time.Millisecond)
66 assert.Equal(t, SignalError{syscall.SIGUSR2}, ctx.Err())
67 }
68
69 func TestImmediateCompletion(t *testing.T) {
70 parent, cancel := context.WithCancel(context.Background())
71 cancel()
72 <-parent.Done()
73 ctx := UntilSignal(parent, syscall.SIGUSR2)
74 // peek inside to be certain we never set up the signal channel.
75 assert.Nil(t, ctx.c)
76 select {
77 case _, ok := <-ctx.Done():
78 assert.False(t, ok, "Done() should be closed")
79 default:
80 assert.False(t, true, "context should be complete")
81 }
82 assert.Equal(t, context.Canceled, ctx.Err())
83 }
84
85 func BenchmarkReceivesSignal(b *testing.B) {
86 for i := 0; i < b.N; i++ {
87 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
88 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
89 <-ctx.Done()
90 }
91 }
92
93 func BenchmarkCancelChildren(b *testing.B) {
94 children := make([]context.Context, b.N)
95 cancels := make([]context.CancelFunc, b.N)
96 b.ResetTimer()
97 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
98 for i := range children {
99 children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
100 }
101 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
102 <-ctx.Done()
103 for i := range children {
104 <-children[i].Done()
105 }
106 b.StopTimer()
107 for i := range cancels {
108 cancels[i]()
109 }
110 }
111
112 func BenchmarkCanceledAsChild(b *testing.B) {
113 children := make([]context.Context, b.N)
114 parent, cancel := context.WithCancel(context.Background())
115 b.ResetTimer()
116 for i := range children {
117 children[i] = UntilSignal(parent, syscall.SIGUSR2)
118 }
119 cancel()
120 <-parent.Done()
121 for i := range children {
122 <-children[i].Done()
123 }
124 }