Add project documentation and licensing
[go-signalcontext.git] / context_test.go
1 package signalcontext
2
3 import (
4 "context"
5 "syscall"
6 "testing"
7 "time"
8
9 "github.com/stretchr/testify/assert"
10 )
11
12 func TestReceivesSignal(t *testing.T) {
13 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
14 assert.NoError(t, ctx.Err())
15 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
16 <-ctx.Done()
17 assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
18 }
19
20 func TestForwardsParent(t *testing.T) {
21 parent, _ := context.WithTimeout(
22 context.WithValue(context.Background(), t, "test"),
23 time.Millisecond)
24 ctx := UntilSignal(parent, syscall.SIGUSR2)
25 assert.NoError(t, ctx.Err())
26 dl, ok := ctx.Deadline()
27 assert.True(t, ok)
28 assert.WithinDuration(t, time.Now(), dl, time.Millisecond)
29 assert.EqualValues(t, "test", ctx.Value(t))
30 <-ctx.Done()
31 assert.Equal(t, context.DeadlineExceeded, ctx.Err())
32 }
33
34 func TestChildForwardsErr(t *testing.T) {
35 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
36 child, cancel := context.WithTimeout(ctx, time.Second)
37 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
38 <-child.Done()
39 <-ctx.Done()
40 cancel()
41 assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
42 assert.Equal(t, Error{syscall.SIGUSR2}, child.Err())
43 }
44
45 func TestSignalAfterCancel(t *testing.T) {
46 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
47 assert.NoError(t, ctx.Err())
48 ctx.Cancel()
49 <-ctx.Done()
50 assert.Equal(t, context.Canceled, ctx.Err())
51 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
52 time.Sleep(5 * time.Millisecond)
53 assert.Equal(t, context.Canceled, ctx.Err())
54 }
55
56 func TestCancelAfterSignal(t *testing.T) {
57 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
58 assert.NoError(t, ctx.Err())
59 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
60 <-ctx.Done()
61 assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
62 ctx.Cancel()
63 time.Sleep(5 * time.Millisecond)
64 assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err())
65 }
66
67 func TestPrecanceled(t *testing.T) {
68 parent, cancel := context.WithCancel(context.Background())
69 cancel()
70 <-parent.Done()
71 ctx := UntilSignal(parent, syscall.SIGUSR2)
72 // peek inside to be certain we never set up the signal channel.
73 assert.Nil(t, ctx.c)
74 select {
75 case _, ok := <-ctx.Done():
76 assert.False(t, ok, "Done() should be closed")
77 default:
78 assert.False(t, true, "Done() should be closed")
79 }
80 assert.Equal(t, context.Canceled, ctx.Err())
81 }
82
83 func BenchmarkReceivesSignal(b *testing.B) {
84 for i := 0; i < b.N; i++ {
85 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
86 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
87 <-ctx.Done()
88 }
89 }
90
91 func BenchmarkCancelChildren(b *testing.B) {
92 children := make([]context.Context, b.N)
93 cancels := make([]context.CancelFunc, b.N)
94 b.ResetTimer()
95 ctx := UntilSignal(context.Background(), syscall.SIGUSR2)
96 for i := range children {
97 children[i], cancels[i] = context.WithTimeout(ctx, time.Hour)
98 }
99 syscall.Kill(syscall.Getpid(), syscall.SIGUSR2)
100 <-ctx.Done()
101 for i := range children {
102 <-children[i].Done()
103 }
104 b.StopTimer()
105 for i := range cancels {
106 cancels[i]()
107 }
108 }
109
110 func BenchmarkCanceledAsChild(b *testing.B) {
111 children := make([]context.Context, b.N)
112 parent, cancel := context.WithCancel(context.Background())
113 b.ResetTimer()
114 for i := range children {
115 children[i] = UntilSignal(parent, syscall.SIGUSR2)
116 }
117 cancel()
118 <-parent.Done()
119 for i := range children {
120 <-children[i].Done()
121 }
122 }
123
124 func BenchmarkPrecanceled(b *testing.B) {
125 parent, cancel := context.WithCancel(context.Background())
126 cancel()
127 <-parent.Done()
128 b.ResetTimer()
129 for i := 0; i < b.N; i++ {
130 ctx := UntilSignal(parent, syscall.SIGUSR2)
131 <-ctx.Done()
132 }
133 }