diff --git a/app/certificates.go b/app/certificates.go new file mode 100644 index 00000000..f560f909 --- /dev/null +++ b/app/certificates.go @@ -0,0 +1,502 @@ +package app + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "math/big" + "os" + "time" + + "github.com/go-playground/validator/v10" + "github.com/temporalio/tcld/utils" + "github.com/urfave/cli/v2" +) + +const ( + maxCADuration = 365 * 24 * time.Hour + minCADuration = 7 * 24 * time.Hour + + caPrivateKeyFileFlagName = "ca-key-file" + certificateFilterFileFlagName = "certificate-filter-file" + certificateFilterInputFlagName = "certificate-filter-input" + + pemEncodingCertificateType = "CERTIFICATE" + pemEncodingPrivateKeyType = "PRIVATE KEY" +) + +func generateRandomString(n int) (string, error) { + const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-" + ret := make([]byte, n) + for i := 0; i < n; i++ { + num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + if err != nil { + return "", err + } + ret[i] = letters[num.Int64()] + } + + return string(ret), nil +} + +type generateCACertificateInput struct { + Organization string `validate:"required"` + ValidityPeriod time.Duration `validate:"required"` + RSAAlgorithm bool +} + +func generateCACertificate( + input generateCACertificateInput, +) (caPEM, caPrivateKeyPEM []byte, err error) { + validator := validator.New() + if err := validator.Struct(input); err != nil { + return nil, nil, err + } + + serialNumber, err := generateSerialNumber() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate a random serial number: %w", err) + } + + randomLetters, err := generateRandomString(4) + if err != nil { + return nil, nil, fmt.Errorf("unable to generate random string for dns name") + } + dnsRoot := fmt.Sprintf("client.root.%s.%s", input.Organization, randomLetters) + + keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageCRLSign | x509.KeyUsageCertSign + if input.RSAAlgorithm { + // Only RSA subject keys should have the KeyEncipherment KeyUsage bits set. In + // the context of TLS this KeyUsage is particular to RSA key exchange and + // authentication. + keyUsage |= x509.KeyUsageKeyEncipherment + } + + now := time.Now().UTC() + conf := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{input.Organization}, + }, + NotBefore: now.Add(-time.Minute), // grace of 1 min + NotAfter: now.Add(input.ValidityPeriod), + IsCA: true, + KeyUsage: keyUsage, + BasicConstraintsValid: true, + DNSNames: []string{dnsRoot}, + MaxPathLen: 0, + } + + var key any + if input.RSAAlgorithm { + key, err = rsa.GenerateKey(rand.Reader, 4096) + } else { + key, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + } + if err != nil { + return nil, nil, fmt.Errorf("unable to generate key: %w", err) + } + + var publicKey any + switch k := key.(type) { + case *rsa.PrivateKey: + publicKey = &k.PublicKey + case *ecdsa.PrivateKey: + publicKey = &k.PublicKey + } + + cert, err := x509.CreateCertificate(rand.Reader, conf, conf, publicKey, key) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate certificate: %w", err) + } + caPEMBuffer := new(bytes.Buffer) + err = pem.Encode(caPEMBuffer, &pem.Block{ + Type: pemEncodingCertificateType, + Bytes: cert, + }) + if err != nil { + return nil, nil, err + } + privBytes, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + return nil, nil, fmt.Errorf("unable to marshal key: %w", err) + } + caPrivateKeyPEMBuffer := new(bytes.Buffer) + err = pem.Encode(caPrivateKeyPEMBuffer, &pem.Block{ + Type: pemEncodingPrivateKeyType, + Bytes: privBytes, + }) + if err != nil { + return nil, nil, err + } + return caPEMBuffer.Bytes(), caPrivateKeyPEMBuffer.Bytes(), nil +} + +type generateEndEntityCertificateInput struct { + Organization string `validate:"required"` + OrganizationUnit string + + ValidityPeriod time.Duration + CaPem []byte `validate:"required"` + CaPrivateKeyPEM []byte `validate:"required"` +} + +func parseCACerts(caPem, caPrivKeyPem []byte) (*x509.Certificate, any, bool, error) { + + pemBlock, _ := pem.Decode(caPem) + if pemBlock == nil { + return nil, nil, false, fmt.Errorf("decoding ca cert failed") + } + caCert, err := x509.ParseCertificate(pemBlock.Bytes) + if err != nil { + return nil, nil, false, fmt.Errorf("decoding ca cert failed: %w", err) + } + pemBlock, _ = pem.Decode(caPrivKeyPem) + if pemBlock == nil { + return nil, nil, false, fmt.Errorf("decoding ca key failed") + } + caPrivateKey, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) + if err != nil { + return nil, nil, false, fmt.Errorf("parsing ca key failed: %w", err) + } + _, isRSA := caPrivateKey.(*rsa.PrivateKey) + return caCert, caPrivateKey, isRSA, nil +} + +func generateSerialNumber() (*big.Int, error) { + max := new(big.Int) + max.Exp(big.NewInt(2), big.NewInt(130), nil).Sub(max, big.NewInt(1)) + // Generate cryptographically strong pseudo-random between 0 - max + n, err := rand.Int(rand.Reader, max) + if err != nil { + return nil, err + } + return n, err +} + +func generateEndEntityCertificate( + input generateEndEntityCertificateInput, +) (certPEM, certPrivateKeyPEM []byte, err error) { + validator := validator.New() + if err := validator.Struct(input); err != nil { + return nil, nil, err + } + caCert, caPrivateKey, isRSA, err := parseCACerts(input.CaPem, input.CaPrivateKeyPEM) + if err != nil { + return nil, nil, err + } + randomLetters, err := generateRandomString(4) + if err != nil { + return nil, nil, fmt.Errorf("unable to generate random string for dns name") + } + dnsRoot := fmt.Sprintf("client.endentity.%s.%s", input.Organization, randomLetters) + serialNumber, err := generateSerialNumber() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate a random serial number: %w", err) + } + subject := pkix.Name{ + Organization: []string{input.Organization}, + OrganizationalUnit: []string{input.OrganizationUnit}, + } + + now := time.Now().UTC() + var notAfter time.Time + if input.ValidityPeriod != 0 { + // a validity period was provided by the user, validate it + notAfter = now.Add(input.ValidityPeriod).UTC() + if notAfter.After(caCert.NotAfter.UTC()) { + return nil, nil, fmt.Errorf("validity period of %s puts certificate's expiry after certificate authority's expiry %s by %s", + input.ValidityPeriod, caCert.NotAfter.UTC().String(), notAfter.Sub(caCert.NotAfter.UTC())) + } + } else { + // set notAfter to ca's notAfter minus one day when validity period is not explicitly set by the user. + notAfter = caCert.NotAfter.UTC().Add(-24 * time.Hour) + } + conf := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: subject, + NotBefore: now.Add(-time.Minute), // grace of 1 min + NotAfter: notAfter, + BasicConstraintsValid: true, + DNSNames: []string{dnsRoot}, + } + var key any + if isRSA { + key, err = rsa.GenerateKey(rand.Reader, 4096) + } else { + key, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + } + if err != nil { + return nil, nil, fmt.Errorf("unable to generate key: %w", err) + } + + var publicKey any + switch k := key.(type) { + case *rsa.PrivateKey: + publicKey = &k.PublicKey + case *ecdsa.PrivateKey: + publicKey = &k.PublicKey + } + cert, err := x509.CreateCertificate(rand.Reader, conf, caCert, publicKey, caPrivateKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate certificate: %w", err) + } + + certPEMBuffer := new(bytes.Buffer) + err = pem.Encode(certPEMBuffer, &pem.Block{ + Type: pemEncodingCertificateType, + Bytes: cert, + }) + if err != nil { + return nil, nil, err + } + + privBytes, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + return nil, nil, fmt.Errorf("unable to marshal key: %w", err) + } + certPrivateKeyPEMBuffer := new(bytes.Buffer) + err = pem.Encode(certPrivateKeyPEMBuffer, &pem.Block{ + Type: pemEncodingPrivateKeyType, + Bytes: privBytes, + }) + if err != nil { + return nil, nil, err + } + return certPEMBuffer.Bytes(), certPrivateKeyPEMBuffer.Bytes(), nil +} + +func NewCertificatesCommand() (CommandOut, error) { + return CommandOut{ + Command: &cli.Command{ + Name: "generate-certificates", + Aliases: []string{"gen"}, + Usage: "Commands for generating certificate authority and end-entity TLS certificates", + Subcommands: []*cli.Command{ + { + Name: "certificate-authority-certificate", + Usage: "Generate a certificate authority certificate", + Aliases: []string{"ca"}, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "organization", + Usage: "The name of the organization", + Aliases: []string{"org"}, + Required: true, + }, + &cli.StringFlag{ + Name: "validity-period", + Usage: "The duration for which the certificate is valid for. example: 30d10h (30 days and 10 hrs)", + Aliases: []string{"d"}, + Required: true, + Action: func(_ *cli.Context, v string) error { + d, err := utils.ParseDuration(v) + if err != nil { + return fmt.Errorf("failed to parse validity-period: %w", err) + } + if d > maxCADuration { + return fmt.Errorf("validity-period cannot be more than: %s", maxCADuration) + } + if d <= minCADuration { + return fmt.Errorf("validity-period cannot be less than: %s", minCADuration) + } + return nil + }, + }, + &cli.PathFlag{ + Name: CaCertificateFileFlagName, + Usage: "The path where the generated x509 certificate will be stored", + Aliases: []string{"ca-cert"}, + Required: true, + }, + &cli.PathFlag{ + Name: caPrivateKeyFileFlagName, + Usage: "The path where the certificate's private key will be stored", + Aliases: []string{"ca-key"}, + Required: true, + }, + &cli.BoolFlag{ + Name: "rsa-algorithm", + Aliases: []string{"rsa"}, + Usage: "Generates a 4096-bit RSA keypair instead of an ECDSA P-384 keypair (the recommended default) for the certificate (optional)", + }, + }, + Action: func(ctx *cli.Context) error { + validityPeriod, err := utils.ParseDuration(ctx.String("validity-period")) + if err != nil { + return fmt.Errorf("failed to parse validity-period: %w", err) + } + caPem, caPrivKey, err := generateCACertificate(generateCACertificateInput{ + Organization: ctx.String("organization"), + ValidityPeriod: validityPeriod, + RSAAlgorithm: ctx.Bool("rsa-algorithm"), + }) + if err != nil { + return fmt.Errorf("failed to generate ca certificate: %w", err) + } + + return writeCertificates( + ctx, + "certificate authority", + caPem, + caPrivKey, + ctx.Path(CaCertificateFileFlagName), + ctx.Path(caPrivateKeyFileFlagName), + ) + }, + }, + { + Name: "end-entity-certificate", + Usage: "Generate an end-entity certificate", + Aliases: []string{"leaf"}, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "organization", + Usage: "The name of the organization", + Aliases: []string{"org"}, + Required: true, + }, + &cli.StringFlag{ + Name: "organization-unit", + Usage: "The name of the organization unit (optional)", + }, + &cli.StringFlag{ + Name: "validity-period", + Usage: "The duration for which the end entity certificate is valid for. example: 30d10h (30 days and 10 hrs). By default the generated certificate expires 24 hours before the certificate authority expires (optional)", + Aliases: []string{"d"}, + Action: func(_ *cli.Context, v string) error { + if _, err := utils.ParseDuration(v); err != nil { + return fmt.Errorf("failed to parse validity-period: %w", err) + } + return nil + }, + }, + &cli.PathFlag{ + Name: CaCertificateFileFlagName, + Usage: "The path of the x509 certificate for the certificate authority", + Aliases: []string{"ca-cert"}, + Required: true, + }, + &cli.PathFlag{ + Name: caPrivateKeyFileFlagName, + Usage: "The path of the private key for the certificate authority", + Aliases: []string{"ca-key"}, + Required: true, + }, + &cli.PathFlag{ + Name: "certificate-file", + Usage: "The path where the generated x509 certificate will be stored", + Aliases: []string{"cert"}, + Required: true, + }, + &cli.PathFlag{ + Name: "key-file", + Usage: "The path where the certificate's private key will be stored", + Aliases: []string{"key"}, + Required: true, + }, + }, + Action: func(ctx *cli.Context) error { + var validityPeriod time.Duration + if s := ctx.String("validity-period"); s != "" { + var err error + validityPeriod, err = utils.ParseDuration(ctx.String("validity-period")) + if err != nil { + return err + } + } + caPem, err := ioutil.ReadFile(ctx.Path(CaCertificateFileFlagName)) + if err != nil { + return fmt.Errorf("failed to read %s: %w", CaCertificateFileFlagName, err) + } + caPrivKey, err := ioutil.ReadFile(ctx.Path(caPrivateKeyFileFlagName)) + if err != nil { + return fmt.Errorf("failed to read %s: %w", caPrivateKeyFileFlagName, err) + } + certPem, certPrivKey, err := generateEndEntityCertificate(generateEndEntityCertificateInput{ + Organization: ctx.String("organization"), + OrganizationUnit: ctx.String("organization-unit"), + + ValidityPeriod: validityPeriod, + CaPem: caPem, + CaPrivateKeyPEM: caPrivKey, + }) + if err != nil { + return fmt.Errorf("failed to generate end-entity certificate: %w", err) + } + return writeCertificates( + ctx, + "end entity certificate", + certPem, + certPrivKey, + ctx.Path("certificate-file"), + ctx.Path("key-file"), + ) + }, + }, + }, + }, + }, nil +} + +func checkPath(ctx *cli.Context, path string) (bool, error) { + if fi, err := os.Stat(path); !errors.Is(err, os.ErrNotExist) { + // the file exists, + switch mode := fi.Mode(); { + case mode.IsRegular(): + yes, err := ConfirmPrompt( + ctx, + fmt.Sprintf("file already exists at path %s, do you want to overwrite:", path), + ) + if err != nil { + return false, fmt.Errorf("failed to confirm: %w", err) + } + return yes, nil + case mode.IsDir(): + return false, fmt.Errorf("path cannot be a directory: %s ", path) + default: + return false, fmt.Errorf("invalid file path: %s (file mode=%s)", path, mode.String()) + } + } + return true, nil +} + +func writeCertificates(ctx *cli.Context, typ string, cert, key []byte, certPath, keyPath string) error { + if cont, err := checkPath(ctx, certPath); err != nil || !cont { + return err + } + if cont, err := checkPath(ctx, keyPath); err != nil || !cont { + return err + } + + yes, err := ConfirmPrompt( + ctx, + fmt.Sprintf("storing the %s (private) key at %s, do not share this key with anyone. confirm:", typ, keyPath), + ) + if err != nil { + return fmt.Errorf("failed to confirm: %w", err) + } + if !yes { + return nil + } + err = ioutil.WriteFile(certPath, cert, 0644) + if err != nil { + return fmt.Errorf("failed to write end-entity certificate: %w", err) + + } + err = ioutil.WriteFile(keyPath, key, 0600) + if err != nil { + return fmt.Errorf("failed to write end-entity key: %w", err) + } + fmt.Printf("%s generated at: %s\n", typ, certPath) + fmt.Printf("%s key generated at: %s\n", typ, keyPath) + return nil +} diff --git a/app/certificates_test.go b/app/certificates_test.go new file mode 100644 index 00000000..cc45d12a --- /dev/null +++ b/app/certificates_test.go @@ -0,0 +1,199 @@ +package app + +import ( + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/stretchr/testify/suite" + "github.com/urfave/cli/v2" +) + +func TestCertificates(t *testing.T) { + suite.Run(t, new(CertificatesTestSuite)) +} + +type CertificatesTestSuite struct { + suite.Suite + cliApp *cli.App + mockCtrl *gomock.Controller +} + +func (s *CertificatesTestSuite) SetupTest() { + s.mockCtrl = gomock.NewController(s.T()) + + out, err := NewCertificatesCommand() + s.Require().NoError(err) + + AutoConfirmFlag.Value = true + s.cliApp = &cli.App{ + Name: "test", + Commands: []*cli.Command{out.Command}, + Flags: []cli.Flag{ + AutoConfirmFlag, + }, + } +} + +func (s *CertificatesTestSuite) RunCmd(args ...string) error { + return s.cliApp.Run(append([]string{"tcld"}, args...)) +} + +func (s *CertificatesTestSuite) AfterTest(suiteName, testName string) { + s.mockCtrl.Finish() +} + +func (s *CertificatesTestSuite) TestCertificateGenerateCore() { + type args struct { + rsa bool + caValidityPeriod time.Duration + endEntityValidityPeriod time.Duration + organization string + } + tests := []struct { + name string + args args + caGenerationErrMsg string + endEntityGenerationErrMsg string + }{ + { + "success - defaults", + args{ + organization: "test-certificate", + caValidityPeriod: 365 * 24 * time.Hour, + }, + "", + "", + }, + { + "success - options", + args{ + rsa: true, + organization: "test-certificate", + caValidityPeriod: 365 * 24 * time.Hour, + endEntityValidityPeriod: 24 * time.Hour, + }, + "", + "", + }, + { + "failure - missing required fields", + args{}, + "Error:Field validation for 'Organization' failed on the 'required' tag", + "", + }, + { + "failure - end-entity validity period too big", + args{ + rsa: true, + organization: "test-certificate", + caValidityPeriod: 365 * 24 * time.Hour, + endEntityValidityPeriod: 500 * 24 * time.Hour, + }, + "", + "validity period of 12000h0m0s puts certificate's expiry after certificate authority's expiry", + }, + } + for _, tt := range tests { + s.Run(tt.name, func() { + caPem, caPrivKeyPem, err := generateCACertificate(generateCACertificateInput{ + Organization: tt.args.organization, + ValidityPeriod: tt.args.caValidityPeriod, + RSAAlgorithm: tt.args.rsa, + }) + + if tt.caGenerationErrMsg == "" { + s.NoError(err, "ca cert generation failed") + } else { + s.Error(err, "expected ca cert generation to fail") + s.ErrorContains(err, tt.caGenerationErrMsg) + return + } + + certBytes, certKeyBytes, err := generateEndEntityCertificate(generateEndEntityCertificateInput{ + Organization: tt.args.organization + "-leaf", + ValidityPeriod: tt.args.endEntityValidityPeriod, + CaPem: caPem, + CaPrivateKeyPEM: caPrivKeyPem, + }) + + if tt.endEntityGenerationErrMsg == "" { + s.NoError(err, "end-entity cert generation failed") + if err != nil { + return + } + + // Even though these are not CA certs, we use this function to make sure + // the leaf certificates we have generated are actually valid + _, _, _, err = parseCACerts(certBytes, certKeyBytes) + s.NoError(err) + + } else { + s.Error(err, "expected end-entity cert generation to fail") + s.ErrorContains(err, tt.endEntityGenerationErrMsg) + } + }) + } +} + +func (s *CertificatesTestSuite) TestGenerateCACertificateCMD() { + tests := []struct { + name string + args []string + expectErrMsg string + }{ + { + name: "generate ca success", + args: []string{"gen", "ca", "--org", "testorg", "-d", "8d", "--ca-cert", "/tmp/" + uuid.NewString(), "--ca-key", "/tmp/" + uuid.NewString()}, + expectErrMsg: "", + }, + { + name: "generate ca failure - validity period too short", + args: []string{"gen", "ca", "--org", "testorg", "-d", "3d", "--ca-cert", "/tmp/" + uuid.NewString(), "--ca-key", "/tmp/" + uuid.NewString()}, + expectErrMsg: "validity-period cannot be less than: 168h0m0s", + }, + { + name: "generate ca failure - validity period too long", + args: []string{"gen", "ca", "--org", "testorg", "-d", "1000d", "--ca-cert", "/tmp/" + uuid.NewString(), "--ca-key", "/tmp/" + uuid.NewString()}, + expectErrMsg: "validity-period cannot be more than: 8760h0m0s", + }, + { + name: "generate ca failure - validity period malformed", + args: []string{"gen", "ca", "--org", "testorg", "-d", "malformed", "--ca-cert", "/tmp/" + uuid.NewString(), "--ca-key", "/tmp/" + uuid.NewString()}, + expectErrMsg: "failed to parse validity-period: time: invalid duration", + }, + } + + for _, tc := range tests { + s.Run(tc.name, func() { + err := s.RunCmd(tc.args...) + if tc.expectErrMsg != "" { + s.Error(err) + s.ErrorContains(err, tc.expectErrMsg) + } else { + s.NoError(err) + } + }) + } +} + +func (s *CertificatesTestSuite) TestGenerateCertificateCMDEndToEnd() { + caCertFile := "/tmp/" + uuid.NewString() + caKeyFile := "/tmp/" + uuid.NewString() + leafCertFile := "/tmp/" + uuid.NewString() + leafKeyFile := "/tmp/" + uuid.NewString() + + s.NoError(s.RunCmd([]string{"gen", "ca", "--org", "testorg", "-d", "8d", "--ca-cert", caCertFile, "--ca-key", caKeyFile}...)) + s.NoError(s.RunCmd([]string{"gen", "leaf", "--org", "testorg", "-d", "1d", "--ca-cert", caCertFile, "--ca-key", caKeyFile, "--cert", leafCertFile, "--key", leafKeyFile}...)) + + s.ErrorContains( + s.RunCmd([]string{"gen", "leaf", "--org", "testorg", "-d", "malformed", "--ca-cert", caCertFile, "--ca-key", caKeyFile, "--cert", leafCertFile, "--key", leafKeyFile}...), + "failed to parse validity-period: time: invalid duration", + ) + + s.ErrorContains( + s.RunCmd([]string{"gen", "leaf", "--org", "testorg", "-d", "100d", "--ca-cert", caCertFile, "--ca-key", caKeyFile, "--cert", leafCertFile, "--key", leafKeyFile}...), + "failed to generate end-entity certificate: validity period of 2400h0m0s puts certificate's expiry after certificate authority's expiry", + ) +} diff --git a/app/namespace.go b/app/namespace.go index 34179c6b..a173d6f4 100644 --- a/app/namespace.go +++ b/app/namespace.go @@ -5,12 +5,13 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/temporalio/tcld/protogen/api/auth/v1" - "go.uber.org/multierr" "io/ioutil" "net/mail" "strings" + "github.com/temporalio/tcld/protogen/api/auth/v1" + "go.uber.org/multierr" + "github.com/kylelemons/godebug/diff" "github.com/temporalio/tcld/protogen/api/authservice/v1" "github.com/temporalio/tcld/protogen/api/namespace/v1" @@ -24,8 +25,6 @@ const ( CaCertificateFlagName = "ca-certificate" CaCertificateFileFlagName = "ca-certificate-file" caCertificateFingerprintFlagName = "ca-certificate-fingerprint" - certificateFilterFileFlagName = "certificate-filter-file" - certificateFilterInputFlagName = "certificate-filter-input" searchAttributeFlagName = "search-attribute" userNamespacePermissionFlagName = "user-namespace-permission" ) @@ -477,25 +476,26 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, Name: "accepted-client-ca", Usage: "Manage client ca certificate used to verify client connections", Aliases: []string{"ca"}, - Subcommands: []*cli.Command{{ - Name: "list", - Aliases: []string{"l"}, - Usage: "List the accepted client ca certificates currently configured for the namespace", - Flags: []cli.Flag{ - NamespaceFlag, - }, - Action: func(ctx *cli.Context) error { - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) - if err != nil { - return err - } - out, err := parseCertificates(n.Spec.AcceptedClientCa) - if err != nil { - return err - } - return PrintObj(out) + Subcommands: []*cli.Command{ + { + Name: "list", + Aliases: []string{"l"}, + Usage: "List the accepted client ca certificates currently configured for the namespace", + Flags: []cli.Flag{ + NamespaceFlag, + }, + Action: func(ctx *cli.Context) error { + n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + if err != nil { + return err + } + out, err := parseCertificates(n.Spec.AcceptedClientCa) + if err != nil { + return err + } + return PrintObj(out) + }, }, - }, { Name: "add", Aliases: []string{"a"}, diff --git a/cmd/tcld/fx.go b/cmd/tcld/fx.go index 33b69d66..31460b2c 100644 --- a/cmd/tcld/fx.go +++ b/cmd/tcld/fx.go @@ -21,6 +21,7 @@ func fxOptions() fx.Option { app.GetLoginClient, app.NewLoginCommand, app.NewLogoutCommand, + app.NewCertificatesCommand, func() app.GetNamespaceClientFn { return app.GetNamespaceClient }, diff --git a/go.mod b/go.mod index 047dcea5..f8c71353 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,13 @@ module github.com/temporalio/tcld go 1.18 require ( + github.com/go-playground/validator/v10 v10.13.0 github.com/gogo/protobuf v1.3.2 github.com/golang/mock v1.6.0 + github.com/google/uuid v1.3.0 github.com/kylelemons/godebug v1.1.0 github.com/stretchr/testify v1.8.2 - github.com/urfave/cli/v2 v2.25.1 + github.com/urfave/cli/v2 v2.25.3 go.uber.org/fx v1.19.2 go.uber.org/multierr v1.6.0 google.golang.org/grpc v1.54.0 @@ -16,13 +18,17 @@ require ( require ( github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/protobuf v1.5.2 // indirect + github.com/leodido/go-urn v1.2.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/dig v1.16.1 // indirect go.uber.org/zap v1.23.0 // indirect + golang.org/x/crypto v0.7.0 // indirect golang.org/x/net v0.8.0 // indirect golang.org/x/sys v0.6.0 // indirect golang.org/x/text v0.8.0 // indirect diff --git a/go.sum b/go.sum index 1a06b7e6..23c9ffa0 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,13 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.13.0 h1:cFRQdfaSMCOSfGCCLB20MHvuoHb/s5G8L5pu2ppK5AQ= +github.com/go-playground/validator/v10 v10.13.0/go.mod h1:dwu7+CG8/CtBiJFZDz4e+5Upb6OLw04gtBYw0mcG/z4= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -13,10 +20,14 @@ github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.2.3 h1:6BE2vPT0lqoz3fmOesHZiaiFh7889ssCo2GMvLCfiuA= +github.com/leodido/go-urn v1.2.3/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -30,8 +41,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/urfave/cli/v2 v2.25.1 h1:zw8dSP7ghX0Gmm8vugrs6q9Ku0wzweqPyshy+syu9Gw= -github.com/urfave/cli/v2 v2.25.1/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= +github.com/urfave/cli/v2 v2.25.3 h1:VJkt6wvEBOoSjPFQvOkv6iWIrsJyCrKGtCtxXWwmGeY= +github.com/urfave/cli/v2 v2.25.3/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -51,6 +62,8 @@ go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= diff --git a/utils/duration.go b/utils/duration.go new file mode 100644 index 00000000..8b527087 --- /dev/null +++ b/utils/duration.go @@ -0,0 +1,214 @@ +package utils + +import ( + "errors" + "time" +) + +const ( + lowerhex = "0123456789abcdef" + runeSelf = 0x80 + runeError = '\uFFFD' +) + +func quote(s string) string { + buf := make([]byte, 1, len(s)+2) // slice will be at least len(s) + quotes + buf[0] = '"' + for i, c := range s { + if c >= runeSelf || c < ' ' { + // This means you are asking us to parse a time.Duration or + // time.Location with unprintable or non-ASCII characters in it. + // We don't expect to hit this case very often. We could try to + // reproduce strconv.Quote's behavior with full fidelity but + // given how rarely we expect to hit these edge cases, speed and + // conciseness are better. + var width int + if c == runeError { + width = 1 + if i+2 < len(s) && s[i:i+3] == string(runeError) { + width = 3 + } + } else { + width = len(string(c)) + } + for j := 0; j < width; j++ { + buf = append(buf, `\x`...) + buf = append(buf, lowerhex[s[i+j]>>4]) + buf = append(buf, lowerhex[s[i+j]&0xF]) + } + } else { + if c == '"' || c == '\\' { + buf = append(buf, '\\') + } + buf = append(buf, string(c)...) + } + } + buf = append(buf, '"') + return string(buf) +} + +var errLeadingInt = errors.New("time: bad [0-9]*") // never printed + +// leadingInt consumes the leading [0-9]* from s. +func leadingInt[bytes []byte | string](s bytes) (x uint64, rem bytes, err error) { + i := 0 + for ; i < len(s); i++ { + c := s[i] + if c < '0' || c > '9' { + break + } + if x > 1<<63/10 { + // overflow + return 0, rem, errLeadingInt + } + x = x*10 + uint64(c) - '0' + if x > 1<<63 { + // overflow + return 0, rem, errLeadingInt + } + } + return x, s[i:], nil +} + +// leadingFraction consumes the leading [0-9]* from s. +// It is used only for fractions, so does not return an error on overflow, +// it just stops accumulating precision. +func leadingFraction(s string) (x uint64, scale float64, rem string) { + i := 0 + scale = 1 + overflow := false + for ; i < len(s); i++ { + c := s[i] + if c < '0' || c > '9' { + break + } + if overflow { + continue + } + if x > (1<<63-1)/10 { + // It's possible for overflow to give a positive number, so take care. + overflow = true + continue + } + y := x*10 + uint64(c) - '0' + if y > 1<<63 { + overflow = true + continue + } + x = y + scale *= 10 + } + return x, scale, s[i:] +} + +var unitMap = map[string]uint64{ + "s": uint64(time.Second), + "m": uint64(time.Minute), + "h": uint64(time.Hour), + "d": uint64(24 * time.Hour), + "y": uint64(365 * 24 * time.Hour), +} + +/* + A parser to parse time.Duration that can support days and year units. + Copied over from: https://cs.opensource.google/go/go/+/refs/tags/go1.20.4:src/time/format.go;l=1589 +*/ + +func ParseDuration(s string) (time.Duration, error) { + // [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+ + orig := s + var d uint64 + neg := false + + // Consume [-+]? + if s != "" { + c := s[0] + if c == '-' || c == '+' { + neg = c == '-' + s = s[1:] + } + } + // Special case: if all that is left is "0", this is zero. + if s == "0" { + return 0, nil + } + if s == "" { + return 0, errors.New("time: invalid duration " + quote(orig)) + } + for s != "" { + var ( + v, f uint64 // integers before, after decimal point + scale float64 = 1 // value = v + f/scale + ) + + var err error + + // The next character must be [0-9.] + if !(s[0] == '.' || '0' <= s[0] && s[0] <= '9') { + return 0, errors.New("time: invalid duration " + quote(orig)) + } + // Consume [0-9]* + pl := len(s) + v, s, err = leadingInt(s) + if err != nil { + return 0, errors.New("time: invalid duration " + quote(orig)) + } + pre := pl != len(s) // whether we consumed anything before a period + + // Consume (\.[0-9]*)? + post := false + if s != "" && s[0] == '.' { + s = s[1:] + pl := len(s) + f, scale, s = leadingFraction(s) + post = pl != len(s) + } + if !pre && !post { + // no digits (e.g. ".s" or "-.s") + return 0, errors.New("time: invalid duration " + quote(orig)) + } + + // Consume unit. + i := 0 + for ; i < len(s); i++ { + c := s[i] + if c == '.' || '0' <= c && c <= '9' { + break + } + } + if i == 0 { + return 0, errors.New("time: missing unit in duration " + quote(orig)) + } + u := s[:i] + s = s[i:] + unit, ok := unitMap[u] + if !ok { + return 0, errors.New("time: unknown unit " + quote(u) + " in duration " + quote(orig)) + } + if v > 1<<63/unit { + // overflow + return 0, errors.New("time: invalid duration " + quote(orig)) + } + v *= unit + if f > 0 { + // float64 is needed to be nanosecond accurate for fractions of hours. + // v >= 0 && (f*unit/scale) <= 3.6e+12 (ns/h, h is the largest unit) + v += uint64(float64(f) * (float64(unit) / scale)) + if v > 1<<63 { + // overflow + return 0, errors.New("time: invalid duration " + quote(orig)) + } + } + d += v + if d > 1<<63 { + return 0, errors.New("time: invalid duration " + quote(orig)) + } + } + if neg { + return -time.Duration(d), nil + } + if d > 1<<63-1 { + return 0, errors.New("time: invalid duration " + quote(orig)) + } + return time.Duration(d), nil +} diff --git a/utils/duration_test.go b/utils/duration_test.go new file mode 100644 index 00000000..7de35ae1 --- /dev/null +++ b/utils/duration_test.go @@ -0,0 +1,66 @@ +package utils_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/temporalio/tcld/utils" +) + +func TestParseDuration(t *testing.T) { + + duration, err := utils.ParseDuration("99s") + assert.NoError(t, err) + assert.Equal(t, 99*time.Second, duration) + + duration, err = utils.ParseDuration("99m") + assert.NoError(t, err) + assert.Equal(t, 99*time.Minute, duration) + + duration, err = utils.ParseDuration("99m99s") + assert.NoError(t, err) + assert.Equal(t, 99*time.Minute+99*time.Second, duration) + + duration, err = utils.ParseDuration("99h") + assert.NoError(t, err) + assert.Equal(t, 99*time.Hour, duration) + + duration, err = utils.ParseDuration("99h99m99s") + assert.NoError(t, err) + assert.Equal(t, 99*time.Hour+99*time.Minute+99*time.Second, duration) + + duration, err = utils.ParseDuration("99d") + assert.NoError(t, err) + assert.Equal(t, 99*24*time.Hour, duration) + + duration, err = utils.ParseDuration("99d99h99m99s") + assert.NoError(t, err) + assert.Equal(t, 99*24*time.Hour+99*time.Hour+99*time.Minute+99*time.Second, duration) + + duration, err = utils.ParseDuration("99y") + assert.NoError(t, err) + assert.Equal(t, 99*365*24*time.Hour, duration) + + duration, err = utils.ParseDuration("99y99d99h99m99s") + assert.NoError(t, err) + assert.Equal(t, 99*365*24*time.Hour+99*24*time.Hour+99*time.Hour+99*time.Minute+99*time.Second, duration) + + duration, err = utils.ParseDuration("99.9y") + assert.NoError(t, err) + assert.Equal(t, 99.9*365*24*time.Hour, duration) + + // error scenarios + _, err = utils.ParseDuration("y") + assert.Error(t, err) + + _, err = utils.ParseDuration("12") + assert.Error(t, err) + + _, err = utils.ParseDuration("99yy") + assert.Error(t, err) + + _, err = utils.ParseDuration("99y45") + assert.Error(t, err) + +}