From a15311a7e6f2cd5444fb7bdb91718fb7fad575ce Mon Sep 17 00:00:00 2001 From: Jawad Zaheer Date: Thu, 11 May 2023 10:35:27 +0000 Subject: [PATCH] Added tls configuration flags for ipam --- main.go | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++ main_test.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 main_test.go diff --git a/main.go b/main.go index 5786c78a..909dc2cf 100644 --- a/main.go +++ b/main.go @@ -18,9 +18,11 @@ package main import ( "context" + "crypto/tls" "flag" "fmt" "os" + "strings" "time" ipamv1 "github.com/metal3-io/ip-address-manager/api/v1alpha1" @@ -30,6 +32,7 @@ import ( "k8s.io/client-go/kubernetes/scheme" _ "k8s.io/client-go/plugin/pkg/client/auth/gcp" "k8s.io/client-go/tools/leaderelection/resourcelock" + cliflag "k8s.io/component-base/cli/flag" "k8s.io/component-base/logs" logsv1 "k8s.io/component-base/logs/api/v1" "k8s.io/klog/v2" @@ -39,6 +42,20 @@ import ( // +kubebuilder:scaffold:imports ) +type TLSVersion string + +// Constants for TLS versions. +const ( + TLSVersion12 TLSVersion = "TLS12" + TLSVersion13 TLSVersion = "TLS13" +) + +type TLSOptions struct { + TLSMaxVersion string + TLSMinVersion string + TLSCipherSuites string +} + var ( myscheme = runtime.NewScheme() setupLog = ctrl.Log.WithName("setup") @@ -52,6 +69,8 @@ var ( webhookCertDir string watchFilterValue string logOptions = logs.NewOptions() + tlsOptions = TLSOptions{} + tlsSupportedVersions = []string{"TLS12", "TLS13"} ) func init() { @@ -89,6 +108,23 @@ func main() { flag.IntVar(&ippoolConcurrency, "ippool-concurrency", 10, "Number of ippools to process simultaneously") + flag.StringVar(&tlsOptions.TLSMinVersion, "tls-min-version", "TLS12", + "The minimum TLS version in use by the webhook server.\n"+ + fmt.Sprintf("Possible values are %s.", strings.Join(tlsSupportedVersions, ", ")), + ) + + flag.StringVar(&tlsOptions.TLSMaxVersion, "tls-max-version", "TLS13", + "The maximum TLS version in use by the webhook server.\n"+ + fmt.Sprintf("Possible values are %s.", strings.Join(tlsSupportedVersions, ", ")), + ) + + tlsCipherPreferredValues := cliflag.PreferredTLSCipherNames() + tlsCipherInsecureValues := cliflag.InsecureTLSCipherNames() + flag.StringVar(&tlsOptions.TLSCipherSuites, "tls-cipher-suites", "", + "Comma-separated list of cipher suites for the webhook server. "+ + "If omitted, the default Go cipher suites will be used. \n"+ + "Preferred values: "+strings.Join(tlsCipherPreferredValues, ", ")+". \n"+ + "Insecure values: "+strings.Join(tlsCipherInsecureValues, ", ")+".") flag.Parse() if err := logsv1.ValidateAndApply(logOptions, nil); err != nil { @@ -98,6 +134,11 @@ func main() { // klog.Background will automatically use the right logger. ctrl.SetLogger(klog.Background()) + tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(tlsOptions) + if err != nil { + setupLog.Error(err, "unable to add TLS settings to the webhook server") + os.Exit(1) + } mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{ Scheme: myscheme, @@ -110,6 +151,7 @@ func main() { HealthProbeBindAddress: healthAddr, Namespace: watchNamespace, CertDir: webhookCertDir, + TLSOpts: tlsOptionOverrides, }) if err != nil { setupLog.Error(err, "unable to start manager") @@ -174,3 +216,77 @@ func setupWebhooks(mgr ctrl.Manager) { func concurrency(c int) controller.Options { return controller.Options{MaxConcurrentReconciles: c} } + +// GetTLSOptionOverrideFuncs returns a list of TLS configuration overrides to be used +// by the webhook server. +func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) { + var tlsOptions []func(config *tls.Config) + + tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion) + if err != nil { + return nil, err + } + + tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion) + if err != nil { + return nil, err + } + + if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion { + return nil, fmt.Errorf("TLS version flag min version (%s) is greater than max version (%s)", + options.TLSMinVersion, options.TLSMaxVersion) + } + + tlsOptions = append(tlsOptions, func(cfg *tls.Config) { + cfg.MinVersion = tlsMinVersion + }) + + tlsOptions = append(tlsOptions, func(cfg *tls.Config) { + cfg.MaxVersion = tlsMaxVersion + }) + // Cipher suites should not be set if empty. + if options.TLSMinVersion == string(TLSVersion13) && + options.TLSMaxVersion == string(TLSVersion13) && + options.TLSCipherSuites != "" { + setupLog.Info("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers") + options.TLSCipherSuites = "" + } + + if options.TLSCipherSuites != "" { + tlsCipherSuites := strings.Split(options.TLSCipherSuites, ",") + suites, err := cliflag.TLSCipherSuites(tlsCipherSuites) + if err != nil { + return nil, err + } + + insecureCipherValues := cliflag.InsecureTLSCipherNames() + for _, cipher := range tlsCipherSuites { + for _, insecureCipherName := range insecureCipherValues { + if insecureCipherName == cipher { + setupLog.Info(fmt.Sprintf("warning: use of insecure cipher '%s' detected.", cipher)) + } + } + } + tlsOptions = append(tlsOptions, func(cfg *tls.Config) { + cfg.CipherSuites = suites + }) + } + + return tlsOptions, nil +} + +// GetTLSVersion returns the corresponding tls.Version or error. +func GetTLSVersion(version string) (uint16, error) { + var v uint16 + + switch version { + case string(TLSVersion12): + v = tls.VersionTLS12 + case string(TLSVersion13): + v = tls.VersionTLS13 + default: + return 0, fmt.Errorf("unexpected TLS version %q (must be one of: TLS12, TLS13)", version) + } + + return v, nil +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 00000000..f3a11311 --- /dev/null +++ b/main_test.go @@ -0,0 +1,111 @@ +/* +Copyright 2023 The Metal3 Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "bytes" + "testing" + + . "github.com/onsi/gomega" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" +) + +func TestTLSInsecureCiperSuite(t *testing.T) { + t.Run("test insecure cipher suite passed as TLS flag", func(t *testing.T) { + g := NewWithT(t) + tlsMockOptions := TLSOptions{ + TLSMaxVersion: "TLS13", + TLSMinVersion: "TLS12", + TLSCipherSuites: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", + } + ctrl.Log.WithName("setup") + ctrl.SetLogger(klog.Background()) + + bufWriter := bytes.NewBuffer(nil) + klog.SetOutput(bufWriter) + klog.LogToStderr(false) // this is important, because klog by default logs to stderr only + _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) + g.Expect(err).Should(BeNil()) + g.Expect(bufWriter.String()).Should(ContainSubstring("use of insecure cipher 'TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256' detected.")) + }) +} + +func TestTLSMinAndMaxVersion(t *testing.T) { + t.Run("should fail if TLS min version is greater than max version.", func(t *testing.T) { + g := NewWithT(t) + tlsMockOptions := TLSOptions{ + TLSMaxVersion: "TLS12", + TLSMinVersion: "TLS13", + } + _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) + g.Expect(err.Error()).To(Equal("TLS version flag min version (TLS13) is greater than max version (TLS12)")) + }) +} + +func Test13CipherSuite(t *testing.T) { + t.Run("should reset ciphersuite flag if TLS min and max version are set to 1.3", func(t *testing.T) { + g := NewWithT(t) + + // Here TLS_RSA_WITH_AES_128_CBC_SHA is a tls12 cipher suite. + tlsMockOptions := TLSOptions{ + TLSMaxVersion: "TLS13", + TLSMinVersion: "TLS13", + TLSCipherSuites: "TLS_RSA_WITH_AES_128_CBC_SHA,TLS_AES_256_GCM_SHA384", + } + + ctrl.Log.WithName("setup") + ctrl.SetLogger(klog.Background()) + + bufWriter := bytes.NewBuffer(nil) + klog.SetOutput(bufWriter) + klog.LogToStderr(false) // this is important, because klog by default logs to stderr only + _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) + g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers")) + g.Expect(err).Should(BeNil()) + }) +} + +func TestGetTLSVersion(t *testing.T) { + t.Run("should error out when incorrect tls version passed", func(t *testing.T) { + g := NewWithT(t) + tlsVersion := "TLS11" + _, err := GetTLSVersion(tlsVersion) + g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)")) + }) + t.Run("should pass and output correct tls version", func(t *testing.T) { + const VersionTLS12 uint16 = 771 + g := NewWithT(t) + tlsVersion := "TLS12" + version, err := GetTLSVersion(tlsVersion) + g.Expect(version).To(Equal(VersionTLS12)) + g.Expect(err).Should(BeNil()) + }) +} + +func TestTLSOptions(t *testing.T) { + t.Run("should pass with all the correct options below with no error.", func(t *testing.T) { + g := NewWithT(t) + tlsMockOptions := TLSOptions{ + TLSMinVersion: "TLS12", + TLSMaxVersion: "TLS13", + TLSCipherSuites: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + } + _, err := GetTLSOptionOverrideFuncs(tlsMockOptions) + g.Expect(err).Should(BeNil()) + }) +}