X-Git-Url: https://git.korewanetadesu.com/?p=go-signalcontext.git;a=blobdiff_plain;f=context_test.go;fp=context_test.go;h=a7b30a480a1e854edf6501e455a2a6006103de05;hp=0000000000000000000000000000000000000000;hb=8cf685c73e37c718d1dcfe057caad080196f3e18;hpb=f6b6f942533c75bd0b6884816e652794def00e44 diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..a7b30a4 --- /dev/null +++ b/context_test.go @@ -0,0 +1,122 @@ +package signalcontext + +import ( + "context" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestReceivesSignal(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err()) +} + +func TestForwardsParent(t *testing.T) { + parent, _ := context.WithTimeout( + context.WithValue(context.Background(), t, "test"), + time.Millisecond) + ctx := UntilSignal(parent, syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + dl, ok := ctx.Deadline() + assert.True(t, ok) + assert.WithinDuration(t, time.Now(), dl, time.Millisecond) + assert.EqualValues(t, "test", ctx.Value(t)) + <-ctx.Done() + assert.Equal(t, context.DeadlineExceeded, ctx.Err()) +} + +func TestChildForwardsErr(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + child, cancel := context.WithTimeout(ctx, time.Second) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-child.Done() + <-ctx.Done() + cancel() + assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err()) + assert.Equal(t, Error{syscall.SIGUSR2}, child.Err()) +} + +func TestSignalAfterCancel(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + ctx.Cancel() + <-ctx.Done() + assert.Equal(t, context.Canceled, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + time.Sleep(5 * time.Millisecond) + assert.Equal(t, context.Canceled, ctx.Err()) +} + +func TestCancelAfterSignal(t *testing.T) { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + assert.NoError(t, ctx.Err()) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err()) + ctx.Cancel() + time.Sleep(5 * time.Millisecond) + assert.Equal(t, Error{syscall.SIGUSR2}, ctx.Err()) +} + +func TestImmediateCompletion(t *testing.T) { + parent, cancel := context.WithCancel(context.Background()) + cancel() + <-parent.Done() + ctx := UntilSignal(parent, syscall.SIGUSR2) + // peek inside to be certain we never set up the signal channel. + assert.Nil(t, ctx.c) + select { + case _, ok := <-ctx.Done(): + assert.False(t, ok, "Done() should be closed") + default: + assert.False(t, true, "context should be complete") + } + assert.Equal(t, context.Canceled, ctx.Err()) +} + +func BenchmarkReceivesSignal(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + } +} + +func BenchmarkCancelChildren(b *testing.B) { + children := make([]context.Context, b.N) + cancels := make([]context.CancelFunc, b.N) + b.ResetTimer() + ctx := UntilSignal(context.Background(), syscall.SIGUSR2) + for i := range children { + children[i], cancels[i] = context.WithTimeout(ctx, time.Hour) + } + syscall.Kill(syscall.Getpid(), syscall.SIGUSR2) + <-ctx.Done() + for i := range children { + <-children[i].Done() + } + b.StopTimer() + for i := range cancels { + cancels[i]() + } +} + +func BenchmarkCanceledAsChild(b *testing.B) { + children := make([]context.Context, b.N) + parent, cancel := context.WithCancel(context.Background()) + b.ResetTimer() + for i := range children { + children[i] = UntilSignal(parent, syscall.SIGUSR2) + } + cancel() + <-parent.Done() + for i := range children { + <-children[i].Done() + } +}