d718e1199e5dd60c454f838a6bde5e08635bcf5d
[go-signalcontext.git] / signalcontext.go
1 package signalcontext
2
3 import (
4 "context"
5 "fmt"
6 "os"
7 "os/signal"
8 "sync"
9 "time"
10 )
11
12 // A SignalContext is an implementation of the context.Context interface
13 // which completes when a signal (e.g. os.Interrupt) is received.
14 //
15 // SignalContexts should be created via the UntilSignal function.
16 type SignalContext struct {
17 parent context.Context
18 done chan struct{}
19 err error
20
21 // The mutex synchronizes access to err and clearing the
22 // internal Signal channel after initialization.
23 m sync.Mutex
24 c chan os.Signal
25 }
26
27 // UntilSignal returns a new SignalContext which will complete when the
28 // parent does or when any of the specified signals are received.
29 func UntilSignal(parent context.Context, sig ...os.Signal) *SignalContext {
30 ctx := new(SignalContext)
31 ctx.parent = parent
32 ctx.done = make(chan struct{})
33
34 if err := parent.Err(); err != nil {
35 close(ctx.done)
36 ctx.err = err
37 return ctx
38 }
39
40 ctx.c = make(chan os.Signal, 1)
41 signal.Notify(ctx.c, sig...)
42 go ctx.wait(sig...)
43 return ctx
44 }
45
46 func (s *SignalContext) wait(sig ...os.Signal) {
47 var err error
48 select {
49 case <-s.parent.Done():
50 err = s.parent.Err()
51 case v := <-s.c:
52 if v != nil {
53 err = SignalError{v}
54 }
55 }
56 signal.Stop(s.c)
57 s.m.Lock()
58 if s.err == nil {
59 s.err = err
60 }
61 close(s.c)
62 s.c = nil
63 s.m.Unlock()
64 close(s.done)
65 }
66
67 // Cancel cancels this context, if it hasn’t already been canceled or
68 // received a signal. (If it has, this is safe but has no effect.)
69 func (s *SignalContext) Cancel() {
70 s.m.Lock()
71 if s.c != nil {
72 s.err = context.Canceled
73 select {
74 case s.c <- nil:
75 default:
76 }
77 }
78 s.m.Unlock()
79 }
80
81 // Deadline implements context.Context; a SignalContext’s deadline is
82 // that of its parent.
83 func (s *SignalContext) Deadline() (time.Time, bool) {
84 return s.parent.Deadline()
85 }
86
87 // Value implements context.Context; any value is that of its parent.
88 func (s *SignalContext) Value(key interface{}) interface{} {
89 return s.parent.Value(key)
90 }
91
92 // Done implements context.Context.
93 func (s *SignalContext) Done() <-chan struct{} {
94 return s.done
95 }
96
97 // Err implements context.Context; it returns context.Canceled if the
98 // context was canceled; a SignalError if the context completed due to a
99 // signal; the parent’s error if the parent was done before either of
100 // those; or nil if the context is not yet done.
101 func (s *SignalContext) Err() error {
102 s.m.Lock()
103 err := s.err
104 s.m.Unlock()
105 return err
106 }
107
108 // A SignalError will be returned by a SignalContext’s Err() method when
109 // it was finished due to a signal (rather than e.g. parent
110 // cancellation).
111 type SignalError struct {
112 os.Signal
113 }
114
115 func (e SignalError) Error() string {
116 return e.String()
117 }
118
119 func (e SignalError) String() string {
120 return fmt.Sprintf("received signal: %s", e.Signal)
121 }