package signalcontext import ( "context" "fmt" "os" "os/signal" "sync" "time" ) // A SignalContext is an implementation of the context.Context interface // which completes when a signal (e.g. os.Interrupt) is received. // // SignalContexts should be created via the UntilSignal function. type SignalContext struct { parent context.Context done chan struct{} err error // The mutex synchronizes access to err and clearing the // internal Signal channel after initialization. m sync.Mutex c chan os.Signal } // UntilSignal returns a new SignalContext which will complete when the // parent does or when any of the specified signals are received. func UntilSignal(parent context.Context, sig ...os.Signal) *SignalContext { ctx := new(SignalContext) ctx.parent = parent ctx.done = make(chan struct{}) if err := parent.Err(); err != nil { close(ctx.done) ctx.err = err return ctx } ctx.c = make(chan os.Signal, 1) signal.Notify(ctx.c, sig...) go ctx.wait(sig...) return ctx } func (s *SignalContext) wait(sig ...os.Signal) { var err error select { case <-s.parent.Done(): err = s.parent.Err() case v := <-s.c: if v != nil { err = SignalError{v} } } signal.Stop(s.c) s.m.Lock() if s.err == nil { s.err = err } close(s.c) s.c = nil s.m.Unlock() close(s.done) } // Cancel cancels this context, if it hasn’t already been canceled or // received a signal. (If it has, this is safe but has no effect.) func (s *SignalContext) Cancel() { s.m.Lock() if s.c != nil { s.err = context.Canceled select { case s.c <- nil: default: } } s.m.Unlock() } // Deadline implements context.Context; a SignalContext’s deadline is // that of its parent. func (s *SignalContext) Deadline() (time.Time, bool) { return s.parent.Deadline() } // Value implements context.Context; any value is that of its parent. func (s *SignalContext) Value(key interface{}) interface{} { return s.parent.Value(key) } // Done implements context.Context. func (s *SignalContext) Done() <-chan struct{} { return s.done } // Err implements context.Context; it returns context.Canceled if the // context was canceled; a SignalError if the context completed due to a // signal; the parent’s error if the parent was done before either of // those; or nil if the context is not yet done. func (s *SignalContext) Err() error { s.m.Lock() err := s.err s.m.Unlock() return err } // A SignalError will be returned by a SignalContext’s Err() method when // it was finished due to a signal (rather than e.g. parent // cancellation). type SignalError struct { os.Signal } func (e SignalError) Error() string { return e.String() } func (e SignalError) String() string { return fmt.Sprintf("received signal: %s", e.Signal) }