Remove ‘Signal’ from structure names
[go-signalcontext.git] / context.go
diff --git a/context.go b/context.go
new file mode 100644 (file)
index 0000000..1bd24d3
--- /dev/null
@@ -0,0 +1,105 @@
+package signalcontext
+
+import (
+       "context"
+       "os"
+       "os/signal"
+       "sync"
+       "time"
+)
+
+// A Context is an implementation of the context.Context interface which
+// completes when a signal (e.g. os.Interrupt) is received.
+//
+// Contexts should be created via the UntilSignal function.
+type Context struct {
+       parent context.Context
+       done   chan struct{}
+       err    error
+
+       // The mutex synchronizes access to err and clearing the
+       // internal Signal channel after initialization.
+       m sync.Mutex
+       c chan os.Signal
+}
+
+// UntilSignal returns a new Context which will complete when the parent
+// does or when any of the specified signals are received.
+func UntilSignal(parent context.Context, sig ...os.Signal) *Context {
+       ctx := new(Context)
+       ctx.parent = parent
+       ctx.done = make(chan struct{})
+
+       if err := parent.Err(); err != nil {
+               close(ctx.done)
+               ctx.err = err
+               return ctx
+       }
+
+       ctx.c = make(chan os.Signal, 1)
+       signal.Notify(ctx.c, sig...)
+       go ctx.wait(sig...)
+       return ctx
+}
+
+func (s *Context) wait(sig ...os.Signal) {
+       var err error
+       select {
+       case <-s.parent.Done():
+               err = s.parent.Err()
+       case v := <-s.c:
+               if v != nil {
+                       err = Error{v}
+               }
+       }
+       signal.Stop(s.c)
+       s.m.Lock()
+       if s.err == nil {
+               s.err = err
+       }
+       close(s.c)
+       s.c = nil
+       s.m.Unlock()
+       close(s.done)
+}
+
+// Cancel cancels this context, if it hasn’t already been completed. (If
+// it has, this is safe but has no effect.)
+func (s *Context) Cancel() {
+       s.m.Lock()
+       if s.c != nil {
+               s.err = context.Canceled
+               select {
+               case s.c <- nil:
+               default:
+               }
+       }
+       s.m.Unlock()
+}
+
+// Deadline implements context.Context; a Context’s deadline is that of
+// its parent.
+func (s *Context) Deadline() (time.Time, bool) {
+       return s.parent.Deadline()
+}
+
+// Value implements context.Context; any value is that of its parent.
+func (s *Context) Value(key interface{}) interface{} {
+       return s.parent.Value(key)
+}
+
+// Done implements context.Context.
+func (s *Context) Done() <-chan struct{} {
+       return s.done
+}
+
+// Err implements context.Context; it returns context.Canceled if the
+// context was canceled; an Error if the context completed due to a
+// signal; the parent’s error if the parent was done before either of
+// those; or nil if the context is not yet done.
+func (s *Context) Err() error {
+       s.m.Lock()
+       err := s.err
+       s.m.Unlock()
+       return err
+}