From 4e2aa904b1539ea841a22352fc31a83f51580d6d Mon Sep 17 00:00:00 2001 From: Che Lin Date: Fri, 11 Oct 2024 23:37:37 -0700 Subject: [PATCH 1/7] feat: support callback notify function for when connection is refused by the proxy --- cmd/options.go | 8 ++++++++ cmd/root.go | 11 ++++++----- internal/proxy/proxy.go | 14 ++++++++++---- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/cmd/options.go b/cmd/options.go index b9fa7edca..7ca969654 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -97,3 +97,11 @@ func WithLazyRefresh() Option { c.conf.LazyRefresh = true } } + +// WithConnRefuseNotify configures the Proxy to start a goroutine and run the +// given notify callback function in the event of a connection refuse. +func WithConnRefuseNotify(n func(string)) Option { + return func(c *Command) { + c.connRefuseNotify = n + } +} diff --git a/cmd/root.go b/cmd/root.go index e8706a48f..60ce402cf 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -86,10 +86,11 @@ func Execute() { // Command represents an invocation of the Cloud SQL Auth Proxy. type Command struct { *cobra.Command - conf *proxy.Config - logger cloudsql.Logger - dialer cloudsql.Dialer - cleanup func() error + conf *proxy.Config + logger cloudsql.Logger + dialer cloudsql.Dialer + cleanup func() error + connRefuseNotify func(string) } var longHelp = ` @@ -1025,7 +1026,7 @@ func runSignalWrapper(cmd *Command) (err error) { startCh := make(chan *proxy.Client) go func() { defer close(startCh) - p, err := proxy.NewClient(ctx, cmd.dialer, cmd.logger, cmd.conf) + p, err := proxy.NewClient(ctx, cmd.dialer, cmd.logger, cmd.conf, cmd.connRefuseNotify) if err != nil { cmd.logger.Debugf("Error starting proxy: %v", err) shutdownCh <- fmt.Errorf("unable to start: %v", err) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index d0eb3da82..e7cd09213 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -503,12 +503,14 @@ type Client struct { logger cloudsql.Logger + connRefuseNotify func(string) + fuseMount } // NewClient completes the initial setup required to get the proxy to a "steady" // state. -func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *Config) (*Client, error) { +func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *Config, connRefuseNotify func(string)) (*Client, error) { // Check if the caller has configured a dialer. // Otherwise, initialize a new one. if d == nil { @@ -523,9 +525,10 @@ func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf * } c := &Client{ - logger: l, - dialer: d, - conf: conf, + logger: l, + dialer: d, + connRefuseNotify: connRefuseNotify, + conf: conf, } if conf.FUSEDir != "" { @@ -753,6 +756,9 @@ func (c *Client) serveSocketMount(_ context.Context, s *socketMount) error { if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections { c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections) + if c.connRefuseNotify != nil { + go c.connRefuseNotify("too many connections") + } _ = cConn.Close() return } From ef29c014fc7ceb9212ae3d314bc170b6d2ede680 Mon Sep 17 00:00:00 2001 From: Che Lin Date: Tue, 22 Oct 2024 20:42:54 -0700 Subject: [PATCH 2/7] Address feedback for PR #2308 Remove uninteresting argument in the callback function. --- cmd/options.go | 2 +- cmd/root.go | 2 +- internal/proxy/proxy.go | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/options.go b/cmd/options.go index 7ca969654..f166d8fba 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -100,7 +100,7 @@ func WithLazyRefresh() Option { // WithConnRefuseNotify configures the Proxy to start a goroutine and run the // given notify callback function in the event of a connection refuse. -func WithConnRefuseNotify(n func(string)) Option { +func WithConnRefuseNotify(n func()) Option { return func(c *Command) { c.connRefuseNotify = n } diff --git a/cmd/root.go b/cmd/root.go index 60ce402cf..ec28e6af9 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -90,7 +90,7 @@ type Command struct { logger cloudsql.Logger dialer cloudsql.Dialer cleanup func() error - connRefuseNotify func(string) + connRefuseNotify func() } var longHelp = ` diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index e7cd09213..0016e9468 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -503,14 +503,14 @@ type Client struct { logger cloudsql.Logger - connRefuseNotify func(string) + connRefuseNotify func() fuseMount } // NewClient completes the initial setup required to get the proxy to a "steady" // state. -func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *Config, connRefuseNotify func(string)) (*Client, error) { +func NewClient(ctx context.Context, d cloudsql.Dialer, l cloudsql.Logger, conf *Config, connRefuseNotify func()) (*Client, error) { // Check if the caller has configured a dialer. // Otherwise, initialize a new one. if d == nil { @@ -757,7 +757,7 @@ func (c *Client) serveSocketMount(_ context.Context, s *socketMount) error { if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections { c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections) if c.connRefuseNotify != nil { - go c.connRefuseNotify("too many connections") + go c.connRefuseNotify() } _ = cConn.Close() return From d39c355b8fcdab07de0d1b2843391e60279499b4 Mon Sep 17 00:00:00 2001 From: Che Lin Date: Tue, 22 Oct 2024 21:26:31 -0700 Subject: [PATCH 3/7] Address feedback for PR #2308 - 2 --- internal/proxy/fuse_test.go | 4 ++-- internal/proxy/proxy_other_test.go | 3 ++- internal/proxy/proxy_test.go | 36 ++++++++++++++++++------------ 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/internal/proxy/fuse_test.go b/internal/proxy/fuse_test.go index bb205fe1b..0379c9072 100644 --- a/internal/proxy/fuse_test.go +++ b/internal/proxy/fuse_test.go @@ -47,7 +47,7 @@ func newTestClient(t *testing.T, d cloudsql.Dialer, fuseDir, fuseTempDir string) conf := &proxy.Config{FUSEDir: fuseDir, FUSETempDir: fuseTempDir} // This context is only used to call the Cloud SQL API - c, err := proxy.NewClient(context.Background(), d, testLogger, conf) + c, err := proxy.NewClient(context.Background(), d, testLogger, conf, nil) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -424,7 +424,7 @@ func TestFUSEWithBadDir(t *testing.T) { t.Skip("skipping fuse tests in short mode.") } conf := &proxy.Config{FUSEDir: "/not/a/dir", FUSETempDir: randTmpDir(t)} - _, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, conf) + _, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, conf, nil) if err == nil { t.Fatal("proxy client should fail with bad dir") } diff --git a/internal/proxy/proxy_other_test.go b/internal/proxy/proxy_other_test.go index 56fded49d..b3503848a 100644 --- a/internal/proxy/proxy_other_test.go +++ b/internal/proxy/proxy_other_test.go @@ -51,7 +51,8 @@ func TestFuseClosesGracefully(t *testing.T) { FUSEDir: t.TempDir(), FUSETempDir: t.TempDir(), Token: "mytoken", - }) + }, + nil) if err != nil { t.Fatal(err) } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 8d64510aa..6f1872325 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -324,7 +324,7 @@ func TestClientInitialization(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in, nil) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -370,7 +370,11 @@ func TestClientLimitsMaxConnections(t *testing.T) { }, MaxConnections: 1, } - c, err := proxy.NewClient(context.Background(), d, testLogger, in) + callbackGot := 0 + connRefuseNotify := func() { + callbackGot += 1 + } + c, err := proxy.NewClient(context.Background(), d, testLogger, in, connRefuseNotify) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -410,6 +414,10 @@ func TestClientLimitsMaxConnections(t *testing.T) { if got := d.dialAttempts(); got != want { t.Fatalf("dial attempts did not match expected, want = %v, got = %v", want, got) } + + if callbackGot == 0 { + t.Fatal("connRefuseNotifyCallback is not called") + } } func tryTCPDial(t *testing.T, addr string) net.Conn { @@ -442,7 +450,7 @@ func TestClientCloseWaitsForActiveConnections(t *testing.T) { }, WaitOnClose: 1 * time.Second, } - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in, nil) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -464,7 +472,7 @@ func TestClientClosesCleanly(t *testing.T) { {Name: "proj:reg:inst"}, }, } - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in, nil) if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } @@ -486,7 +494,7 @@ func TestClosesWithError(t *testing.T) { {Name: "proj:reg:inst"}, }, } - c, err := proxy.NewClient(context.Background(), &errorDialer{}, testLogger, in) + c, err := proxy.NewClient(context.Background(), &errorDialer{}, testLogger, in, nil) if err != nil { t.Fatalf("proxy.NewClient error want = nil, got = %v", err) } @@ -542,13 +550,13 @@ func TestClientInitializationWorksRepeatedly(t *testing.T) { }, } - c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in, nil) if err != nil { t.Fatalf("want error = nil, got = %v", err) } c.Close() - c, err = proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) + c, err = proxy.NewClient(ctx, &fakeDialer{}, testLogger, in, nil) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -562,7 +570,7 @@ func TestClientNotifiesCallerOnServe(t *testing.T) { {Name: "proj:region:pg"}, }, } - c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in) + c, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, in, nil) if err != nil { t.Fatalf("want error = nil, got = %v", err) } @@ -595,7 +603,7 @@ func TestClientConnCount(t *testing.T) { MaxConnections: 10, } - c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in) + c, err := proxy.NewClient(context.Background(), &fakeDialer{}, testLogger, in, nil) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -636,7 +644,7 @@ func TestCheckConnections(t *testing.T) { }, } d := &fakeDialer{} - c, err := proxy.NewClient(context.Background(), d, testLogger, in) + c, err := proxy.NewClient(context.Background(), d, testLogger, in, nil) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -664,7 +672,7 @@ func TestCheckConnections(t *testing.T) { }, } ed := &errorDialer{} - c, err = proxy.NewClient(context.Background(), ed, testLogger, in) + c, err = proxy.NewClient(context.Background(), ed, testLogger, in, nil) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -690,7 +698,7 @@ func TestRunConnectionCheck(t *testing.T) { RunConnectionTest: true, } d := &fakeDialer{} - c, err := proxy.NewClient(context.Background(), d, testLogger, in) + c, err := proxy.NewClient(context.Background(), d, testLogger, in, nil) if err != nil { t.Fatalf("proxy.NewClient error: %v", err) } @@ -757,7 +765,7 @@ func TestProxyInitializationWithFailedUnixSocket(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - _, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in) + _, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in, nil) if err == nil { t.Fatalf("want non nil error, got = %v", err) } @@ -801,7 +809,7 @@ func TestProxyMultiInstances(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - _, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in) + _, err := proxy.NewClient(ctx, &fakeDialer{}, testLogger, tc.in, nil) if tc.wantSuccess != (err == nil) { t.Fatalf("want return = %v, got = %v", tc.wantSuccess, err == nil) } From 9f2190bc014c151e1712eb6b7bb39afd03a5cb26 Mon Sep 17 00:00:00 2001 From: Che Lin Date: Wed, 23 Oct 2024 11:14:49 -0700 Subject: [PATCH 4/7] Address feedback for PR #2308 - 3 --- cmd/options.go | 4 ++-- internal/healthcheck/healthcheck_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/options.go b/cmd/options.go index f166d8fba..b8ff855f4 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -98,8 +98,8 @@ func WithLazyRefresh() Option { } } -// WithConnRefuseNotify configures the Proxy to start a goroutine and run the -// given notify callback function in the event of a connection refuse. +// WithConnRefuseNotify configures the Proxy to call the provided function when +// a connection is refused. The notification function is run in a goroutine. func WithConnRefuseNotify(n func()) Option { return func(c *Command) { c.connRefuseNotify = n diff --git a/internal/healthcheck/healthcheck_test.go b/internal/healthcheck/healthcheck_test.go index b4729bb3f..bb86b72ea 100644 --- a/internal/healthcheck/healthcheck_test.go +++ b/internal/healthcheck/healthcheck_test.go @@ -78,7 +78,7 @@ func newProxyWithParams(t *testing.T, maxConns uint64, dialer cloudsql.Dialer, i Instances: instances, MaxConnections: maxConns, } - p, err := proxy.NewClient(context.Background(), dialer, logger, c) + p, err := proxy.NewClient(context.Background(), dialer, logger, c, nil) if err != nil { t.Fatalf("proxy.NewClient: %v", err) } From 769b8cd06dc4e6a791f76f9da56b35a44b3e19ff Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Wed, 23 Oct 2024 14:33:53 -0400 Subject: [PATCH 5/7] chore: fix lint --- internal/proxy/proxy_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 6f1872325..2f8d1d8fd 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -372,7 +372,7 @@ func TestClientLimitsMaxConnections(t *testing.T) { } callbackGot := 0 connRefuseNotify := func() { - callbackGot += 1 + callbackGot++ } c, err := proxy.NewClient(context.Background(), d, testLogger, in, connRefuseNotify) if err != nil { From defaf4dafe3f605edefde647635ab9b9113cf28f Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 23 Oct 2024 18:40:50 +0000 Subject: [PATCH 6/7] chore: run goimports --- cmd/options.go | 6 +++--- internal/proxy/proxy.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/options.go b/cmd/options.go index b8ff855f4..e0007cdf4 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -101,7 +101,7 @@ func WithLazyRefresh() Option { // WithConnRefuseNotify configures the Proxy to call the provided function when // a connection is refused. The notification function is run in a goroutine. func WithConnRefuseNotify(n func()) Option { - return func(c *Command) { - c.connRefuseNotify = n - } + return func(c *Command) { + c.connRefuseNotify = n + } } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 0016e9468..3081cc3a0 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -757,7 +757,7 @@ func (c *Client) serveSocketMount(_ context.Context, s *socketMount) error { if c.conf.MaxConnections > 0 && count > c.conf.MaxConnections { c.logger.Infof("max connections (%v) exceeded, refusing new connection", c.conf.MaxConnections) if c.connRefuseNotify != nil { - go c.connRefuseNotify() + go c.connRefuseNotify() } _ = cConn.Close() return From 598570b7c812dd04049fe345e62fffe8553495b3 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 23 Oct 2024 18:56:19 +0000 Subject: [PATCH 7/7] chore: use mutex in test func --- internal/proxy/proxy_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2f8d1d8fd..0faac73a3 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -372,6 +372,8 @@ func TestClientLimitsMaxConnections(t *testing.T) { } callbackGot := 0 connRefuseNotify := func() { + d.mu.Lock() + defer d.mu.Unlock() callbackGot++ } c, err := proxy.NewClient(context.Background(), d, testLogger, in, connRefuseNotify)