From bc7ea8feeb7d5eaed08266f29aa58f74a3083069 Mon Sep 17 00:00:00 2001 From: Artem Gavrilov Date: Thu, 7 Jul 2022 16:08:03 +0200 Subject: [PATCH] PMM-10078 Extract portal client, add dev env variables for portal address overwriting (#958) * PMM-10078 Extract portal client, add dev env variables for portal address overwriting * PMM-10078 Refactoring * PMM-10078 Refactoring * PMM-10078 Refactoring * Fix DBAAS dependency version * PMM-10078 Use warns instead of errors for removed test env variables * PMM-10078 Fix tests * PMM-10078 Add test env variables to CONTRIBUTING.md --- docker-compose.yml | 4 +- managed/CONTRIBUTING.md | 20 + managed/main.go | 22 +- managed/services/checks/checks.go | 70 ++-- managed/services/checks/checks_test.go | 91 ++-- managed/services/config/config.go | 3 - managed/services/config/pmm-managed.yaml | 6 - .../management/ia/alerts_service_test.go | 3 +- .../management/ia/rules_service_test.go | 5 +- .../management/ia/templates_service.go | 68 ++- .../management/ia/templates_service_test.go | 42 +- managed/services/platform/config.go | 26 -- managed/services/platform/platform.go | 395 +++--------------- managed/services/telemetry/config.go | 37 +- managed/services/telemetry/config_test.go | 18 +- managed/services/telemetry/telemetry.go | 46 +- managed/utils/envvars/parser.go | 77 ++-- managed/utils/envvars/parser_test.go | 55 ++- managed/utils/platform/client.go | 344 +++++++++++++++ managed/utils/saasreq/request.go | 91 ---- utils/pdeathsig/pdeathsig_linux.go | 2 +- 21 files changed, 665 insertions(+), 760 deletions(-) delete mode 100644 managed/services/platform/config.go create mode 100644 managed/utils/platform/client.go delete mode 100644 managed/utils/saasreq/request.go diff --git a/docker-compose.yml b/docker-compose.yml index 5937ed06ef..f768fa460d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,7 +16,9 @@ services: - AWS_SECRET_KEY=${AWS_SECRET_KEY} - ENABLE_ALERTING=1 - ENABLE_BACKUP_MANAGEMENT=1 -# - PERCONA_TEST_SAAS_HOST=check.localhost +# - PERCONA_TEST_PLATFORM_ADDRESS=https://check.localhost +# - PERCONA_TEST_PLATFORM_INSECURE=1 +# - PERCONA_TEST_PLATFORM_PUBLIC_KEY= # - PERCONA_TEST_TELEMETRY_INTERVAL=10s # - PERCONA_TEST_TELEMETRY_RETRY_BACKOFF=10s # - PMM_DEBUG=1 diff --git a/managed/CONTRIBUTING.md b/managed/CONTRIBUTING.md index aa4c499db1..76cb35a888 100644 --- a/managed/CONTRIBUTING.md +++ b/managed/CONTRIBUTING.md @@ -48,6 +48,26 @@ go test -timeout=30s -p 1 ./... # Advanced Setup +## Available test environment variables: +| Variable | Description | Default | +|-----------------------------------------|------------------------------------------------------------------------------------------------|------------------------------------------| +| PERCONA_TEST_PMM_CLICKHOUSE_ADDR | Sets Clickhouse address | 127.0.0.1:9000 | +| PERCONA_TEST_PMM_CLICKHOUSE_DATABASE | Sets Clickhouse database | pmm | +| PERCONA_TEST_PMM_CLICKHOUSE_POOL_SIZE | Sets Clickhouse connections pool size | none | +| PERCONA_TEST_PMM_CLICKHOUSE_BLOCK_SIZE | Sets Clickhouse block size | none | +| PERCONA_TEST_STARLARK_ALLOW_RECURSION | Allows recursive functions in checks scripts | false | +| PERCONA_TEST_NICER_API | Enables nicer API with default/zero values in response. | false | +| PERCONA_TEST_VERSION_SERVICE_URL | Sets versions service URL | https://check.percona.com/versions/v1 | +| PERCONA_TEST_CHECKS_FILE | Specifies path to local checks file and disables downlading checks files from Percona Platform | none | +| PERCONA_TEST_CHECKS_RESEND_INTERVAL | Sets how often checks alerts resent to Alertmanager | 2 seconds | +| PERCONA_TEST_CHECKS_DISABLE_START_DELAY | Disables checks service startup delay | false | +| PERCONA_TEST_TELEMETRY_INTERVAL | ## TODO | | +| PERCONA_TEST_TELEMETRY_RETRY_BACKOFF | ## TODO | | +| PERCONA_TEST_DBAAS_KUBECONFIG | ## TODO | | +| PERCONA_TEST_PLATFORM_ADDRESS | Sets Percona Platform address | https://check.percona.com | +| PERCONA_TEST_PLATFORM_INSECURE | Allows insecure TLS connections to Percona Platform | false | +| PERCONA_TEST_PLATFORM_PUBLIC_KEY | Sets Percona Platform public key (Minisign) | set of keys embedded into managed binary | + ## Add instances for monitoring `make env-up` just starts the PMM server but it doesn't setup anything to be monitored. We can use [pmm-admin](https://github.com/percona/pmm-admin) and [pmm-agent](https://github.com/percona/pmm-agent) to add instances to be monitored to pmm-managed. diff --git a/managed/main.go b/managed/main.go index f46f6b9607..a05d65ef92 100644 --- a/managed/main.go +++ b/managed/main.go @@ -90,8 +90,10 @@ import ( "github.com/percona/pmm/managed/services/victoriametrics" "github.com/percona/pmm/managed/services/vmalert" "github.com/percona/pmm/managed/utils/clean" + "github.com/percona/pmm/managed/utils/envvars" "github.com/percona/pmm/managed/utils/interceptors" "github.com/percona/pmm/managed/utils/logger" + platformClient "github.com/percona/pmm/managed/utils/platform" pmmerrors "github.com/percona/pmm/utils/errors" "github.com/percona/pmm/utils/sqlmetrics" "github.com/percona/pmm/version" @@ -132,6 +134,7 @@ func addLogsHandler(mux *http.ServeMux, logs *supervisord.Logs) { type gRPCServerDeps struct { db *reform.DB vmdb *victoriametrics.Service + platformClient *platformClient.Client server *server.Server agentsRegistry *agents.Registry handler *agents.Handler @@ -228,7 +231,7 @@ func runGRPCServer(ctx context.Context, deps *gRPCServerDeps) { dbaasv1beta1.RegisterLogsAPIServer(gRPCServer, managementdbaas.NewLogsService(deps.db, deps.dbaasClient)) dbaasv1beta1.RegisterComponentsServer(gRPCServer, managementdbaas.NewComponentsService(deps.db, deps.dbaasClient, deps.versionServiceClient)) - platformService, err := platform.New(deps.db, deps.supervisord, deps.checksService, deps.grafanaClient, deps.config.Services.Platform) + platformService, err := platform.New(deps.platformClient, deps.db, deps.supervisord, deps.checksService, deps.grafanaClient) if err == nil { platformpb.RegisterPlatformServer(gRPCServer, platformService) } else { @@ -699,7 +702,17 @@ func main() { logs := supervisord.NewLogs(version.FullInfo(), pmmUpdateCheck) supervisord := supervisord.New(*supervisordConfigDirF, pmmUpdateCheck, vmParams) - telemetry, err := telemetry.NewService(db, version.Version, cfg.Config.Services.Telemetry) + platformAddress, err := envvars.GetPlatformAddress() + if err != nil { + l.Fatal(err) + } + + platformClient, err := platformClient.NewClient(db, platformAddress) + if err != nil { + l.Fatalf("Could not create Percona Portal client: %s", err) + } + + telemetry, err := telemetry.NewService(db, platformClient, version.Version, cfg.Config.Services.Telemetry) if err != nil { l.Fatalf("Could not create telemetry service: %s", err) } @@ -714,7 +727,7 @@ func main() { actionsService := agents.NewActionsService(qanClient, agentsRegistry) - checksService, err := checks.New(actionsService, alertManager, db, *victoriaMetricsURLF) + checksService, err := checks.New(db, platformClient, actionsService, alertManager, *victoriaMetricsURLF) if err != nil { l.Fatalf("Could not create checks service: %s", err) } @@ -722,7 +735,7 @@ func main() { prom.MustRegister(checksService) // Integrated alerts services - templatesService, err := ia.NewTemplatesService(db) + templatesService, err := ia.NewTemplatesService(db, platformClient) if err != nil { l.Fatalf("Could not create templates service: %s", err) } @@ -918,6 +931,7 @@ func main() { &gRPCServerDeps{ db: db, vmdb: vmdb, + platformClient: platformClient, server: server, agentsRegistry: agentsRegistry, handler: agentsHandler, diff --git a/managed/services/checks/checks.go b/managed/services/checks/checks.go index c06ffb3e2a..623516da1d 100644 --- a/managed/services/checks/checks.go +++ b/managed/services/checks/checks.go @@ -21,9 +21,7 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io/ioutil" - "net/http" "os" "os/exec" "strconv" @@ -33,7 +31,6 @@ import ( "text/template" "time" - api "github.com/percona-platform/saas/gen/check/retrieval" "github.com/percona-platform/saas/pkg/check" "github.com/percona-platform/saas/pkg/common" "github.com/pkg/errors" @@ -49,7 +46,7 @@ import ( "github.com/percona/pmm/managed/models" "github.com/percona/pmm/managed/services" "github.com/percona/pmm/managed/utils/envvars" - "github.com/percona/pmm/managed/utils/saasreq" + "github.com/percona/pmm/managed/utils/platform" "github.com/percona/pmm/managed/utils/signatures" "github.com/percona/pmm/utils/pdeathsig" "github.com/percona/pmm/version" @@ -64,7 +61,7 @@ const ( envDisableStartDelay = "PERCONA_TEST_CHECKS_DISABLE_START_DELAY" checkExecutionTimeout = 5 * time.Minute // limits execution time for every single check - platformRequestTimeout = 2 * time.Minute // time limit to get checks list from the platform + platformRequestTimeout = 2 * time.Minute // time limit to get checks list from the portal resultAwaitTimeout = 20 * time.Second // should be greater than agents.defaultQueryActionTimeout scriptExecutionTimeout = 5 * time.Second // time limit for running pmm-managed-starlark resultCheckInterval = time.Second @@ -91,18 +88,18 @@ var ( // Service is responsible for interactions with Percona Check service. type Service struct { + platformClient *platform.Client agentsRegistry agentsRegistry alertmanagerService alertmanagerService db *reform.DB alertsRegistry *registry vmClient v1.API - l *logrus.Entry - host string - publicKeys []string - startDelay time.Duration - resendInterval time.Duration - localChecksFile string // For testing + l *logrus.Entry + startDelay time.Duration + resendInterval time.Duration + platformPublicKeys []string + localChecksFile string // For testing cm sync.Mutex checks map[string]check.Check @@ -117,7 +114,7 @@ type Service struct { } // New returns Service with given PMM version. -func New(agentsRegistry agentsRegistry, alertmanagerService alertmanagerService, db *reform.DB, VMAddress string) (*Service, error) { +func New(db *reform.DB, platformClient *platform.Client, agentsRegistry agentsRegistry, alertmanagerService alertmanagerService, VMAddress string) (*Service, error) { l := logrus.WithField("component", "checks") resendInterval := defaultResendInterval @@ -126,28 +123,30 @@ func New(agentsRegistry agentsRegistry, alertmanagerService alertmanagerService, resendInterval = d } - host, err := envvars.GetSAASHost() + vmClient, err := metrics.NewClient(metrics.Config{Address: VMAddress}) if err != nil { return nil, err } - vmClient, err := metrics.NewClient(metrics.Config{Address: VMAddress}) - if err != nil { - return nil, err + var platformPublicKeys []string + if k := envvars.GetPlatformPublicKeys(); k != nil { + l.Warnf("Percona Platform public keys changed to %q.", k) + platformPublicKeys = k } s := &Service{ + db: db, agentsRegistry: agentsRegistry, alertmanagerService: alertmanagerService, - db: db, alertsRegistry: newRegistry(resolveTimeoutFactor * resendInterval), vmClient: v1.NewAPI(vmClient), - l: l, - host: host, - startDelay: defaultStartDelay, - resendInterval: resendInterval, - localChecksFile: os.Getenv(envCheckFile), + l: l, + platformClient: platformClient, + startDelay: defaultStartDelay, + resendInterval: resendInterval, + platformPublicKeys: platformPublicKeys, + localChecksFile: os.Getenv(envCheckFile), mScriptsExecuted: prom.NewCounterVec(prom.CounterOpts{ Namespace: prometheusNamespace, @@ -164,10 +163,6 @@ func New(agentsRegistry agentsRegistry, alertmanagerService alertmanagerService, }, []string{"service_type", "check_type"}), } - if k := envvars.GetPublicKeys(); k != nil { - l.Warnf("Public keys changed to %q.", k) - s.publicKeys = k - } if d, _ := strconv.ParseBool(os.Getenv(envDisableStartDelay)); d { l.Warn("Start delay disabled.") s.startDelay = 0 @@ -1422,29 +1417,15 @@ func (s *Service) downloadChecks(ctx context.Context) ([]check.Check, error) { return nil, nil } - s.l.Infof("Downloading checks from %s ...", s.host) - nCtx, cancel := context.WithTimeout(ctx, platformRequestTimeout) defer cancel() - var accessToken string - if ssoDetails, err := models.GetPerconaSSODetails(nCtx, s.db.Querier); err == nil { - accessToken = ssoDetails.AccessToken.AccessToken - } - - endpoint := fmt.Sprintf("https://%s/v1/check/GetAllChecks", s.host) - bodyBytes, err := saasreq.MakeRequest(nCtx, http.MethodPost, endpoint, accessToken, nil, - &saasreq.SaasRequestOptions{}) + resp, err := s.platformClient.GetChecks(nCtx) if err != nil { - return nil, errors.Wrap(err, "failed to dial") - } - - var resp *api.GetAllChecksResponse - if err := json.Unmarshal(bodyBytes, &resp); err != nil { - return nil, err + return nil, errors.WithStack(err) } - if err = signatures.Verify(s.l, resp.File, resp.Signatures, s.publicKeys); err != nil { + if err = signatures.Verify(s.l, resp.File, resp.Signatures, s.platformPublicKeys); err != nil { return nil, err } @@ -1453,9 +1434,10 @@ func (s *Service) downloadChecks(ctx context.Context) ([]check.Check, error) { DisallowUnknownFields: false, DisallowInvalidChecks: false, } + checks, err := check.Parse(strings.NewReader(resp.File), params) if err != nil { - return nil, err + return nil, errors.WithStack(err) } return checks, nil diff --git a/managed/services/checks/checks_test.go b/managed/services/checks/checks_test.go index f1648fdefb..d7d12b565f 100644 --- a/managed/services/checks/checks_test.go +++ b/managed/services/checks/checks_test.go @@ -37,16 +37,17 @@ import ( "github.com/percona/pmm/api/alertmanager/ammodels" "github.com/percona/pmm/managed/models" "github.com/percona/pmm/managed/services" + "github.com/percona/pmm/managed/utils/platform" "github.com/percona/pmm/managed/utils/testdb" "github.com/percona/pmm/version" ) const ( - devChecksHost = "check-dev.percona.com" - devChecksPublicKey = "RWTg+ZmCCjt7O8eWeAmTLAqW+1ozUbpRSKSwNTmO+exlS5KEIPYWuYdX" - testChecksFile = "../../testdata/checks/checks.yml" - issuerURL = "https://id-dev.percona.com/oauth2/aus15pi5rjdtfrcH51d7/v1" - vmAddress = "http://127.0.0.1:9090/prometheus/" + devPlatformAddress = "https://check-dev.percona.com" + devPlatformPublicKey = "RWTg+ZmCCjt7O8eWeAmTLAqW+1ozUbpRSKSwNTmO+exlS5KEIPYWuYdX" + testChecksFile = "../../testdata/checks/checks.yml" + issuerURL = "https://id-dev.percona.com/oauth2/aus15pi5rjdtfrcH51d7/v1" + vmAddress = "http://127.0.0.1:9090/prometheus/" ) func TestDownloadChecks(t *testing.T) { @@ -57,6 +58,12 @@ func TestDownloadChecks(t *testing.T) { sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, nil) + platformClient, err := platform.NewClient(db, devPlatformAddress) + require.NoError(t, err) + + s, err := New(db, platformClient, nil, nil, vmAddress) + s.platformPublicKeys = []string{devPlatformPublicKey} + require.NoError(t, err) insertSSODetails := &models.PerconaSSODetailsInsert{ IssuerURL: issuerURL, @@ -64,14 +71,9 @@ func TestDownloadChecks(t *testing.T) { PMMManagedClientSecret: clientSecret, Scope: "percona", } - err := models.InsertPerconaSSODetails(db.Querier, insertSSODetails) + err = models.InsertPerconaSSODetails(db.Querier, insertSSODetails) require.NoError(t, err) - s, err := New(nil, nil, db, vmAddress) - require.NoError(t, err) - s.host = devChecksHost - s.publicKeys = []string{devChecksPublicKey} - t.Run("normal", func(t *testing.T) { checks, err := s.GetChecks() require.NoError(t, err) @@ -108,7 +110,7 @@ func TestDownloadChecks(t *testing.T) { } func TestLoadLocalChecks(t *testing.T) { - s, err := New(nil, nil, nil, vmAddress) + s, err := New(nil, nil, nil, nil, vmAddress) require.NoError(t, err) checks, err := s.loadLocalChecks(testChecksFile) @@ -144,10 +146,14 @@ func TestLoadLocalChecks(t *testing.T) { } func TestCollectChecks(t *testing.T) { + sqlDB := testdb.Open(t, models.SkipFixtures, nil) + db := reform.NewDB(sqlDB, postgresql.Dialect, nil) + + platformClient, err := platform.NewClient(db, devPlatformAddress) + require.NoError(t, err) + t.Run("collect local checks", func(t *testing.T) { - sqlDB := testdb.Open(t, models.SkipFixtures, nil) - db := reform.NewDB(sqlDB, postgresql.Dialect, nil) - s, err := New(nil, nil, db, vmAddress) + s, err := New(db, platformClient, nil, nil, vmAddress) require.NoError(t, err) s.localChecksFile = testChecksFile @@ -171,11 +177,9 @@ func TestCollectChecks(t *testing.T) { }) t.Run("download checks", func(t *testing.T) { - sqlDB := testdb.Open(t, models.SkipFixtures, nil) - db := reform.NewDB(sqlDB, postgresql.Dialect, nil) - s, err := New(nil, nil, db, vmAddress) + s, err := New(db, platformClient, nil, nil, vmAddress) + s.platformPublicKeys = []string{devPlatformPublicKey} require.NoError(t, err) - s.localChecksFile = testChecksFile s.CollectChecks(context.Background()) assert.NotEmpty(t, s.checks) @@ -186,7 +190,8 @@ func TestDisableChecks(t *testing.T) { t.Run("normal", func(t *testing.T) { sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, nil) - s, err := New(nil, nil, db, vmAddress) + + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) s.localChecksFile = testChecksFile @@ -211,7 +216,8 @@ func TestDisableChecks(t *testing.T) { t.Run("disable same check twice", func(t *testing.T) { sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, nil) - s, err := New(nil, nil, db, vmAddress) + + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) s.localChecksFile = testChecksFile @@ -239,7 +245,8 @@ func TestDisableChecks(t *testing.T) { t.Run("disable unknown check", func(t *testing.T) { sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, nil) - s, err := New(nil, nil, db, vmAddress) + + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) s.localChecksFile = testChecksFile @@ -258,7 +265,8 @@ func TestEnableChecks(t *testing.T) { t.Run("normal", func(t *testing.T) { sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, nil) - s, err := New(nil, nil, db, vmAddress) + + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) s.localChecksFile = testChecksFile @@ -289,7 +297,8 @@ func TestChangeInterval(t *testing.T) { ams.On("SendAlerts", mock.Anything, mock.Anything).Return() sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, nil) - s, err := New(nil, &ams, db, vmAddress) + + s, err := New(db, nil, nil, &ams, vmAddress) require.NoError(t, err) s.localChecksFile = testChecksFile @@ -332,7 +341,7 @@ func TestChangeInterval(t *testing.T) { // method and test for recorded metrics. func TestSTTMetrics(t *testing.T) { t.Run("check for recorded metrics", func(t *testing.T) { - s, err := New(nil, nil, nil, vmAddress) + s, err := New(nil, nil, nil, nil, vmAddress) require.NoError(t, err) expected := strings.NewReader(` # HELP pmm_managed_checks_alerts_generated_total Counter of alerts generated per service type per check type @@ -361,7 +370,7 @@ func TestGetSecurityCheckResults(t *testing.T) { db := reform.NewDB(sqlDB, postgresql.Dialect, nil) t.Run("STT enabled", func(t *testing.T) { - s, err := New(nil, nil, db, vmAddress) + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) results, err := s.GetSecurityCheckResults() @@ -370,7 +379,7 @@ func TestGetSecurityCheckResults(t *testing.T) { }) t.Run("STT disabled", func(t *testing.T) { - s, err := New(nil, nil, db, vmAddress) + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) settings, err := models.GetSettings(db) @@ -391,8 +400,9 @@ func TestStartChecks(t *testing.T) { db := reform.NewDB(sqlDB, postgresql.Dialect, nil) t.Run("unknown interval", func(t *testing.T) { - s, err := New(nil, nil, db, vmAddress) + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) + s.localChecksFile = testChecksFile err = s.runChecksGroup(context.Background(), check.Interval("unknown")) assert.EqualError(t, err, "unknown check interval: unknown") @@ -402,7 +412,7 @@ func TestStartChecks(t *testing.T) { var ams mockAlertmanagerService ams.On("SendAlerts", mock.Anything, mock.Anything).Return() - s, err := New(nil, &ams, db, vmAddress) + s, err := New(db, nil, nil, &ams, vmAddress) require.NoError(t, err) s.localChecksFile = testChecksFile @@ -414,7 +424,7 @@ func TestStartChecks(t *testing.T) { }) t.Run("stt disabled", func(t *testing.T) { - s, err := New(nil, nil, db, vmAddress) + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) settings, err := models.GetSettings(db) @@ -455,7 +465,7 @@ func TestFilterChecks(t *testing.T) { checks := append(valid, invalid...) - s, err := New(nil, nil, nil, vmAddress) + s, err := New(nil, nil, nil, nil, vmAddress) require.NoError(t, err) actual := s.filterSupportedChecks(checks) assert.ElementsMatch(t, valid, actual) @@ -482,7 +492,7 @@ func TestGroupChecksByDB(t *testing.T) { "missing family": {Name: "missing family", Version: 2}, } - s, err := New(nil, nil, nil, vmAddress) + s, err := New(nil, nil, nil, nil, vmAddress) require.NoError(t, err) mySQLChecks, postgreSQLChecks, mongoDBChecks := s.groupChecksByDB(checks) @@ -529,7 +539,7 @@ func TestMinPMMAgents(t *testing.T) { {name: "PostgreSQL Family", minVersion: pmmAgent2_6_0, check: check.Check{Version: 2, Queries: []check.Query{{Type: check.PostgreSQLShow}, {Type: check.PostgreSQLSelect}}}}, } - s, err := New(nil, nil, nil, vmAddress) + s, err := New(nil, nil, nil, nil, vmAddress) require.NoError(t, err) for _, test := range tests { @@ -568,7 +578,8 @@ func setup(t *testing.T, db *reform.DB, serviceName, nodeID, pmmAgentVersion str func TestFindTargets(t *testing.T) { sqlDB := testdb.Open(t, models.SetupFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, reform.NewPrintfLogger(t.Logf)) - s, err := New(nil, nil, db, vmAddress) + + s, err := New(db, nil, nil, nil, vmAddress) require.NoError(t, err) t.Run("unknown service", func(t *testing.T) { @@ -622,7 +633,7 @@ func TestFindTargets(t *testing.T) { func TestFilterChecksByInterval(t *testing.T) { t.Parallel() - s, err := New(nil, nil, nil, vmAddress) + s, err := New(nil, nil, nil, nil, vmAddress) require.NoError(t, err) rareCheck := check.Check{Name: "rareCheck", Interval: check.Rare} @@ -658,7 +669,7 @@ func TestGetFailedChecks(t *testing.T) { ctx := context.Background() ams.On("GetAlerts", ctx, mock.Anything).Return([]*ammodels.GettableAlert{}, nil) - s, err := New(nil, &ams, db, vmAddress) + s, err := New(db, nil, nil, &ams, vmAddress) require.NoError(t, err) results, err := s.GetChecksResults(context.Background(), "test_svc") @@ -713,7 +724,7 @@ func TestGetFailedChecks(t *testing.T) { ctx := context.Background() ams.On("GetAlerts", ctx, mock.Anything).Return([]*ammodels.GettableAlert{&testAlert}, nil) - s, err := New(nil, &ams, db, vmAddress) + s, err := New(db, nil, nil, &ams, vmAddress) require.NoError(t, err) response, err := s.GetChecksResults(ctx, "test_svc") @@ -726,7 +737,7 @@ func TestGetFailedChecks(t *testing.T) { ctx := context.Background() ams.On("GetAlerts", ctx, mock.Anything).Return(nil, services.ErrSTTDisabled) - s, err := New(nil, &ams, db, vmAddress) + s, err := New(db, nil, nil, &ams, vmAddress) require.NoError(t, err) settings, err := models.GetSettings(db) @@ -762,7 +773,7 @@ func TestToggleCheckAlert(t *testing.T) { ams.On("GetAlerts", ctx, mock.Anything).Return([]*ammodels.GettableAlert{testAlert}, nil) ams.On("SilenceAlerts", ctx, []*ammodels.GettableAlert{testAlert}).Return(nil) - s, err := New(nil, &ams, nil, vmAddress) + s, err := New(nil, nil, nil, &ams, vmAddress) require.NoError(t, err) active := len(testAlert.Status.SilencedBy) == 0 @@ -788,7 +799,7 @@ func TestToggleCheckAlert(t *testing.T) { ams.On("GetAlerts", ctx, mock.Anything).Return([]*ammodels.GettableAlert{testAlert}, nil) ams.On("UnsilenceAlerts", ctx, []*ammodels.GettableAlert{testAlert}).Return(nil) - s, err := New(nil, &ams, nil, vmAddress) + s, err := New(nil, nil, nil, &ams, vmAddress) require.NoError(t, err) active := len(testAlert.Status.SilencedBy) == 0 diff --git a/managed/services/config/config.go b/managed/services/config/config.go index 624bebc4c3..8aa4245a75 100644 --- a/managed/services/config/config.go +++ b/managed/services/config/config.go @@ -26,7 +26,6 @@ import ( "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" - "github.com/percona/pmm/managed/services/platform" "github.com/percona/pmm/managed/services/telemetry" ) @@ -47,7 +46,6 @@ type Service struct { // Config application config. type Config struct { Services struct { - Platform platform.Config `yaml:"platform"` Telemetry telemetry.ServiceConfig `yaml:"telemetry"` } `yaml:"services"` } @@ -90,7 +88,6 @@ func (s *Service) Load() error { } } - cfg.Services.Platform.Init() if err := cfg.Services.Telemetry.Init(s.l); err != nil { return err } diff --git a/managed/services/config/pmm-managed.yaml b/managed/services/config/pmm-managed.yaml index f80c1d98d0..adcbc87711 100644 --- a/managed/services/config/pmm-managed.yaml +++ b/managed/services/config/pmm-managed.yaml @@ -1,12 +1,7 @@ services: - platform: - skip_tls_verification: false telemetry: enabled: true load_defaults: true - endpoints: - # %s is substituted with `saas_hostname` - report: https://%s/v1/telemetry/Report datasources: VM: enabled: true @@ -24,7 +19,6 @@ services: username: pmm password: pmm reporting: - skip_tls_verification: false send_on_start: false interval: 24h interval_env: "PERCONA_TEST_TELEMETRY_INTERVAL" diff --git a/managed/services/management/ia/alerts_service_test.go b/managed/services/management/ia/alerts_service_test.go index f0471bf5b0..81864d79f0 100644 --- a/managed/services/management/ia/alerts_service_test.go +++ b/managed/services/management/ia/alerts_service_test.go @@ -288,9 +288,8 @@ func TestListAlerts(t *testing.T) { mockAlert.On("GetAlerts", ctx, mock.Anything).Return(mockedAlerts, nil) - tmplSvc, err := NewTemplatesService(db) + tmplSvc, err := NewTemplatesService(db, nil) require.NoError(t, err) - tmplSvc.CollectTemplates(ctx) svc := NewAlertsService(db, mockAlert, tmplSvc) findAlerts := func(alerts []*iav1beta1.Alert, alertIDs ...string) bool { diff --git a/managed/services/management/ia/rules_service_test.go b/managed/services/management/ia/rules_service_test.go index ea47990026..57d352f668 100644 --- a/managed/services/management/ia/rules_service_test.go +++ b/managed/services/management/ia/rules_service_test.go @@ -46,10 +46,11 @@ func TestCreateAlertRule(t *testing.T) { sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, reform.NewPrintfLogger(t.Logf)) - // Enable IA + // Enable IA and disable telemetry to prevent network calls settings, err := models.GetSettings(db) require.NoError(t, err) settings.IntegratedAlerting.Enabled = true + settings.Telemetry.Disabled = true err = models.SaveSettings(db, settings) require.NoError(t, err) @@ -72,7 +73,7 @@ func TestCreateAlertRule(t *testing.T) { channelID := respC.ChannelId // Load test templates - templates, err := NewTemplatesService(db) + templates, err := NewTemplatesService(db, nil) require.NoError(t, err) templates.userTemplatesPath = testTemplates2 templates.CollectTemplates(ctx) diff --git a/managed/services/management/ia/templates_service.go b/managed/services/management/ia/templates_service.go index cbbf2a1480..0af6e245c9 100644 --- a/managed/services/management/ia/templates_service.go +++ b/managed/services/management/ia/templates_service.go @@ -19,11 +19,8 @@ package ia import ( "bytes" "context" - "encoding/json" - "fmt" "io/fs" "io/ioutil" - "net/http" "path/filepath" "sort" "strings" @@ -31,7 +28,6 @@ import ( "text/template" "time" - api "github.com/percona-platform/saas/gen/check/retrieval" "github.com/percona-platform/saas/pkg/alert" "github.com/percona-platform/saas/pkg/common" "github.com/percona/promconfig" @@ -49,12 +45,14 @@ import ( "github.com/percona/pmm/managed/models" "github.com/percona/pmm/managed/utils/dir" "github.com/percona/pmm/managed/utils/envvars" - "github.com/percona/pmm/managed/utils/saasreq" + "github.com/percona/pmm/managed/utils/platform" "github.com/percona/pmm/managed/utils/signatures" ) const ( - templatesDir = "/srv/ia/templates" + templatesDir = "/srv/ia/templates" + portalRequestTimeout = 2 * time.Minute // time limit to get templates list from the portal + ) // templateInfo represents alerting rule template information from various sources. @@ -70,12 +68,11 @@ type templateInfo struct { // TemplatesService is responsible for interactions with IA rule templates. type TemplatesService struct { - db *reform.DB - l *logrus.Entry - userTemplatesPath string - - host string - publicKeys []string + db *reform.DB + l *logrus.Entry + platformClient *platform.Client + userTemplatesPath string + platformPublicKeys []string rw sync.RWMutex templates map[string]templateInfo @@ -84,7 +81,7 @@ type TemplatesService struct { } // NewTemplatesService creates a new TemplatesService. -func NewTemplatesService(db *reform.DB) (*TemplatesService, error) { +func NewTemplatesService(db *reform.DB, platformClient *platform.Client) (*TemplatesService, error) { l := logrus.WithField("component", "management/ia/templates") err := dir.CreateDataDir(templatesDir, "pmm", "pmm", dirPerm) @@ -92,22 +89,19 @@ func NewTemplatesService(db *reform.DB) (*TemplatesService, error) { l.Error(err) } - host, err := envvars.GetSAASHost() - if err != nil { - return nil, err + var platformPublicKeys []string + if k := envvars.GetPlatformPublicKeys(); k != nil { + l.Warnf("Percona Platform public keys changed to %q.", k) + platformPublicKeys = k } s := &TemplatesService{ - db: db, - l: l, - userTemplatesPath: templatesDir, - host: host, - templates: make(map[string]templateInfo), - } - - if k := envvars.GetPublicKeys(); k != nil { - l.Warnf("Public keys changed to %q.", k) - s.publicKeys = k + db: db, + l: l, + platformClient: platformClient, + userTemplatesPath: templatesDir, + platformPublicKeys: platformPublicKeys, + templates: make(map[string]templateInfo), } return s, nil @@ -382,7 +376,6 @@ func (s *TemplatesService) loadTemplatesFromDB() ([]templateInfo, error) { }, ) } - return res, nil } @@ -398,26 +391,15 @@ func (s *TemplatesService) downloadTemplates(ctx context.Context) ([]alert.Templ return nil, nil } - s.l.Infof("Downloading templates from %s ...", s.host) - - var accessToken string - if ssoDetails, err := models.GetPerconaSSODetails(ctx, s.db.Querier); err == nil { - accessToken = ssoDetails.AccessToken.AccessToken - } + nCtx, cancel := context.WithTimeout(ctx, portalRequestTimeout) + defer cancel() - endpoint := fmt.Sprintf("https://%s/v1/check/GetAllAlertRuleTemplates", s.host) - bodyBytes, err := saasreq.MakeRequest(ctx, http.MethodPost, endpoint, accessToken, nil, - &saasreq.SaasRequestOptions{SkipTLSVerification: false}) + resp, err := s.platformClient.GetTemplates(nCtx) if err != nil { - return nil, errors.Wrap(err, "failed to dial") - } - - var resp *api.GetAllAlertRuleTemplatesResponse - if err := json.Unmarshal(bodyBytes, &resp); err != nil { - return nil, err + return nil, errors.WithStack(err) } - if err = signatures.Verify(s.l, resp.File, resp.Signatures, s.publicKeys); err != nil { + if err = signatures.Verify(s.l, resp.File, resp.Signatures, s.platformPublicKeys); err != nil { return nil, err } diff --git a/managed/services/management/ia/templates_service_test.go b/managed/services/management/ia/templates_service_test.go index b9352af392..59f06e83b0 100644 --- a/managed/services/management/ia/templates_service_test.go +++ b/managed/services/management/ia/templates_service_test.go @@ -29,16 +29,17 @@ import ( iav1beta1 "github.com/percona/pmm/api/managementpb/ia" "github.com/percona/pmm/managed/models" + "github.com/percona/pmm/managed/utils/platform" "github.com/percona/pmm/managed/utils/testdb" ) const ( - devPortalHost = "check-dev.percona.com" - devPortalPublicKey = "RWTg+ZmCCjt7O8eWeAmTLAqW+1ozUbpRSKSwNTmO+exlS5KEIPYWuYdX" - testBadTemplates = "../../../testdata/ia/bad" - testTemplates = "../../../testdata/ia/user2" - testTemplates2 = "../../../testdata/ia/user" - issuerURL = "https://id-dev.percona.com/oauth2/aus15pi5rjdtfrcH51d7/v1" + devPlatformAddress = "https://check-dev.percona.com" + devPlatformPublicKey = "RWTg+ZmCCjt7O8eWeAmTLAqW+1ozUbpRSKSwNTmO+exlS5KEIPYWuYdX" + testBadTemplates = "../../../testdata/ia/bad" + testTemplates = "../../../testdata/ia/user2" + testTemplates2 = "../../../testdata/ia/user" + issuerURL = "https://id-dev.percona.com/oauth2/aus15pi5rjdtfrcH51d7/v1" ) func TestCollect(t *testing.T) { @@ -50,6 +51,8 @@ func TestCollect(t *testing.T) { ctx := context.Background() sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, reform.NewPrintfLogger(t.Logf)) + platformClient, err := platform.NewClient(db, devPlatformAddress) + require.NoError(t, err) insertSSODetails := &models.PerconaSSODetailsInsert{ IssuerURL: issuerURL, @@ -57,13 +60,13 @@ func TestCollect(t *testing.T) { PMMManagedClientSecret: clientSecret, Scope: "percona", } - err := models.InsertPerconaSSODetails(db.Querier, insertSSODetails) + err = models.InsertPerconaSSODetails(db.Querier, insertSSODetails) require.NoError(t, err) t.Run("builtin are valid", func(t *testing.T) { t.Parallel() - svc, err := NewTemplatesService(db) + svc, err := NewTemplatesService(db, platformClient) require.NoError(t, err) _, err = svc.loadTemplatesFromAssets(ctx) require.NoError(t, err) @@ -72,7 +75,7 @@ func TestCollect(t *testing.T) { t.Run("bad template paths", func(t *testing.T) { t.Parallel() - svc, err := NewTemplatesService(db) + svc, err := NewTemplatesService(db, platformClient) require.NoError(t, err) svc.userTemplatesPath = testBadTemplates templates, err := svc.loadTemplatesFromUserFiles(ctx) @@ -83,7 +86,7 @@ func TestCollect(t *testing.T) { t.Run("valid template paths", func(t *testing.T) { t.Parallel() - svc, err := NewTemplatesService(db) + svc, err := NewTemplatesService(db, platformClient) require.NoError(t, err) svc.userTemplatesPath = testTemplates2 svc.CollectTemplates(ctx) @@ -114,6 +117,12 @@ func TestDownloadTemplates(t *testing.T) { sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, reform.NewPrintfLogger(t.Logf)) + platformClient, err := platform.NewClient(db, devPlatformAddress) + require.NoError(t, err) + + svc, err := NewTemplatesService(db, platformClient) + svc.platformPublicKeys = []string{devPlatformPublicKey} + require.NoError(t, err) insertSSODetails := &models.PerconaSSODetailsInsert{ IssuerURL: issuerURL, @@ -121,14 +130,9 @@ func TestDownloadTemplates(t *testing.T) { PMMManagedClientSecret: clientSecret, Scope: "percona", } - err := models.InsertPerconaSSODetails(db.Querier, insertSSODetails) + err = models.InsertPerconaSSODetails(db.Querier, insertSSODetails) require.NoError(t, err) - svc, err := NewTemplatesService(db) - require.NoError(t, err) - svc.host = devPortalHost - svc.publicKeys = []string{devPortalPublicKey} - t.Run("normal", func(t *testing.T) { assert.Empty(t, svc.getTemplates()) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -160,6 +164,8 @@ func TestTemplateValidation(t *testing.T) { ctx := context.Background() sqlDB := testdb.Open(t, models.SkipFixtures, nil) db := reform.NewDB(sqlDB, postgresql.Dialect, reform.NewPrintfLogger(t.Logf)) + platformClient, err := platform.NewClient(db, devPlatformAddress) + require.NoError(t, err) // Enable IA settings, err := models.GetSettings(db) @@ -208,7 +214,7 @@ templates: summary: MySQL too many connections (instance {{ $labels.instance }}) ` - svc, err := NewTemplatesService(db) + svc, err := NewTemplatesService(db, platformClient) require.NoError(t, err) resp, err := svc.CreateTemplate(ctx, &iav1beta1.CreateTemplateRequest{ Yaml: templateWithMissingParam, @@ -301,7 +307,7 @@ templates: summary: MySQL too many connections (instance {{ $labels.instance }}) ` - svc, err := NewTemplatesService(db) + svc, err := NewTemplatesService(db, platformClient) require.NoError(t, err) createResp, err := svc.CreateTemplate(ctx, &iav1beta1.CreateTemplateRequest{ Yaml: validTemplate, diff --git a/managed/services/platform/config.go b/managed/services/platform/config.go deleted file mode 100644 index 8f26e087cb..0000000000 --- a/managed/services/platform/config.go +++ /dev/null @@ -1,26 +0,0 @@ -// pmm-managed -// Copyright (C) 2017 Percona LLC -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package platform - -// Config platform config. -type Config struct { - SkipTLSVerification bool `yaml:"skip_tls_verification"` //nolint:tagliatelle -} - -// Init platform config init. -func (c *Config) Init() { -} diff --git a/managed/services/platform/platform.go b/managed/services/platform/platform.go index 2234fee4aa..5242f3cfc2 100644 --- a/managed/services/platform/platform.go +++ b/managed/services/platform/platform.go @@ -18,12 +18,8 @@ package platform import ( - "bytes" "context" - "crypto/tls" - "encoding/json" "fmt" - "net/http" "time" "github.com/pkg/errors" @@ -37,22 +33,18 @@ import ( "github.com/percona/pmm/api/platformpb" "github.com/percona/pmm/managed/models" "github.com/percona/pmm/managed/services/grafana" - "github.com/percona/pmm/managed/utils/envvars" + "github.com/percona/pmm/managed/utils/platform" ) const rollbackFailed = "Failed to rollback:" -var ( - errInternalServer = status.Error(codes.Internal, "Internal server error") - errGetSSODetailsFailed = status.Error(codes.Aborted, "Failed to fetch SSO details.") -) +var errGetSSODetailsFailed = status.Error(codes.Aborted, "Failed to fetch SSO details.") // Service is responsible for interactions with Percona Platform. type Service struct { db *reform.DB - host string l *logrus.Entry - client http.Client + client *platform.Client grafanaClient grafanaClient supervisord supervisordService checksService checksService @@ -61,30 +53,16 @@ type Service struct { } // New returns platform Service. -func New(db *reform.DB, supervisord supervisordService, checksService checksService, grafanaClient grafanaClient, c Config) (*Service, error) { +func New(client *platform.Client, db *reform.DB, supervisord supervisordService, checksService checksService, grafanaClient grafanaClient) (*Service, error) { l := logrus.WithField("component", "platform") - host, err := envvars.GetSAASHost() - if err != nil { - return nil, err - } - - timeout := envvars.GetPlatformAPITimeout(l) - s := Service{ - host: host, db: db, + client: client, l: l, supervisord: supervisord, checksService: checksService, - client: http.Client{ - Timeout: timeout, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: c.SkipTLSVerification, //nolint:gosec - }, - }, - }, + grafanaClient: grafanaClient, } @@ -100,36 +78,32 @@ func (s *Service) Connect(ctx context.Context, req *platformpb.ConnectRequest) ( settings, err := models.GetSettings(s.db) if err != nil { s.l.Errorf("Failed to fetch PMM server ID and address: %s", err) - return nil, errInternalServer + return nil, err } if settings.PMMPublicAddress == "" { return nil, status.Error(codes.FailedPrecondition, "The address of PMM server is not set") } + pmmServerURL := fmt.Sprintf("https://%s/graph", settings.PMMPublicAddress) + pmmServerOAuthCallbackURL := fmt.Sprintf("%s/login/generic_oauth", pmmServerURL) - connectResp, err := s.connect(ctx, &connectPMMParams{ - serverName: req.ServerName, - pmmServerURL: pmmServerURL, - pmmServerOAuthCallbackURL: fmt.Sprintf("%s/login/generic_oauth", pmmServerURL), - pmmServerID: settings.PMMServerID, - personalAccessToken: req.PersonalAccessToken, - }) + resp, err := s.client.Connect(ctx, req.PersonalAccessToken, settings.PMMServerID, req.ServerName, pmmServerURL, pmmServerOAuthCallbackURL) if err != nil { - return nil, err // this is already a status error + return nil, err } err = models.InsertPerconaSSODetails(s.db.Querier, &models.PerconaSSODetailsInsert{ - PMMManagedClientID: connectResp.SSODetails.PMMManagedClientID, - PMMManagedClientSecret: connectResp.SSODetails.PMMManagedClientSecret, - GrafanaClientID: connectResp.SSODetails.GrafanaClientID, - IssuerURL: connectResp.SSODetails.IssuerURL, - Scope: connectResp.SSODetails.Scope, - OrganizationID: connectResp.OrganizationID, + PMMManagedClientID: resp.SSODetails.PMMManagedClientID, + PMMManagedClientSecret: resp.SSODetails.PMMManagedClientSecret, + GrafanaClientID: resp.SSODetails.GrafanaClientID, + IssuerURL: resp.SSODetails.IssuerURL, + Scope: resp.SSODetails.Scope, + OrganizationID: resp.OrganizationID, PMMServerName: req.ServerName, }) if err != nil { s.l.Errorf("Failed to insert SSO details: %s", err) - return nil, errInternalServer + return nil, err } if !settings.SaaS.STTDisabled { @@ -138,7 +112,7 @@ func (s *Service) Connect(ctx context.Context, req *platformpb.ConnectRequest) ( if err := s.UpdateSupervisordConfigurations(ctx); err != nil { s.l.Errorf("Failed to update configuration of grafana after connecting PMM to Portal: %s", err) - return nil, errInternalServer + return nil, err } return &platformpb.ConnectResponse{}, nil } @@ -154,7 +128,7 @@ func (s *Service) Disconnect(ctx context.Context, req *platformpb.DisconnectRequ settings, err := models.GetSettings(s.db) if err != nil { s.l.Errorf("Failed to fetch PMM server ID and address: %s", err) - return nil, errInternalServer + return nil, err } err = models.DeletePerconaSSODetails(s.db.Querier) @@ -163,12 +137,19 @@ func (s *Service) Disconnect(ctx context.Context, req *platformpb.DisconnectRequ if e := s.UpdateSupervisordConfigurations(ctx); e != nil { s.l.Errorf("%s %s", rollbackFailed, e) } - return nil, errInternalServer + return nil, err + } + + userAccessToken, err := s.grafanaClient.GetCurrentUserAccessToken(ctx) + if err != nil { + if errors.Is(err, grafana.ErrFailedToGetToken) { + return nil, status.Error(codes.FailedPrecondition, "Failed to get access token. Please sign in using your Percona Account.") + } + s.l.Errorf("Disconnect to Platform request failed: %s", err) + return nil, err } - err = s.disconnect(ctx, &disconnectPMMParams{ - PMMServerID: settings.PMMServerID, - }) + err = s.client.Disconnect(ctx, userAccessToken, settings.PMMServerID) needRecover := err != nil && !req.Force if needRecover { @@ -196,7 +177,7 @@ func (s *Service) Disconnect(ctx context.Context, req *platformpb.DisconnectRequ if err = s.UpdateSupervisordConfigurations(ctx); err != nil { s.l.Errorf("Failed to update configuration of grafana after disconnect from Platform: %s", err) - return nil, errInternalServer + return nil, err } return &platformpb.DisconnectResponse{}, nil @@ -219,149 +200,15 @@ func (s *Service) UpdateSupervisordConfigurations(ctx context.Context) error { return nil } -type connectPMMParams struct { - pmmServerURL, pmmServerOAuthCallbackURL, pmmServerID, serverName, personalAccessToken string -} - -type connectPMMRequest struct { - PMMServerID string `json:"pmm_server_id"` - PMMServerName string `json:"pmm_server_name"` - PMMServerURL string `json:"pmm_server_url"` - PMMServerOAuthCallbackURL string `json:"pmm_server_oauth_callback_url"` -} - -type disconnectPMMParams struct { - PMMServerID string -} - -type ssoDetails struct { - GrafanaClientID string `json:"grafana_client_id"` //nolint:tagliatelle - PMMManagedClientID string `json:"pmm_managed_client_id"` //nolint:tagliatelle - PMMManagedClientSecret string `json:"pmm_managed_client_secret"` //nolint:tagliatelle - Scope string `json:"scope"` - IssuerURL string `json:"issuer_url"` //nolint:tagliatelle -} - -type connectPMMResponse struct { - SSODetails *ssoDetails `json:"sso_details"` - OrganizationID string `json:"org_id"` -} - -type grpcGatewayError struct { - Message string `json:"message"` - Code uint32 `json:"code"` -} - -func (s *Service) connect(ctx context.Context, params *connectPMMParams) (*connectPMMResponse, error) { - endpoint := fmt.Sprintf("https://%s/v1/orgs/inventory", s.host) - marshaled, err := json.Marshal(connectPMMRequest{ - PMMServerID: params.pmmServerID, - PMMServerName: params.serverName, - PMMServerURL: params.pmmServerURL, - PMMServerOAuthCallbackURL: params.pmmServerOAuthCallbackURL, - }) - if err != nil { - s.l.Errorf("Failed to marshal request data: %s", err) - return nil, errInternalServer - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(marshaled)) - if err != nil { - s.l.Errorf("Failed to build Connect to Platform request: %s", err) - return nil, errInternalServer - } - h := req.Header - h.Add("Authorization", fmt.Sprintf("Bearer %s", params.personalAccessToken)) - resp, err := s.client.Do(req) - if err != nil { - s.l.Errorf("Connect to Platform request failed: %s", err) - return nil, errInternalServer - } - defer resp.Body.Close() //nolint:errcheck - - decoder := json.NewDecoder(resp.Body) - if resp.StatusCode != http.StatusOK { - var gwErr grpcGatewayError - if err := decoder.Decode(&gwErr); err != nil { - s.l.Errorf("Connect to Platform request failed and we failed to decode error message: %s", err) - return nil, errInternalServer - } - return nil, status.Error(codes.Code(gwErr.Code), gwErr.Message) - } - - response := &connectPMMResponse{} - if err := decoder.Decode(response); err != nil { - s.l.Errorf("Failed to decode response into SSO details: %s", err) - return nil, errInternalServer - } - return response, nil -} - -func (s *Service) disconnect(ctx context.Context, params *disconnectPMMParams) error { - userAccessToken, err := s.grafanaClient.GetCurrentUserAccessToken(ctx) - if err != nil { - if errors.Is(err, grafana.ErrFailedToGetToken) { - return status.Error(codes.FailedPrecondition, "Failed to get access token. Please sign in using your Percona Account.") - } - s.l.Errorf("Disconnect to Platform request failed: %s", err) - return errInternalServer - } - - endpoint := fmt.Sprintf("https://%s/v1/orgs/inventory/%s:disconnect", s.host, params.PMMServerID) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil) - if err != nil { - s.l.Errorf("Failed to build Disconnect to Platform request: %s", err) - return errInternalServer - } - - h := req.Header - h.Add("Authorization", fmt.Sprintf("Bearer %s", userAccessToken)) - - resp, err := s.client.Do(req) - if err != nil { - s.l.Errorf("Disconnect to Platform request failed: %s", err) - return errInternalServer - } - defer resp.Body.Close() //nolint:errcheck - - decoder := json.NewDecoder(resp.Body) - if resp.StatusCode != http.StatusOK { - var gwErr grpcGatewayError - if err := decoder.Decode(&gwErr); err != nil { - s.l.Errorf("Disconnect to Platform request failed and we failed to decode error message: %s", err) - return errInternalServer - } - return status.Error(codes.Code(gwErr.Code), gwErr.Message) - } - - return nil -} - -type searchOrganizationTicketsResponse struct { - Tickets []*ticketResponse `json:"tickets"` -} - -type ticketResponse struct { - Number string `json:"number"` - ShortDescription string `json:"short_description"` //nolint:tagliatelle - Priority string `json:"priority"` - State string `json:"state"` - CreateTime string `json:"create_time"` //nolint:tagliatelle - Department string `json:"department"` - Requester string `json:"requestor"` - TaskType string `json:"task_type"` //nolint:tagliatelle - URL string `json:"url"` -} - // SearchOrganizationTickets fetches the list of ticket associated with the Portal organization this PMM server is registered with. func (s *Service) SearchOrganizationTickets(ctx context.Context, req *platformpb.SearchOrganizationTicketsRequest) (*platformpb.SearchOrganizationTicketsResponse, error) { - userAccessToken, err := s.grafanaClient.GetCurrentUserAccessToken(ctx) + accessToken, err := s.grafanaClient.GetCurrentUserAccessToken(ctx) if err != nil { if errors.Is(err, grafana.ErrFailedToGetToken) { return nil, status.Error(codes.Unauthenticated, "Failed to get access token. Please sign in using your Percona Account.") } s.l.Errorf("SearchOrganizationTickets request failed: %s", err) - return nil, errInternalServer + return nil, err } ssoDetails, err := models.GetPerconaSSODetails(ctx, s.db.Querier) @@ -370,49 +217,17 @@ func (s *Service) SearchOrganizationTickets(ctx context.Context, req *platformpb return nil, errGetSSODetailsFailed } - endpoint := fmt.Sprintf("https://%s/v1/orgs/%s/tickets:search", s.host, ssoDetails.OrganizationID) - - r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil) + resp, err := s.client.SearchOrgTickets(ctx, accessToken, ssoDetails.OrganizationID) if err != nil { - s.l.Errorf("Failed to build SearchOrganizationTickets request: %s", err) - return nil, errInternalServer - } - - h := r.Header - h.Add("Authorization", fmt.Sprintf("Bearer %s", userAccessToken)) - - resp, err := s.client.Do(r) - if err != nil { - s.l.Errorf("SearchOrganizationTickets request failed: %s", err) - return nil, errInternalServer - } - defer resp.Body.Close() //nolint:errcheck - - decoder := json.NewDecoder(resp.Body) - if resp.StatusCode != http.StatusOK { - var gwErr grpcGatewayError - if err := decoder.Decode(&gwErr); err != nil { - s.l.Errorf("SearchOrganizationRequest failed to decode error message: %s", err) - return nil, errInternalServer - } - return nil, status.Error(codes.Code(gwErr.Code), gwErr.Message) - } - - // the response from portal contains the timestamp as a string - // so we first unmarshal the response to an internal type with a string - // timestamp field and then convert it to the type used by the public API. - platformResponse := &searchOrganizationTicketsResponse{} - if err := decoder.Decode(platformResponse); err != nil { - s.l.Errorf("Failed to decode response into OrganizationTickets: %s", err) - return nil, errInternalServer + return nil, err } response := &platformpb.SearchOrganizationTicketsResponse{} - for _, t := range platformResponse.Tickets { + for _, t := range resp.Tickets { ticket, err := convertTicket(t) if err != nil { s.l.Errorf("Failed to convert OrganizationTickets: %s", err) - return nil, errInternalServer + return nil, err } response.Tickets = append(response.Tickets, ticket) } @@ -420,7 +235,7 @@ func (s *Service) SearchOrganizationTickets(ctx context.Context, req *platformpb return response, nil } -func convertTicket(t *ticketResponse) (*platformpb.OrganizationTicket, error) { +func convertTicket(t *platform.TicketResponse) (*platformpb.OrganizationTicket, error) { createTime, err := time.Parse(time.RFC3339, t.CreateTime) if err != nil { return nil, err @@ -439,38 +254,15 @@ func convertTicket(t *ticketResponse) (*platformpb.OrganizationTicket, error) { }, nil } -type searchOrganizationEntitlementsResponse struct { - Entitlement []*entitlementResponse `json:"entitlements"` -} - -type entitlementResponse struct { - Number string `json:"number"` - Name string `json:"name"` - Summary string `json:"summary"` - Tier string `json:"tier"` - TotalUnits string `json:"total_units"` //nolint:tagliatelle - UnlimitedUnits bool `json:"unlimited_units"` //nolint:tagliatelle - SupportLevel string `json:"support_level"` //nolint:tagliatelle - SoftwareFamilies []string `json:"software_families"` //nolint:tagliatelle - StartDate string `json:"start_date"` //nolint:tagliatelle - EndDate string `json:"end_date"` //nolint:tagliatelle - Platform platformResponse `json:"platform"` -} - -type platformResponse struct { - SecurityAdvisor string `json:"security_advisor"` //nolint:tagliatelle - ConfigAdvisor string `json:"config_advisor"` //nolint:tagliatelle -} - // SearchOrganizationEntitlements fetches customer entitlements for a particular organization. func (s *Service) SearchOrganizationEntitlements(ctx context.Context, req *platformpb.SearchOrganizationEntitlementsRequest) (*platformpb.SearchOrganizationEntitlementsResponse, error) { - userAccessToken, err := s.grafanaClient.GetCurrentUserAccessToken(ctx) + accessToken, err := s.grafanaClient.GetCurrentUserAccessToken(ctx) if err != nil { if errors.Is(err, grafana.ErrFailedToGetToken) { return nil, status.Error(codes.Unauthenticated, "Failed to get access token. Please sign in using your Percona Account.") } s.l.Errorf("SearchOrganizationEntitlements request failed: %s", err) - return nil, errInternalServer + return nil, err } ssoDetails, err := models.GetPerconaSSODetails(ctx, s.db.Querier) @@ -479,49 +271,17 @@ func (s *Service) SearchOrganizationEntitlements(ctx context.Context, req *platf return nil, errGetSSODetailsFailed } - endpoint := fmt.Sprintf("https://%s/v1/orgs/%s/entitlements:search", s.host, ssoDetails.OrganizationID) - - r, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil) + resp, err := s.client.SearchOrgEntitlements(ctx, accessToken, ssoDetails.OrganizationID) if err != nil { - s.l.Errorf("Failed to build SearchOrganizationEntitlements request: %s", err) - return nil, errInternalServer - } - - h := r.Header - h.Add("Authorization", fmt.Sprintf("Bearer %s", userAccessToken)) - - resp, err := s.client.Do(r) - if err != nil { - s.l.Errorf("SearchOrganizationEntitlements request failed: %s", err) - return nil, errInternalServer - } - defer resp.Body.Close() //nolint:errcheck - - decoder := json.NewDecoder(resp.Body) - if resp.StatusCode != http.StatusOK { - var gwErr grpcGatewayError - if err := decoder.Decode(&gwErr); err != nil { - s.l.Errorf("Failed to decode error message: %s", err) - return nil, errInternalServer - } - return nil, status.Error(codes.Code(gwErr.Code), gwErr.Message) - } - - // the response from portal contains the timestamp as a string - // so we first unmarshal the response to an internal type with a string - // timestamp field and then convert it to the type used by the public API. - platformResp := &searchOrganizationEntitlementsResponse{} - if err := decoder.Decode(platformResp); err != nil { - s.l.Errorf("Failed to decode response into OrganizationTickets: %s", err) - return nil, errInternalServer + return nil, err } response := &platformpb.SearchOrganizationEntitlementsResponse{} - for _, e := range platformResp.Entitlement { + for _, e := range resp.Entitlement { entitlement, err := convertEntitlement(e) if err != nil { s.l.Errorf("Failed to convert OrganizationEntitlements: %s", err) - return nil, errInternalServer + return nil, err } response.Entitlements = append(response.Entitlements, entitlement) } @@ -529,7 +289,7 @@ func (s *Service) SearchOrganizationEntitlements(ctx context.Context, req *platf return response, nil } -func convertEntitlement(ent *entitlementResponse) (*platformpb.OrganizationEntitlement, error) { +func convertEntitlement(ent *platform.EntitlementResponse) (*platformpb.OrganizationEntitlement, error) { startDate, err := time.Parse(time.RFC3339, ent.StartDate) if err != nil { return nil, err @@ -558,26 +318,16 @@ func convertEntitlement(ent *entitlementResponse) (*platformpb.OrganizationEntit }, nil } -type contactInformation struct { - Contacts struct { - CustomerSuccess struct { - Name string `json:"name"` - Email string `json:"email"` - } `json:"customer_success"` //nolint:tagliatelle - NewTicketURL string `json:"new_ticket_url"` //nolint:tagliatelle - } `json:"contacts"` -} - // GetContactInformation fetches contact information of the Customer Success employee assigned to the Percona customer from Percona Portal. func (s *Service) GetContactInformation(ctx context.Context, req *platformpb.GetContactInformationRequest) (*platformpb.GetContactInformationResponse, error) { - userAccessToken, err := s.grafanaClient.GetCurrentUserAccessToken(ctx) + accessToken, err := s.grafanaClient.GetCurrentUserAccessToken(ctx) if err != nil { if errors.Is(err, grafana.ErrFailedToGetToken) { s.l.Error("Failed to get access token.") return nil, status.Error(codes.Unauthenticated, "Failed to get access token. Please sign in using your Percona Account.") } s.l.Errorf("GetContactInformation request failed: %s", err) - return nil, errInternalServer + return nil, err } ssoDetails, err := models.GetPerconaSSODetails(ctx, s.db.Querier) @@ -586,62 +336,33 @@ func (s *Service) GetContactInformation(ctx context.Context, req *platformpb.Get return nil, status.Error(codes.Aborted, "PMM server is not connected to Portal") } - endpoint := fmt.Sprintf("https://%s/v1/orgs/%s", s.host, ssoDetails.OrganizationID) - - r, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - s.l.Errorf("Failed to build GetContactInformation request: %s", err) - return nil, errInternalServer - } - - h := r.Header - h.Add("Authorization", fmt.Sprintf("Bearer %s", userAccessToken)) - - resp, err := s.client.Do(r) + resp, err := s.client.GetContactInformation(ctx, accessToken, ssoDetails.OrganizationID) if err != nil { - s.l.Errorf("GetContactInformation request failed: %s", err) - return nil, errInternalServer - } - defer resp.Body.Close() //nolint:errcheck - - decoder := json.NewDecoder(resp.Body) - if resp.StatusCode != http.StatusOK { - var gwErr grpcGatewayError - if err := decoder.Decode(&gwErr); err != nil { - s.l.Errorf("Failed to decode error message: %s", err) - return nil, errInternalServer - } - return nil, status.Error(codes.Code(gwErr.Code), gwErr.Message) - } - - var platformResp contactInformation - if err := decoder.Decode(&platformResp); err != nil { - s.l.Errorf("Failed to decode response : %s", err) - return nil, errInternalServer + return nil, err } - res := &platformpb.GetContactInformationResponse{ + response := &platformpb.GetContactInformationResponse{ CustomerSuccess: &platformpb.GetContactInformationResponse_CustomerSuccess{ - Name: platformResp.Contacts.CustomerSuccess.Name, - Email: platformResp.Contacts.CustomerSuccess.Email, + Name: resp.Contacts.CustomerSuccess.Name, + Email: resp.Contacts.CustomerSuccess.Email, }, - NewTicketUrl: platformResp.Contacts.NewTicketURL, + NewTicketUrl: resp.Contacts.NewTicketURL, } // Platform account is not linked to ServiceNow. - if res.CustomerSuccess.Email == "" { + if response.CustomerSuccess.Email == "" { s.l.Error("Failed to find contact information, non-customer account.") return nil, status.Error(codes.FailedPrecondition, "Platform account user is not a Percona customer.") } - return res, nil + return response, nil } func (s *Service) ServerInfo(ctx context.Context, req *platformpb.ServerInfoRequest) (*platformpb.ServerInfoResponse, error) { settings, err := models.GetSettings(s.db) if err != nil { s.l.Errorf("Failed to fetch PMM server ID: %s", err) - return nil, errInternalServer + return nil, err } serverName := "" @@ -678,7 +399,7 @@ func (s *Service) UserStatus(ctx context.Context, req *platformpb.UserStatusRequ return nil, status.Error(codes.Unauthenticated, "Failed to get access token. Please sign in using your Percona Account.") } s.l.Errorf("UserStatus request failed: %s", err) - return nil, errInternalServer + return nil, err } return &platformpb.UserStatusResponse{ diff --git a/managed/services/telemetry/config.go b/managed/services/telemetry/config.go index e1cbe109b5..83dccc6893 100644 --- a/managed/services/telemetry/config.go +++ b/managed/services/telemetry/config.go @@ -19,7 +19,6 @@ package telemetry import ( _ "embed" //nolint:golint - "fmt" "os" "time" @@ -33,11 +32,10 @@ import ( // ServiceConfig telemetry config. type ServiceConfig struct { l *logrus.Entry - Enabled bool `yaml:"enabled"` - LoadDefaults bool `yaml:"load_defaults"` //nolint:tagliatelle - telemetry []Config `yaml:"-"` - Endpoints EndpointsConfig `yaml:"endpoints"` - SaasHostname string `yaml:"saas_hostname"` //nolint:tagliatelle + Enabled bool `yaml:"enabled"` + LoadDefaults bool `yaml:"load_defaults"` //nolint:tagliatelle + telemetry []Config `yaml:"-"` + SaasHostname string `yaml:"saas_hostname"` //nolint:tagliatelle DataSources struct { VM *DataSourceVictoriaMetrics `yaml:"VM"` QanDBSelect *DSConfigQAN `yaml:"QANDB_SELECT"` //nolint:tagliatelle @@ -51,16 +49,6 @@ type FileConfig struct { Telemetry []Config `yaml:"telemetry"` } -// EndpointsConfig telemetry endpoint config. -type EndpointsConfig struct { - Report string `yaml:"report"` -} - -// ReportEndpointURL returns reporting endpoint URL. -func (c *ServiceConfig) ReportEndpointURL() string { - return fmt.Sprintf(c.Endpoints.Report, c.SaasHostname) -} - // DSConfigQAN telemetry config. type DSConfigQAN struct { Enabled bool `yaml:"enabled"` @@ -124,14 +112,13 @@ func (c *Config) mapByColumn() map[string][]ConfigData { // ReportingConfig reporting config. type ReportingConfig struct { - SkipTLSVerification bool `yaml:"skip_tls_verification"` //nolint:tagliatelle - SendOnStart bool `yaml:"send_on_start"` //nolint:tagliatelle - IntervalEnv string `yaml:"interval_env"` //nolint:tagliatelle - Interval time.Duration `yaml:"interval"` - RetryBackoffEnv string `yaml:"retry_backoff_env"` //nolint:tagliatelle - RetryBackoff time.Duration `yaml:"retry_backoff"` //nolint:tagliatelle - SendTimeout time.Duration `yaml:"send_timeout"` //nolint:tagliatelle - RetryCount int `yaml:"retry_count"` //nolint:tagliatelle + SendOnStart bool `yaml:"send_on_start"` //nolint:tagliatelle + IntervalEnv string `yaml:"interval_env"` //nolint:tagliatelle + Interval time.Duration `yaml:"interval"` + RetryBackoffEnv string `yaml:"retry_backoff_env"` //nolint:tagliatelle + RetryBackoff time.Duration `yaml:"retry_backoff"` //nolint:tagliatelle + SendTimeout time.Duration `yaml:"send_timeout"` //nolint:tagliatelle + RetryCount int `yaml:"retry_count"` //nolint:tagliatelle } //go:embed config.default.yml @@ -157,7 +144,7 @@ func (c *ServiceConfig) Init(l *logrus.Entry) error { //nolint:gocognit } if c.SaasHostname == "" { - host, err := envvars.GetSAASHost() + host, err := envvars.GetPlatformAddress() c.SaasHostname = host if err != nil { return errors.Wrap(err, "failed to get SaaSHost") diff --git a/managed/services/telemetry/config_test.go b/managed/services/telemetry/config_test.go index 73ad1631ca..17ef609ed5 100644 --- a/managed/services/telemetry/config_test.go +++ b/managed/services/telemetry/config_test.go @@ -70,18 +70,14 @@ reporting: Enabled: true, LoadDefaults: true, SaasHostname: "check.localhost", - Endpoints: EndpointsConfig{ - Report: "https://%s/v1/telemetry/Report", - }, Reporting: ReportingConfig{ - SkipTLSVerification: true, - SendOnStart: true, - Interval: time.Second * 10, - IntervalEnv: "PERCONA_TEST_TELEMETRY_INTERVAL", - RetryBackoff: time.Second * 1, - RetryBackoffEnv: "PERCONA_TEST_TELEMETRY_RETRY_BACKOFF", - RetryCount: 2, - SendTimeout: time.Second * 10, + SendOnStart: true, + Interval: time.Second * 10, + IntervalEnv: "PERCONA_TEST_TELEMETRY_INTERVAL", + RetryBackoff: time.Second * 1, + RetryBackoffEnv: "PERCONA_TEST_TELEMETRY_RETRY_BACKOFF", + RetryCount: 2, + SendTimeout: time.Second * 10, }, DataSources: struct { VM *DataSourceVictoriaMetrics `yaml:"VM"` diff --git a/managed/services/telemetry/telemetry.go b/managed/services/telemetry/telemetry.go index 7025c50672..6bf6f6eba8 100644 --- a/managed/services/telemetry/telemetry.go +++ b/managed/services/telemetry/telemetry.go @@ -22,7 +22,6 @@ import ( "context" "encoding/hex" "io/ioutil" - "net/http" "regexp" "strings" "time" @@ -32,14 +31,13 @@ import ( reporter "github.com/percona-platform/saas/gen/telemetry/reporter" "github.com/pkg/errors" "github.com/sirupsen/logrus" - "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" "gopkg.in/reform.v1" "github.com/percona/pmm/api/serverpb" "github.com/percona/pmm/managed/models" - "github.com/percona/pmm/managed/utils/saasreq" + "github.com/percona/pmm/managed/utils/platform" ) const ( @@ -51,6 +49,7 @@ const ( type Service struct { db *reform.DB l *logrus.Entry + portalClient *platform.Client start time.Time config ServiceConfig dsRegistry DataSourceLocator @@ -71,7 +70,7 @@ var ( ) // NewService creates a new service. -func NewService(db *reform.DB, pmmVersion string, config ServiceConfig) (*Service, error) { +func NewService(db *reform.DB, portalClient *platform.Client, pmmVersion string, config ServiceConfig) (*Service, error) { if config.SaasHostname == "" { return nil, errors.New("empty host") } @@ -83,12 +82,13 @@ func NewService(db *reform.DB, pmmVersion string, config ServiceConfig) (*Servic return nil, err } s := &Service{ - db: db, - l: l, - pmmVersion: pmmVersion, - start: time.Now(), - config: config, - dsRegistry: registry, + db: db, + l: l, + portalClient: portalClient, + pmmVersion: pmmVersion, + start: time.Now(), + config: config, + dsRegistry: registry, } s.sDistributionMethod, s.tDistributionMethod, s.os = getDistributionMethodAndOS(l) @@ -307,7 +307,8 @@ func (s *Service) send(ctx context.Context, report *reporter.ReportRequest) erro var err error var attempt int for { - err = s.sendRequest(ctx, report) + s.l.Debugf("Using %s as telemetry host.", s.config.SaasHostname) + err = s.portalClient.SendTelemetry(ctx, report) attempt++ s.l.Debugf("sendV2Request (attempt %d/%d) result: %v", attempt, s.config.Reporting.RetryCount, err) if err == nil { @@ -329,26 +330,3 @@ func (s *Service) send(ctx context.Context, report *reporter.ReportRequest) erro } } } - -func (s *Service) sendRequest(ctx context.Context, req *reporter.ReportRequest) error { - s.l.Debugf("Using %s as telemetry host.", s.config.SaasHostname) - - var accessToken string - if ssoDetails, err := models.GetPerconaSSODetails(ctx, s.db.Querier); err == nil { - accessToken = ssoDetails.AccessToken.AccessToken - } - - reqByte, err := protojson.Marshal(req) - if err != nil { - return err - } - - _, err = saasreq.MakeRequest(ctx, http.MethodPost, s.config.ReportEndpointURL(), accessToken, bytes.NewReader(reqByte), &saasreq.SaasRequestOptions{ - SkipTLSVerification: s.config.Reporting.SkipTLSVerification, - }) - if err != nil { - return errors.Wrap(err, "failed to dial") - } - - return nil -} diff --git a/managed/utils/envvars/parser.go b/managed/utils/envvars/parser.go index af08c39fb5..e3814c8184 100644 --- a/managed/utils/envvars/parser.go +++ b/managed/utils/envvars/parser.go @@ -19,7 +19,7 @@ package envvars import ( "fmt" - "net" + "net/url" "os" "strconv" "strings" @@ -32,13 +32,14 @@ import ( ) const ( - defaultSaaSHost = "check.percona.com" - envSaaSHost = "PERCONA_TEST_SAAS_HOST" - envPublicKey = "PERCONA_TEST_CHECKS_PUBLIC_KEY" + defaultPlatformAddress = "https://check.percona.com" + envPlatformAddress = "PERCONA_TEST_PLATFORM_ADDRESS" + envPlatformInsecure = "PERCONA_TEST_PLATFORM_INSECURE" + envPlatformPublicKey = "PERCONA_TEST_PLATFORM_PUBLIC_KEY" // TODO REMOVE PERCONA_TEST_DBAAS IN FUTURE RELEASES. envTestDbaas = "PERCONA_TEST_DBAAS" envEnableDbaas = "ENABLE_DBAAS" - envPlatfromAPITimeout = "PERCONA_PLATFORM_API_TIMEOUT" + envPlatformAPITimeout = "PERCONA_PLATFORM_API_TIMEOUT" defaultPlatformAPITimeout = 30 * time.Second ) @@ -142,8 +143,11 @@ func ParseEnvVars(envs []string) (envSettings *models.ChangeSettingsParams, errs err = fmt.Errorf("invalid value %q for environment variable %q", v, k) } - case "PERCONA_TEST_AUTH_HOST", "PERCONA_TEST_CHECKS_HOST", "PERCONA_TEST_TELEMETRY_HOST": - err = fmt.Errorf("environment variable %q is removed and replaced by %q", k, envSaaSHost) + case "PERCONA_TEST_AUTH_HOST", "PERCONA_TEST_CHECKS_HOST", "PERCONA_TEST_TELEMETRY_HOST", "PERCONA_TEST_SAAS_HOST": + warns = append(warns, fmt.Sprintf("environment variable %q is removed and replaced by %q", k, envPlatformAddress)) + + case "PERCONA_TEST_CHECKS_PUBLIC_KEY": + warns = append(warns, fmt.Sprintf("environment variable %q is removed and replaced by %q", k, envPlatformPublicKey)) case "PMM_PUBLIC_ADDRESS": envSettings.PMMPublicAddress = v @@ -160,7 +164,7 @@ func ParseEnvVars(envs []string) (envSettings *models.ChangeSettingsParams, errs warns = append(warns, fmt.Sprintf("environment variable %q IS DEPRECATED AND WILL BE REMOVED, USE %q INSTEAD", envTestDbaas, envEnableDbaas)) } - case envPlatfromAPITimeout: + case envPlatformAPITimeout: // This variable is not part of the settings and is parsed separately. continue @@ -220,7 +224,7 @@ func parseStringDuration(value string) (time.Duration, error) { func parsePlatformAPITimeout(d string) (time.Duration, string) { if d == "" { - msg := fmt.Sprintf("Environment variable %q is not set, using %q as a default timeout for platform API.", envPlatfromAPITimeout, defaultPlatformAPITimeout.String()) + msg := fmt.Sprintf("Environment variable %q is not set, using %q as a default timeout for platform API.", envPlatformAPITimeout, defaultPlatformAPITimeout.String()) return defaultPlatformAPITimeout, msg } duration, err := parseStringDuration(d) @@ -234,52 +238,43 @@ func parsePlatformAPITimeout(d string) (time.Duration, string) { // GetPlatformAPITimeout returns timeout duration for requests to Platform. func GetPlatformAPITimeout(l *logrus.Entry) time.Duration { - d := os.Getenv(envPlatfromAPITimeout) + d := os.Getenv(envPlatformAPITimeout) duration, msg := parsePlatformAPITimeout(d) l.Info(msg) return duration } -// GetSAASHost returns SaaS host env variable value if it's present and valid. -// Otherwise returns defaultSaaSHost. -func GetSAASHost() (string, error) { - v := os.Getenv(envSaaSHost) - host, err := parseSAASHost(v) - if err != nil { - return "", err +// GetPlatformAddress returns Percona Platform address env variable value if it's present and valid. +// Otherwise returns default Percona Platform address. +func GetPlatformAddress() (string, error) { + address := os.Getenv(envPlatformAddress) + if address == "" { + logrus.Infof("Using default Percona Platform address %q.", defaultPlatformAddress) + return defaultPlatformAddress, nil + } + + if _, err := url.Parse(address); err != nil { + return "", errors.Errorf("invalid percona platform address: %s", err) } - logrus.Infof("Using SaaS host %q.", host) - return host, nil + logrus.Infof("Using Percona Platform address %q.", address) + return address, nil } -// GetPublicKeys returns public keys used to dowload checks from SaaS. -func GetPublicKeys() []string { - if v := os.Getenv(envPublicKey); v != "" { - return strings.Split(v, ",") - } +// GetPlatformInsecure returns true if invalid/self-signed TLS certificates allowed. Default is false. +func GetPlatformInsecure() bool { + insecure, _ := strconv.ParseBool(os.Getenv(envPlatformInsecure)) - return nil + return insecure } -// parseSAASHost parses, validates and returns SAAS host, otherwise returns error. -func parseSAASHost(v string) (string, error) { - if v == "" { - logrus.Infof("Using default SaaS host %q.", defaultSaaSHost) - return defaultSaaSHost, nil - } - if strings.HasPrefix(v, ":") { - return "", fmt.Errorf("environment variable %q has invalid format %q. Expected host[:port]", envSaaSHost, v) +// GetPlatformPublicKeys returns public keys used to verify signatures of files downloaded form Percona Portal. +func GetPlatformPublicKeys() []string { + if v := os.Getenv(envPlatformPublicKey); v != "" { + return strings.Split(v, ",") } - host, _, err := net.SplitHostPort(v) - if err != nil && strings.Count(v, ":") >= 1 { - return "", err - } - if host == "" { - host = v - } - return host, nil + return nil } func formatEnvVariableError(err error, env, value string) error { diff --git a/managed/utils/envvars/parser_test.go b/managed/utils/envvars/parser_test.go index 2b2c31c66c..da79883c18 100644 --- a/managed/utils/envvars/parser_test.go +++ b/managed/utils/envvars/parser_test.go @@ -123,15 +123,32 @@ func TestEnvVarValidator(t *testing.T) { assert.Nil(t, gotWarns) }) - t.Run("SAAS env vars with warnings", func(t *testing.T) { + t.Run("PERCONA_TEST_PLATFORM_ADDRESS env vars with warnings", func(t *testing.T) { t.Parallel() envs := []string{ - "PERCONA_TEST_SAAS_HOST=host:333", + "PERCONA_TEST_PLATFORM_ADDRESS=https://host:333", } expectedEnvVars := &models.ChangeSettingsParams{} expectedWarns := []string{ - `environment variable "PERCONA_TEST_SAAS_HOST" IS NOT SUPPORTED and WILL BE REMOVED IN THE FUTURE`, + `environment variable "PERCONA_TEST_PLATFORM_ADDRESS" IS NOT SUPPORTED and WILL BE REMOVED IN THE FUTURE`, + } + + gotEnvVars, gotErrs, gotWarns := ParseEnvVars(envs) + assert.Nil(t, gotErrs) + assert.Equal(t, expectedEnvVars, gotEnvVars) + assert.Equal(t, expectedWarns, gotWarns) + }) + + t.Run("PERCONA_TEST_CHECKS_PUBLIC_KEY env vars with warnings", func(t *testing.T) { + t.Parallel() + + envs := []string{ + "PERCONA_TEST_CHECKS_PUBLIC_KEY=some key", + } + expectedEnvVars := &models.ChangeSettingsParams{} + expectedWarns := []string{ + `environment variable "PERCONA_TEST_CHECKS_PUBLIC_KEY" is removed and replaced by "PERCONA_TEST_PLATFORM_PUBLIC_KEY"`, } gotEnvVars, gotErrs, gotWarns := ParseEnvVars(envs) @@ -147,37 +164,13 @@ func TestEnvVarValidator(t *testing.T) { "PERCONA_TEST_AUTH_HOST", "PERCONA_TEST_CHECKS_HOST", "PERCONA_TEST_TELEMETRY_HOST", + "PERCONA_TEST_SAAS_HOST", } { - expected := fmt.Errorf(`environment variable %q is removed and replaced by "PERCONA_TEST_SAAS_HOST"`, k) + expected := fmt.Sprintf(`environment variable %q is removed and replaced by "PERCONA_TEST_PLATFORM_ADDRESS"`, k) envs := []string{k + "=host:333"} _, gotErrs, gotWarns := ParseEnvVars(envs) - assert.Equal(t, []error{expected}, gotErrs) - assert.Nil(t, gotWarns) - } - }) - - t.Run("Parse SAAS host", func(t *testing.T) { - t.Parallel() - - userCase := []struct { - value string - err string - respVal string - }{ - {value: "host", err: "", respVal: "host"}, - {value: ":111", err: `environment variable "PERCONA_TEST_SAAS_HOST" has invalid format ":111". Expected host[:port]`, respVal: ""}, - {value: "host:555", err: "", respVal: "host"}, - {value: "[2001:cafe:8221:9a0f:4dc7:4bb:8581:d186]:333", err: "", respVal: "2001:cafe:8221:9a0f:4dc7:4bb:8581:d186"}, - {value: "ho:st:444", err: "address ho:st:444: too many colons in address", respVal: ""}, - } - for _, c := range userCase { - value, err := parseSAASHost(c.value) - assert.Equal(t, c.respVal, value) - if c.err == "" { - assert.NoError(t, err) - } else { - assert.Equal(t, c.err, err.Error()) - } + assert.Equal(t, []string{expected}, gotWarns) + assert.Nil(t, gotErrs) } }) diff --git a/managed/utils/platform/client.go b/managed/utils/platform/client.go new file mode 100644 index 0000000000..00c1d70232 --- /dev/null +++ b/managed/utils/platform/client.go @@ -0,0 +1,344 @@ +// pmm-managed +// Copyright (C) 2017 Percona LLC +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package platform implements HTTP client for Percona Platform. +package platform + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + + api "github.com/percona-platform/saas/gen/check/retrieval" + reporter "github.com/percona-platform/saas/gen/telemetry/reporter" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" + "gopkg.in/reform.v1" + + "github.com/percona/pmm/managed/models" + "github.com/percona/pmm/managed/utils/envvars" + "github.com/percona/pmm/utils/tlsconfig" +) + +// Client is HTTP Percona Platform client. +// TODO: Replace this client with generated one https://jira.percona.com/browse/SAAS-956 +type Client struct { + db *reform.DB + + address string + l *logrus.Entry + client http.Client +} + +// NewClient creates new Percona Platform client. +func NewClient(db *reform.DB, address string) (*Client, error) { + l := logrus.WithField("component", "portal client") + + tlsConfig := tlsconfig.Get() + tlsConfig.InsecureSkipVerify = envvars.GetPlatformInsecure() + + return &Client{ + db: db, + l: l, + address: address, + client: http.Client{ + Timeout: envvars.GetPlatformAPITimeout(l), + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + }, + }, nil +} + +// GetChecks download checks from Percona Platform. It also validates content and checks signatures. +func (c *Client) GetChecks(ctx context.Context) (*api.GetAllChecksResponse, error) { + const path = "/v1/check/GetAllChecks" + + var accessToken string + if ssoDetails, err := models.GetPerconaSSODetails(ctx, c.db.Querier); err == nil { + accessToken = ssoDetails.AccessToken.AccessToken + } + + c.l.Infof("Downloading checks from %s ...", c.address) + bodyBytes, err := c.makeRequest(ctx, accessToken, http.MethodPost, path, nil) + if err != nil { + return nil, errors.Wrap(err, "failed to download checks") + } + + var resp api.GetAllChecksResponse + if err := json.Unmarshal(bodyBytes, &resp); err != nil { + return nil, err + } + + return &resp, nil +} + +// GetTemplates download templates from Percona Platform. It also validates content and checks signatures. +func (c *Client) GetTemplates(ctx context.Context) (*api.GetAllAlertRuleTemplatesResponse, error) { + const path = "/v1/check/GetAllAlertRuleTemplates" + + var accessToken string + if ssoDetails, err := models.GetPerconaSSODetails(ctx, c.db.Querier); err == nil { + accessToken = ssoDetails.AccessToken.AccessToken + } + + c.l.Infof("Downloading templates from %s ...", c.address) + bodyBytes, err := c.makeRequest(ctx, accessToken, http.MethodPost, path, nil) + if err != nil { + return nil, errors.Wrap(err, "failed to download checks") + } + + var resp api.GetAllAlertRuleTemplatesResponse + if err := json.Unmarshal(bodyBytes, &resp); err != nil { + return nil, err + } + + return &resp, nil +} + +// SendTelemetry sends telemetry data to Percona Platform. +func (c *Client) SendTelemetry(ctx context.Context, report *reporter.ReportRequest) error { + const path = "/v1/telemetry/Report" + + var accessToken string + if ssoDetails, err := models.GetPerconaSSODetails(ctx, c.db.Querier); err == nil { + accessToken = ssoDetails.AccessToken.AccessToken + } + + body, err := protojson.Marshal(report) + if err != nil { + return err + } + + _, err = c.makeRequest(ctx, accessToken, http.MethodPost, path, bytes.NewReader(body)) + if err != nil { + return errors.Wrap(err, "failed to send telemetry data") + } + + return nil +} + +// Connect send connect request to Percona Platform. +func (c *Client) Connect(ctx context.Context, accessToken, pmmServerID, pmmServerName, pmmServerURL, pmmServerOAuthCallbackURL string) (*ConnectPMMResponse, error) { + const path = "/v1/orgs/inventory" + + body, err := json.Marshal(struct { + PMMServerID string `json:"pmm_server_id"` + PMMServerName string `json:"pmm_server_name"` + PMMServerURL string `json:"pmm_server_url"` + PMMServerOAuthCallbackURL string `json:"pmm_server_oauth_callback_url"` + }{ + PMMServerID: pmmServerID, + PMMServerName: pmmServerName, + PMMServerURL: pmmServerURL, + PMMServerOAuthCallbackURL: pmmServerOAuthCallbackURL, + }) + if err != nil { + c.l.Errorf("Failed to marshal request data: %s", err) + return nil, err + } + + bodyBytes, err := c.makeRequest(ctx, accessToken, http.MethodPost, path, bytes.NewReader(body)) + if err != nil { + c.l.Errorf("Failed to build Connect to Platform request: %s", err) + return nil, err + } + + var resp ConnectPMMResponse + if err := json.Unmarshal(bodyBytes, &resp); err != nil { + c.l.Errorf("Failed to decode response into SSO details: %s", err) + return nil, err + } + + return &resp, nil +} + +// Disconnect send disconnect request to Percona Platform. +func (c *Client) Disconnect(ctx context.Context, accessToken, pmmServerID string) error { + const path = "/v1/orgs/inventory/%s:disconnect" + + _, err := c.makeRequest(ctx, accessToken, http.MethodPost, fmt.Sprintf(path, pmmServerID), nil) + if err != nil { + return err + } + + return nil +} + +// SearchOrgTickets searches tickets for given organization ID. +func (c *Client) SearchOrgTickets(ctx context.Context, accessToken, orgID string) (*SearchOrganizationTicketsResponse, error) { + const path = "/v1/orgs/%s/tickets:search" + + resp, err := c.makeRequest(ctx, accessToken, http.MethodPost, fmt.Sprintf(path, orgID), nil) + if err != nil { + return nil, err + } + + var res SearchOrganizationTicketsResponse + if err := json.Unmarshal(resp, &res); err != nil { + c.l.Errorf("Failed to decode response into OrganizationTickets: %s", err) + return nil, err + } + + return &res, nil +} + +// SearchOrgEntitlements searches entitlements for given organization ID. +func (c *Client) SearchOrgEntitlements(ctx context.Context, accessToken, orgID string) (*SearchOrganizationEntitlementsResponse, error) { + const path = "/v1/orgs/%s/entitlements:search" + + resp, err := c.makeRequest(ctx, accessToken, http.MethodPost, fmt.Sprintf(path, orgID), nil) + if err != nil { + return nil, err + } + + var res SearchOrganizationEntitlementsResponse + if err := json.Unmarshal(resp, &res); err != nil { + c.l.Errorf("Failed to decode response into OrganizationTickets: %s", err) + return nil, err + } + + return &res, nil +} + +// GetContactInformation returns contact information for given organization ID. +func (c *Client) GetContactInformation(ctx context.Context, accessToken, orgID string) (*ContactInformation, error) { + const path = "/v1/orgs/%s" + + resp, err := c.makeRequest(ctx, accessToken, http.MethodGet, fmt.Sprintf(path, orgID), nil) + if err != nil { + return nil, err + } + + var res ContactInformation + if err := json.Unmarshal(resp, &res); err != nil { + c.l.Errorf("Failed to decode response : %s", err) + return nil, err + } + + return &res, nil +} + +// MakeRequest makes request to Percona Platform. +func (c *Client) makeRequest(ctx context.Context, accessToken, method, path string, body io.Reader) ([]byte, error) { + endpoint := c.address + path + req, err := http.NewRequestWithContext(ctx, method, endpoint, body) + if err != nil { + return nil, err + } + + h := req.Header + h.Add("Content-Type", "application/json") + if accessToken != "" { + h.Add("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + + defer resp.Body.Close() //nolint:errcheck + + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + var gwErr struct { + Message string `json:"message"` + Code uint32 `json:"code"` + } + + if err := json.Unmarshal(bodyBytes, &gwErr); err != nil { + c.l.Errorf("Failed to dial Percona Portal and we failed to decode error message: %s", err) + return nil, err + } + return nil, status.Error(codes.Code(gwErr.Code), gwErr.Message) + } + + return bodyBytes, nil +} + +type SsoDetails struct { + GrafanaClientID string `json:"grafana_client_id"` //nolint:tagliatelle + PMMManagedClientID string `json:"pmm_managed_client_id"` //nolint:tagliatelle + PMMManagedClientSecret string `json:"pmm_managed_client_secret"` //nolint:tagliatelle + Scope string `json:"scope"` + IssuerURL string `json:"issuer_url"` //nolint:tagliatelle +} + +type ConnectPMMResponse struct { + SSODetails *SsoDetails `json:"sso_details"` + OrganizationID string `json:"org_id"` +} + +type SearchOrganizationEntitlementsResponse struct { + Entitlement []*EntitlementResponse `json:"entitlements"` +} + +type EntitlementResponse struct { + Number string `json:"number"` + Name string `json:"name"` + Summary string `json:"summary"` + Tier string `json:"tier"` + TotalUnits string `json:"total_units"` //nolint:tagliatelle + UnlimitedUnits bool `json:"unlimited_units"` //nolint:tagliatelle + SupportLevel string `json:"support_level"` //nolint:tagliatelle + SoftwareFamilies []string `json:"software_families"` //nolint:tagliatelle + StartDate string `json:"start_date"` //nolint:tagliatelle + EndDate string `json:"end_date"` //nolint:tagliatelle + Platform PlatformResponse `json:"platform"` +} + +type PlatformResponse struct { + SecurityAdvisor string `json:"security_advisor"` //nolint:tagliatelle + ConfigAdvisor string `json:"config_advisor"` //nolint:tagliatelle +} + +type SearchOrganizationTicketsResponse struct { + Tickets []*TicketResponse `json:"tickets"` +} + +type TicketResponse struct { + Number string `json:"number"` + ShortDescription string `json:"short_description"` //nolint:tagliatelle + Priority string `json:"priority"` + State string `json:"state"` + CreateTime string `json:"create_time"` //nolint:tagliatelle + Department string `json:"department"` + Requester string `json:"requestor"` + TaskType string `json:"task_type"` //nolint:tagliatelle + URL string `json:"url"` +} + +type ContactInformation struct { + Contacts struct { + CustomerSuccess struct { + Name string `json:"name"` + Email string `json:"email"` + } `json:"customer_success"` //nolint:tagliatelle + NewTicketURL string `json:"new_ticket_url"` //nolint:tagliatelle + } `json:"contacts"` +} diff --git a/managed/utils/saasreq/request.go b/managed/utils/saasreq/request.go deleted file mode 100644 index fe00c6f134..0000000000 --- a/managed/utils/saasreq/request.go +++ /dev/null @@ -1,91 +0,0 @@ -// pmm-managed -// Copyright (C) 2017 Percona LLC -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -// Package saasreq provides http/https connection setup for Percona Platform. -package saasreq - -import ( - "context" - "fmt" - "io" - "io/ioutil" - "net/http" - "net/url" - "time" - - "github.com/percona/pmm/managed/utils/envvars" - "github.com/percona/pmm/managed/utils/logger" - "github.com/percona/pmm/utils/tlsconfig" -) - -var dialTimeout time.Duration - -func init() { - l := logger.Get(logger.Set(context.Background(), "saasreq init")) - dialTimeout = envvars.GetPlatformAPITimeout(l) -} - -// SaasRequestOptions config. -type SaasRequestOptions struct { - SkipTLSVerification bool -} - -// MakeRequest creates http/https POST request to Percona Platform. -func MakeRequest(ctx context.Context, method string, endpoint, accessToken string, body io.Reader, options *SaasRequestOptions) ([]byte, error) { - if _, err := url.Parse(endpoint); err != nil { - return nil, err - } - - tlsConfig := tlsconfig.Get() - tlsConfig.InsecureSkipVerify = options.SkipTLSVerification - - ctx, cancel := context.WithTimeout(ctx, dialTimeout) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, method, endpoint, body) - if err != nil { - return nil, err - } - - h := req.Header - h.Add("Content-Type", "application/json") - if accessToken != "" { - h.Add("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - } - - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - }, - } - res, err := client.Do(req) - if err != nil { - return nil, err - } - - defer res.Body.Close() //nolint:errcheck - - bodyBytes, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, err - } - - if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("failed to dial %s, response body: %s", endpoint, bodyBytes) - } - - return bodyBytes, nil -} diff --git a/utils/pdeathsig/pdeathsig_linux.go b/utils/pdeathsig/pdeathsig_linux.go index e745164d8b..a62959c8b3 100644 --- a/utils/pdeathsig/pdeathsig_linux.go +++ b/utils/pdeathsig/pdeathsig_linux.go @@ -10,7 +10,7 @@ import ( // See http://man7.org/linux/man-pages/man2/prctl.2.html, section PR_SET_PDEATHSIG. func Set(cmd *exec.Cmd, s unix.Signal) { if cmd.SysProcAttr == nil { - cmd.SysProcAttr = new(unix.SysProcAttr) + cmd.SysProcAttr = &unix.SysProcAttr{} } cmd.SysProcAttr.Pdeathsig = s }