From 8be290e2690dcb72730afb47d67e77beb3e68fdd Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Tue, 21 Jan 2025 12:15:03 -0700 Subject: [PATCH] fix: Handle the errors when the listener socket is closed gracefully. --- internal/proxy/proxy.go | 91 +++++++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 39 deletions(-) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 3081cc3a0..6b6d590fb 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -731,50 +731,63 @@ func (c *Client) Close() error { // serveSocketMount persistently listens to the socketMounts listener and proxies connections to a // given Cloud SQL instance. -func (c *Client) serveSocketMount(_ context.Context, s *socketMount) error { +func (c *Client) serveSocketMount(ctx context.Context, s *socketMount) error { for { - cConn, err := s.Accept() - if err != nil { - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - c.logger.Errorf("[%s] Error accepting connection: %v", s.inst, err) - // For transient errors, wait a small amount of time to see if it resolves itself - time.Sleep(10 * time.Millisecond) - continue - } - return err - } - // handle the connection in a separate goroutine - go func() { - c.logger.Infof("[%s] Accepted connection from %s", s.inst, cConn.RemoteAddr()) - - // A client has established a connection to the local socket. Before - // we initiate a connection to the Cloud SQL backend, increment the - // connection counter. If the total number of connections exceeds - // the maximum, refuse to connect and close the client connection. - count := atomic.AddUint64(&c.connCount, 1) - defer atomic.AddUint64(&c.connCount, ^uint64(0)) - - if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections { - c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections) - if c.connRefuseNotify != nil { - go c.connRefuseNotify() + select { + case <-ctx.Done(): + // If the context was canceled, do not accept any more connections, + // exit gracefully. + return nil + default: + // Wait to accept a connection. When s.Accept() returns io.EOF, exit + // gracefully. + cConn, err := s.Accept() + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + c.logger.Errorf("[%s] Error accepting connection: %v", s.inst, err) + // For transient errors, wait a small amount of time to see if it resolves itself + time.Sleep(10 * time.Millisecond) + continue + } else if err == io.EOF { + // The socket was closed gracefully. Stop processing connections. + return nil } - _ = cConn.Close() - return + return err } - // give a max of 30 seconds to connect to the instance - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() + // handle the connection in a separate goroutine + go func() { + c.logger.Infof("[%s] Accepted connection from %s", s.inst, cConn.RemoteAddr()) + + // A client has established a connection to the local socket. Before + // we initiate a connection to the Cloud SQL backend, increment the + // connection counter. If the total number of connections exceeds + // the maximum, refuse to connect and close the client connection. + count := atomic.AddUint64(&c.connCount, 1) + defer atomic.AddUint64(&c.connCount, ^uint64(0)) + + if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections { + c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections) + if c.connRefuseNotify != nil { + go c.connRefuseNotify() + } + _ = cConn.Close() + return + } - sConn, err := c.dialer.Dial(ctx, s.inst, s.dialOpts...) - if err != nil { - c.logger.Errorf("[%s] failed to connect to instance: %v", s.inst, err) - _ = cConn.Close() - return - } - c.proxyConn(s.inst, cConn, sConn) - }() + // give a max of 30 seconds to connect to the instance + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + sConn, err := c.dialer.Dial(ctx, s.inst, s.dialOpts...) + if err != nil { + c.logger.Errorf("[%s] failed to connect to instance: %v", s.inst, err) + _ = cConn.Close() + return + } + c.proxyConn(s.inst, cConn, sConn) + }() + } } }