diff --git a/pipe.go b/pipe.go index 44340b81..82cbe7af 100644 --- a/pipe.go +++ b/pipe.go @@ -22,6 +22,7 @@ import ( const ( cERROR_PIPE_BUSY = syscall.Errno(231) + cERROR_NO_DATA = syscall.Errno(232) cERROR_PIPE_CONNECTED = syscall.Errno(535) cERROR_SEM_TIMEOUT = syscall.Errno(121) @@ -254,6 +255,36 @@ func (l *win32PipeListener) makeServerPipe() (*win32File, error) { return f, nil } +func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) { + p, err := l.makeServerPipe() + if err != nil { + return nil, err + } + + // Wait for the client to connect. + ch := make(chan error) + go func(p *win32File) { + ch <- connectPipe(p) + }(p) + + select { + case err = <-ch: + if err != nil { + p.Close() + p = nil + } + case <-l.closeCh: + // Abort the connect request by closing the handle. + p.Close() + p = nil + err = <-ch + if err == nil || err == ErrFileClosed { + err = ErrPipeListenerClosed + } + } + return p, err +} + func (l *win32PipeListener) listenerRoutine() { closed := false for !closed { @@ -261,31 +292,20 @@ func (l *win32PipeListener) listenerRoutine() { case <-l.closeCh: closed = true case responseCh := <-l.acceptCh: - p, err := l.makeServerPipe() - if err == nil { - // Wait for the client to connect. - ch := make(chan error) - go func(p *win32File) { - ch <- connectPipe(p) - }(p) - select { - case err = <-ch: - if err != nil { - p.Close() - p = nil - } - case <-l.closeCh: - // Abort the connect request by closing the handle. - p.Close() - p = nil - err = <-ch - if err == nil || err == ErrFileClosed { - err = ErrPipeListenerClosed - } - closed = true + var ( + p *win32File + err error + ) + for { + p, err = l.makeConnectedServerPipe() + // If the connection was immediately closed by the client, try + // again. + if err != cERROR_NO_DATA { + break } } responseCh <- acceptResponse{p, err} + closed = err == ErrPipeListenerClosed } } syscall.Close(l.firstHandle) diff --git a/pipe_test.go b/pipe_test.go index 3bc02dae..c0d1a774 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -422,3 +422,32 @@ func TestEchoWithMessaging(t *testing.T) { <-listenerDone <-clientDone } + +func TestConnectRace(t *testing.T) { + l, err := ListenPipe(testPipeName, nil) + if err != nil { + t.Fatal(err) + } + defer l.Close() + go func() { + for { + s, err := l.Accept() + if err == ErrPipeListenerClosed { + return + } + + if err != nil { + t.Fatal(err) + } + s.Close() + } + }() + + for i := 0; i < 1000; i++ { + c, err := DialPipe(testPipeName, nil) + if err != nil { + t.Fatal(err) + } + c.Close() + } +}