diff --git a/clients/ui/bff/Makefile b/clients/ui/bff/Makefile index 96f26131c..f2d67cd46 100644 --- a/clients/ui/bff/Makefile +++ b/clients/ui/bff/Makefile @@ -8,6 +8,7 @@ STANDALONE_MODE ?= true STATIC_ASSETS_DIR ?= ./static # ENVTEST_K8S_VERSION refers to the version of kubebuilder assets to be downloaded by envtest binary. ENVTEST_K8S_VERSION = 1.29.0 +LOG_LEVEL ?= info .PHONY: all all: build @@ -48,7 +49,7 @@ build: fmt vet test ## Builds the project to produce a binary executable. .PHONY: run run: fmt vet envtest ## Runs the project. ENVTEST_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" \ - go run ./cmd/main.go --port=$(PORT) --static-assets-dir=$(STATIC_ASSETS_DIR) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT) --dev-mode=$(DEV_MODE) --dev-mode-port=$(DEV_MODE_PORT) --standalone-mode=$(STANDALONE_MODE) + go run ./cmd/main.go --port=$(PORT) --static-assets-dir=$(STATIC_ASSETS_DIR) --mock-k8s-client=$(MOCK_K8S_CLIENT) --mock-mr-client=$(MOCK_MR_CLIENT) --dev-mode=$(DEV_MODE) --dev-mode-port=$(DEV_MODE_PORT) --standalone-mode=$(STANDALONE_MODE) --log-level=$(LOG_LEVEL) ##@ Dependencies diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index c6adf2512..37bda0b59 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -35,6 +35,11 @@ If you want to use a different port, mock kubernetes client or model registry cl ```shell make run PORT=8000 MOCK_K8S_CLIENT=true MOCK_MR_CLIENT=true ``` +If you want to change the log level on deployment, add the LOG_LEVEL argument when running, supported levels are: ERROR, WARN, INFO, DEBUG. The default level is INFO. +```shell +# Run with debug logging +make run LOG_LEVEL=DEBUG +``` # Building and Deploying diff --git a/clients/ui/bff/cmd/main.go b/clients/ui/bff/cmd/main.go index a38ae341c..bb2aecaf5 100644 --- a/clients/ui/bff/cmd/main.go +++ b/clients/ui/bff/cmd/main.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "os/signal" + "strings" "syscall" "github.com/kubeflow/model-registry/ui/bff/internal/api" @@ -26,9 +27,12 @@ func main() { flag.IntVar(&cfg.DevModePort, "dev-mode-port", getEnvAsInt("DEV_MODE_PORT", 8080), "Use port when in development mode") flag.BoolVar(&cfg.StandaloneMode, "standalone-mode", false, "Use standalone mode for enabling endpoints in standalone mode") flag.StringVar(&cfg.StaticAssetsDir, "static-assets-dir", "./static", "Configure frontend static assets root directory") + flag.StringVar(&cfg.LogLevel, "log-level", getEnvAsString("LOG_LEVEL", "info"), "Sets server log level, possible values: debug, info, warn, error, fatal") flag.Parse() - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: getLogLevelFromString(cfg.LogLevel), + })) app, err := api.NewApp(cfg, logger) if err != nil { @@ -88,3 +92,27 @@ func getEnvAsInt(name string, defaultVal int) int { } return defaultVal } + +func getEnvAsString(name string, defaultVal string) string { + if value, exists := os.LookupEnv(name); exists { + return value + } + return defaultVal +} + +func getLogLevelFromString(level string) slog.Level { + switch strings.ToLower(level) { + case "debug": + return slog.LevelDebug + case "info": + return slog.LevelInfo + case "warn": + return slog.LevelWarn + case "error": + return slog.LevelError + case "fatal": + return slog.LevelError + + } + return slog.LevelInfo +} diff --git a/clients/ui/bff/internal/api/app.go b/clients/ui/bff/internal/api/app.go index 129b41c8d..bdbdbad27 100644 --- a/clients/ui/bff/internal/api/app.go +++ b/clients/ui/bff/internal/api/app.go @@ -138,5 +138,5 @@ func (app *App) Routes() http.Handler { http.ServeFile(w, r, path.Join(app.config.StaticAssetsDir, "index.html")) }) - return app.RecoverPanic(app.enableCORS(app.InjectUserHeaders(appMux))) + return app.RecoverPanic(app.EnableTelemetry(app.enableCORS(app.InjectUserHeaders(appMux)))) } diff --git a/clients/ui/bff/internal/api/healthcheck__handler_test.go b/clients/ui/bff/internal/api/healthcheck__handler_test.go index 0212a58cd..e830afb21 100644 --- a/clients/ui/bff/internal/api/healthcheck__handler_test.go +++ b/clients/ui/bff/internal/api/healthcheck__handler_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "github.com/kubeflow/model-registry/ui/bff/internal/config" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/models" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" @@ -26,7 +27,7 @@ func TestHealthCheckHandler(t *testing.T) { rr := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, HealthCheckPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) req = req.WithContext(ctx) assert.NoError(t, err) diff --git a/clients/ui/bff/internal/api/healthcheck_handler.go b/clients/ui/bff/internal/api/healthcheck_handler.go index 6ee2049a2..4692d6be7 100644 --- a/clients/ui/bff/internal/api/healthcheck_handler.go +++ b/clients/ui/bff/internal/api/healthcheck_handler.go @@ -3,12 +3,13 @@ package api import ( "errors" "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "net/http" ) func (app *App) HealthcheckHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - userId, ok := r.Context().Value(KubeflowUserIdKey).(string) + userId, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || userId == "" { app.serverErrorResponse(w, r, errors.New("failed to retrieve kubeflow-userid from context")) return diff --git a/clients/ui/bff/internal/api/middleware.go b/clients/ui/bff/internal/api/middleware.go index d70fe9acb..6e4e299c5 100644 --- a/clients/ui/bff/internal/api/middleware.go +++ b/clients/ui/bff/internal/api/middleware.go @@ -4,27 +4,15 @@ import ( "context" "errors" "fmt" - "net/http" - "strings" - + "github.com/google/uuid" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/ui/bff/internal/config" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/integrations" -) - -type contextKey string - -const ( - ModelRegistryHttpClientKey contextKey = "ModelRegistryHttpClientKey" - NamespaceHeaderParameterKey contextKey = "namespace" - - //Kubeflow authorization operates using custom authentication headers: - // Note: The functionality for `kubeflow-groups` is not fully operational at Kubeflow platform at this time - // but it's supported on Model Registry BFF - KubeflowUserIdKey contextKey = "kubeflowUserId" // kubeflow-userid :contains the user's email address - KubeflowUserIDHeader = "kubeflow-userid" - KubeflowUserGroupsKey contextKey = "kubeflowUserGroups" // kubeflow-groups : Holds a comma-separated list of user groups - KubeflowUserGroupsIdHeader = "kubeflow-groups" + "log/slog" + "net/http" + "runtime/debug" + "strings" ) func (app *App) RecoverPanic(next http.Handler) http.Handler { @@ -33,6 +21,7 @@ func (app *App) RecoverPanic(next http.Handler) http.Handler { if err := recover(); err != nil { w.Header().Set("Connection", "close") app.serverErrorResponse(w, r, fmt.Errorf("%s", err)) + app.logger.Error("Recover from panic: " + string(debug.Stack())) } }() @@ -49,8 +38,8 @@ func (app *App) InjectUserHeaders(next http.Handler) http.Handler { return } - userIdHeader := r.Header.Get(KubeflowUserIDHeader) - userGroupsHeader := r.Header.Get(KubeflowUserGroupsIdHeader) + userIdHeader := r.Header.Get(constants.KubeflowUserIDHeader) + userGroupsHeader := r.Header.Get(constants.KubeflowUserGroupsIdHeader) //`kubeflow-userid`: Contains the user's email address. if userIdHeader == "" { app.badRequestResponse(w, r, errors.New("missing required header: kubeflow-userid")) @@ -70,8 +59,8 @@ func (app *App) InjectUserHeaders(next http.Handler) http.Handler { } ctx := r.Context() - ctx = context.WithValue(ctx, KubeflowUserIdKey, userIdHeader) - ctx = context.WithValue(ctx, KubeflowUserGroupsKey, userGroups) + ctx = context.WithValue(ctx, constants.KubeflowUserIdKey, userIdHeader) + ctx = context.WithValue(ctx, constants.KubeflowUserGroupsKey, userGroups) next.ServeHTTP(w, r.WithContext(ctx)) }) @@ -87,35 +76,72 @@ func (app *App) enableCORS(next http.Handler) http.Handler { }) } +func (app *App) EnableTelemetry(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Adds a unique id to the context to allow tracing of requests + traceId := uuid.NewString() + ctx := context.WithValue(r.Context(), constants.TraceIdKey, traceId) + + // logger will only be nil in tests. + if app.logger != nil { + traceLogger := app.logger.With(slog.String("trace_id", traceId)) + ctx = context.WithValue(ctx, constants.TraceLoggerKey, traceLogger) + + if traceLogger.Enabled(ctx, slog.LevelDebug) { + cloneBody, err := integrations.CloneBody(r) + if err != nil { + traceLogger.Debug("Error reading request body for debug logging", "error", err) + } + ////TODO (Alex) Log headers, BUT we must ensure we don't log confidential data like tokens etc. + traceLogger.Debug("Incoming HTTP request", "method", r.Method, "url", r.URL.String(), "body", cloneBody) + } + } + + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func (app *App) AttachRESTClient(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { modelRegistryID := ps.ByName(ModelRegistryId) - namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + namespace, ok := r.Context().Value(constants.NamespaceHeaderParameterKey).(string) if !ok || namespace == "" { app.badRequestResponse(w, r, fmt.Errorf("missing namespace in the context")) } - modelRegistryBaseURL, err := resolveModelRegistryURL(namespace, modelRegistryID, app.kubernetesClient, app.config) + modelRegistryBaseURL, err := resolveModelRegistryURL(r.Context(), namespace, modelRegistryID, app.kubernetesClient, app.config) if err != nil { app.notFoundResponse(w, r) return } - client, err := integrations.NewHTTPClient(modelRegistryID, modelRegistryBaseURL) + // Set up a child logger for the rest client that automatically adds the request id to all statements for + // tracing. + restClientLogger := app.logger + traceId, ok := r.Context().Value(constants.TraceIdKey).(string) + if app.logger != nil { + if ok { + restClientLogger = app.logger.With(slog.String("trace_id", traceId)) + } else { + app.logger.Warn("Failed to set trace_id for tracing") + } + } + + client, err := integrations.NewHTTPClient(restClientLogger, modelRegistryID, modelRegistryBaseURL) if err != nil { app.serverErrorResponse(w, r, fmt.Errorf("failed to create Kubernetes client: %v", err)) return } - ctx := context.WithValue(r.Context(), ModelRegistryHttpClientKey, client) + ctx := context.WithValue(r.Context(), constants.ModelRegistryHttpClientKey, client) next(w, r.WithContext(ctx), ps) } } -func resolveModelRegistryURL(namespace string, serviceName string, client integrations.KubernetesClientInterface, config config.EnvConfig) (string, error) { +func resolveModelRegistryURL(sessionCtx context.Context, namespace string, serviceName string, client integrations.KubernetesClientInterface, config config.EnvConfig) (string, error) { - serviceDetails, err := client.GetServiceDetailsByName(namespace, serviceName) + serviceDetails, err := client.GetServiceDetailsByName(sessionCtx, namespace, serviceName) if err != nil { return "", err } @@ -131,13 +157,13 @@ func resolveModelRegistryURL(namespace string, serviceName string, client integr func (app *App) AttachNamespace(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - namespace := r.URL.Query().Get(string(NamespaceHeaderParameterKey)) + namespace := r.URL.Query().Get(string(constants.NamespaceHeaderParameterKey)) if namespace == "" { - app.badRequestResponse(w, r, fmt.Errorf("missing required query parameter: %s", NamespaceHeaderParameterKey)) + app.badRequestResponse(w, r, fmt.Errorf("missing required query parameter: %s", constants.NamespaceHeaderParameterKey)) return } - ctx := context.WithValue(r.Context(), NamespaceHeaderParameterKey, namespace) + ctx := context.WithValue(r.Context(), constants.NamespaceHeaderParameterKey, namespace) r = r.WithContext(ctx) next(w, r, ps) @@ -146,19 +172,19 @@ func (app *App) AttachNamespace(next func(http.ResponseWriter, *http.Request, ht func (app *App) PerformSARonGetListServicesByNamespace(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - user, ok := r.Context().Value(KubeflowUserIdKey).(string) + user, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || user == "" { app.badRequestResponse(w, r, fmt.Errorf("missing user in context")) return } - namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + namespace, ok := r.Context().Value(constants.NamespaceHeaderParameterKey).(string) if !ok || namespace == "" { app.badRequestResponse(w, r, fmt.Errorf("missing namespace in context")) return } var userGroups []string - if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + if groups, ok := r.Context().Value(constants.KubeflowUserGroupsKey).([]string); ok { userGroups = groups } else { userGroups = []string{} @@ -181,13 +207,13 @@ func (app *App) PerformSARonGetListServicesByNamespace(next func(http.ResponseWr func (app *App) PerformSARonSpecificService(next func(http.ResponseWriter, *http.Request, httprouter.Params)) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - user, ok := r.Context().Value(KubeflowUserIdKey).(string) + user, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || user == "" { app.badRequestResponse(w, r, fmt.Errorf("missing user in context")) return } - namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + namespace, ok := r.Context().Value(constants.NamespaceHeaderParameterKey).(string) if !ok || namespace == "" { app.badRequestResponse(w, r, fmt.Errorf("missing namespace in context")) return @@ -200,7 +226,7 @@ func (app *App) PerformSARonSpecificService(next func(http.ResponseWriter, *http } var userGroups []string - if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + if groups, ok := r.Context().Value(constants.KubeflowUserGroupsKey).([]string); ok { userGroups = groups } else { userGroups = []string{} diff --git a/clients/ui/bff/internal/api/model_registry_handler.go b/clients/ui/bff/internal/api/model_registry_handler.go index 1600d0045..d5ee69e0c 100644 --- a/clients/ui/bff/internal/api/model_registry_handler.go +++ b/clients/ui/bff/internal/api/model_registry_handler.go @@ -3,6 +3,7 @@ package api import ( "fmt" "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/models" "net/http" ) @@ -11,12 +12,12 @@ type ModelRegistryListEnvelope Envelope[[]models.ModelRegistryModel, None] func (app *App) ModelRegistryHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - namespace, ok := r.Context().Value(NamespaceHeaderParameterKey).(string) + namespace, ok := r.Context().Value(constants.NamespaceHeaderParameterKey).(string) if !ok || namespace == "" { app.badRequestResponse(w, r, fmt.Errorf("missing namespace in the context")) } - registries, err := app.repositories.ModelRegistry.GetAllModelRegistries(app.kubernetesClient, namespace) + registries, err := app.repositories.ModelRegistry.GetAllModelRegistries(r.Context(), app.kubernetesClient, namespace) if err != nil { app.serverErrorResponse(w, r, err) return diff --git a/clients/ui/bff/internal/api/model_registry_handler_test.go b/clients/ui/bff/internal/api/model_registry_handler_test.go index 872121ce0..2e37d080c 100644 --- a/clients/ui/bff/internal/api/model_registry_handler_test.go +++ b/clients/ui/bff/internal/api/model_registry_handler_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" + "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/models" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" . "github.com/onsi/ginkgo/v2" @@ -28,7 +30,8 @@ var _ = Describe("TestModelRegistryHandler", func() { requestPath := fmt.Sprintf(" %s?namespace=kubeflow", ModelRegistryListPath) req, err := http.NewRequest(http.MethodGet, requestPath, nil) - ctx := context.WithValue(req.Context(), NamespaceHeaderParameterKey, "kubeflow") + ctx := mocks.NewMockSessionContext(req.Context()) + ctx = context.WithValue(ctx, constants.NamespaceHeaderParameterKey, "kubeflow") req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) diff --git a/clients/ui/bff/internal/api/model_versions_handler.go b/clients/ui/bff/internal/api/model_versions_handler.go index 7c74377bf..27eedd61c 100644 --- a/clients/ui/bff/internal/api/model_versions_handler.go +++ b/clients/ui/bff/internal/api/model_versions_handler.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/validation" "net/http" @@ -19,7 +20,7 @@ type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None] type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None] func (app *App) GetAllModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -43,7 +44,7 @@ func (app *App) GetAllModelVersionHandler(w http.ResponseWriter, r *http.Request } func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -71,7 +72,7 @@ func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, p } func (app *App) CreateModelVersionHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -125,7 +126,7 @@ func (app *App) CreateModelVersionHandler(w http.ResponseWriter, r *http.Request } func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -175,7 +176,7 @@ func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request } func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -198,7 +199,7 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, } func (app *App) CreateModelArtifactByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return diff --git a/clients/ui/bff/internal/api/namespaces_handler.go b/clients/ui/bff/internal/api/namespaces_handler.go index 80fb47014..88550531a 100644 --- a/clients/ui/bff/internal/api/namespaces_handler.go +++ b/clients/ui/bff/internal/api/namespaces_handler.go @@ -2,6 +2,7 @@ package api import ( "errors" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/models" "net/http" @@ -12,14 +13,14 @@ type NamespacesEnvelope Envelope[[]models.NamespaceModel, None] func (app *App) GetNamespacesHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - userId, ok := r.Context().Value(KubeflowUserIdKey).(string) + userId, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || userId == "" { app.serverErrorResponse(w, r, errors.New("failed to retrieve kubeflow-userid from context")) return } var userGroups []string - if groups, ok := r.Context().Value(KubeflowUserGroupsKey).([]string); ok { + if groups, ok := r.Context().Value(constants.KubeflowUserGroupsKey).([]string); ok { userGroups = groups } else { userGroups = []string{} diff --git a/clients/ui/bff/internal/api/namespaces_handler_test.go b/clients/ui/bff/internal/api/namespaces_handler_test.go index b4869058f..505956975 100644 --- a/clients/ui/bff/internal/api/namespaces_handler_test.go +++ b/clients/ui/bff/internal/api/namespaces_handler_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "github.com/kubeflow/model-registry/ui/bff/internal/config" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/models" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" @@ -31,7 +32,7 @@ var _ = Describe("TestNamespacesHandler", func() { It("should return only dora-namespace for doraNonAdmin@example.com", func() { By("creating the HTTP request with the kubeflow-userid header") req, err := http.NewRequest(http.MethodGet, NamespaceListPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, mocks.DoraNonAdminUser) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, mocks.DoraNonAdminUser) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) rr := httptest.NewRecorder() @@ -57,7 +58,7 @@ var _ = Describe("TestNamespacesHandler", func() { It("should return all namespaces for user@example.com", func() { By("creating the HTTP request with the kubeflow-userid header") req, err := http.NewRequest(http.MethodGet, NamespaceListPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) req.Header.Set("kubeflow-userid", "user@example.com") @@ -87,7 +88,7 @@ var _ = Describe("TestNamespacesHandler", func() { It("should return no namespaces for non-existent user", func() { By("creating the HTTP request with a non-existent kubeflow-userid") req, err := http.NewRequest(http.MethodGet, NamespaceListPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, "nonexistent@example.com") + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, "nonexistent@example.com") req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) rr := httptest.NewRecorder() diff --git a/clients/ui/bff/internal/api/registered_models_handler.go b/clients/ui/bff/internal/api/registered_models_handler.go index 7f781f69b..ccd2469dd 100644 --- a/clients/ui/bff/internal/api/registered_models_handler.go +++ b/clients/ui/bff/internal/api/registered_models_handler.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/julienschmidt/httprouter" "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/validation" "net/http" @@ -16,7 +17,7 @@ type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None] type RegisteredModelUpdateEnvelope Envelope[*openapi.RegisteredModelUpdate, None] func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -39,7 +40,7 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req } func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -93,7 +94,7 @@ func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ } func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -121,7 +122,7 @@ func (app *App) GetRegisteredModelHandler(w http.ResponseWriter, r *http.Request } func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -171,7 +172,7 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ } func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return @@ -195,7 +196,7 @@ func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWrit } func (app *App) CreateModelVersionForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - client, ok := r.Context().Value(ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) + client, ok := r.Context().Value(constants.ModelRegistryHttpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return diff --git a/clients/ui/bff/internal/api/test_utils.go b/clients/ui/bff/internal/api/test_utils.go index 952a03337..1f8f0cc93 100644 --- a/clients/ui/bff/internal/api/test_utils.go +++ b/clients/ui/bff/internal/api/test_utils.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" k8s "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/repositories" @@ -46,16 +47,16 @@ func setupApiTest[T any](method string, url string, body interface{}, k8sClient } // Set the kubeflow-userid header - req.Header.Set(KubeflowUserIDHeader, kubeflowUserIDHeaderValue) + req.Header.Set(constants.KubeflowUserIDHeader, kubeflowUserIDHeaderValue) - ctx := req.Context() - ctx = context.WithValue(ctx, ModelRegistryHttpClientKey, mockClient) - ctx = context.WithValue(ctx, KubeflowUserIdKey, kubeflowUserIDHeaderValue) - ctx = context.WithValue(ctx, NamespaceHeaderParameterKey, namespace) + ctx := mocks.NewMockSessionContext(req.Context()) + ctx = context.WithValue(ctx, constants.ModelRegistryHttpClientKey, mockClient) + ctx = context.WithValue(ctx, constants.KubeflowUserIdKey, kubeflowUserIDHeaderValue) + ctx = context.WithValue(ctx, constants.NamespaceHeaderParameterKey, namespace) mrHttpClient := k8s.HTTPClient{ ModelRegistryID: "model-registry", } - ctx = context.WithValue(ctx, ModelRegistryHttpClientKey, mrHttpClient) + ctx = context.WithValue(ctx, constants.ModelRegistryHttpClientKey, mrHttpClient) req = req.WithContext(ctx) rr := httptest.NewRecorder() diff --git a/clients/ui/bff/internal/api/user_handler.go b/clients/ui/bff/internal/api/user_handler.go index 9ec135ccc..1359ab18b 100644 --- a/clients/ui/bff/internal/api/user_handler.go +++ b/clients/ui/bff/internal/api/user_handler.go @@ -3,6 +3,7 @@ package api import ( "errors" "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/models" "net/http" ) @@ -11,7 +12,7 @@ type UserEnvelope Envelope[*models.User, None] func (app *App) UserHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - userId, ok := r.Context().Value(KubeflowUserIdKey).(string) + userId, ok := r.Context().Value(constants.KubeflowUserIdKey).(string) if !ok || userId == "" { app.serverErrorResponse(w, r, errors.New("failed to retrieve kubeflow-userid from context")) return diff --git a/clients/ui/bff/internal/api/user_handler_test.go b/clients/ui/bff/internal/api/user_handler_test.go index 13cbf95a8..2927fa101 100644 --- a/clients/ui/bff/internal/api/user_handler_test.go +++ b/clients/ui/bff/internal/api/user_handler_test.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "io" "net/http" @@ -34,7 +35,7 @@ var _ = Describe("TestUserHandler", func() { It("should show that KubeflowUserIDHeaderValue (user@example.com) is a cluster-admin", func() { By("creating the http request") req, err := http.NewRequest(http.MethodGet, UserPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, mocks.KubeflowUserIDHeaderValue) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) @@ -62,7 +63,7 @@ var _ = Describe("TestUserHandler", func() { It("should show that DoraNonAdminUser (doraNonAdmin@example.com) is not a cluster-admin", func() { By("creating the http request") req, err := http.NewRequest(http.MethodGet, UserPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, DoraNonAdminUser) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, DoraNonAdminUser) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) @@ -92,7 +93,7 @@ var _ = Describe("TestUserHandler", func() { By("creating the http request") req, err := http.NewRequest(http.MethodGet, UserPath, nil) - ctx := context.WithValue(req.Context(), KubeflowUserIdKey, randomUser) + ctx := context.WithValue(req.Context(), constants.KubeflowUserIdKey, randomUser) req = req.WithContext(ctx) Expect(err).NotTo(HaveOccurred()) diff --git a/clients/ui/bff/internal/config/environment.go b/clients/ui/bff/internal/config/environment.go index 6017701c8..80449363e 100644 --- a/clients/ui/bff/internal/config/environment.go +++ b/clients/ui/bff/internal/config/environment.go @@ -8,4 +8,5 @@ type EnvConfig struct { StandaloneMode bool DevModePort int StaticAssetsDir string + LogLevel string } diff --git a/clients/ui/bff/internal/constants/keys.go b/clients/ui/bff/internal/constants/keys.go new file mode 100644 index 000000000..3051679f5 --- /dev/null +++ b/clients/ui/bff/internal/constants/keys.go @@ -0,0 +1,19 @@ +package constants + +type contextKey string + +const ( + ModelRegistryHttpClientKey contextKey = "ModelRegistryHttpClientKey" + NamespaceHeaderParameterKey contextKey = "namespace" + + //Kubeflow authorization operates using custom authentication headers: + // Note: The functionality for `kubeflow-groups` is not fully operational at Kubeflow platform at this time + // but it's supported on Model Registry BFF + KubeflowUserIdKey contextKey = "kubeflowUserId" // kubeflow-userid :contains the user's email address + KubeflowUserIDHeader = "kubeflow-userid" + KubeflowUserGroupsKey contextKey = "kubeflowUserGroups" // kubeflow-groups : Holds a comma-separated list of user groups + KubeflowUserGroupsIdHeader = "kubeflow-groups" + + TraceIdKey contextKey = "TraceIdKey" + TraceLoggerKey contextKey = "TraceLoggerKey" +) diff --git a/clients/ui/bff/internal/integrations/http.go b/clients/ui/bff/internal/integrations/http.go index 712b8556f..fea2bf54e 100644 --- a/clients/ui/bff/internal/integrations/http.go +++ b/clients/ui/bff/internal/integrations/http.go @@ -1,16 +1,18 @@ package integrations import ( + "context" "crypto/tls" "encoding/json" "fmt" + "github.com/google/uuid" "io" + "log/slog" "net/http" "strconv" ) type HTTPClientInterface interface { - GetModelRegistryID() (modelRegistryService string) GET(url string) ([]byte, error) POST(url string, body io.Reader) ([]byte, error) PATCH(url string, body io.Reader) ([]byte, error) @@ -20,6 +22,7 @@ type HTTPClient struct { client *http.Client baseURL string ModelRegistryID string + logger *slog.Logger } type ErrorResponse struct { @@ -36,7 +39,7 @@ func (e *HTTPError) Error() string { return fmt.Sprintf("HTTP %d: %s - %s", e.StatusCode, e.Code, e.Message) } -func NewHTTPClient(modelRegistryID string, baseURL string) (HTTPClientInterface, error) { +func NewHTTPClient(logger *slog.Logger, modelRegistryID string, baseURL string) (HTTPClientInterface, error) { return &HTTPClient{ client: &http.Client{Transport: &http.Transport{ @@ -44,6 +47,7 @@ func NewHTTPClient(modelRegistryID string, baseURL string) (HTTPClientInterface, }}, baseURL: baseURL, ModelRegistryID: modelRegistryID, + logger: logger, }, nil } @@ -52,12 +56,16 @@ func (c *HTTPClient) GetModelRegistryID() string { } func (c *HTTPClient) GET(url string) ([]byte, error) { + requestId := uuid.NewString() + fullURL := c.baseURL + url req, err := http.NewRequest("GET", fullURL, nil) if err != nil { return nil, err } + logUpstreamReq(c.logger, requestId, req) + response, err := c.client.Do(req) if err != nil { return nil, err @@ -65,6 +73,7 @@ func (c *HTTPClient) GET(url string) ([]byte, error) { defer response.Body.Close() body, err := io.ReadAll(response.Body) + logUpstreamResp(c.logger, requestId, response, body) if err != nil { return nil, fmt.Errorf("error reading response body: %w", err) } @@ -72,6 +81,8 @@ func (c *HTTPClient) GET(url string) ([]byte, error) { } func (c *HTTPClient) POST(url string, body io.Reader) ([]byte, error) { + requestId := uuid.NewString() + fullURL := c.baseURL + url fmt.Println(fullURL) req, err := http.NewRequest("POST", fullURL, body) @@ -81,6 +92,8 @@ func (c *HTTPClient) POST(url string, body io.Reader) ([]byte, error) { req.Header.Set("Content-Type", "application/json") + logUpstreamReq(c.logger, requestId, req) + response, err := c.client.Do(req) if err != nil { return nil, err @@ -88,6 +101,7 @@ func (c *HTTPClient) POST(url string, body io.Reader) ([]byte, error) { defer response.Body.Close() responseBody, err := io.ReadAll(response.Body) + logUpstreamResp(c.logger, requestId, response, responseBody) if err != nil { return nil, fmt.Errorf("error reading response body: %w", err) } @@ -120,8 +134,12 @@ func (c *HTTPClient) PATCH(url string, body io.Reader) ([]byte, error) { return nil, err } + requestId := uuid.NewString() + req.Header.Set("Content-Type", "application/json") + logUpstreamReq(c.logger, requestId, req) + response, err := c.client.Do(req) if err != nil { return nil, err @@ -129,6 +147,7 @@ func (c *HTTPClient) PATCH(url string, body io.Reader) ([]byte, error) { defer response.Body.Close() responseBody, err := io.ReadAll(response.Body) + logUpstreamResp(c.logger, requestId, response, responseBody) if err != nil { return nil, fmt.Errorf("error reading response body: %w", err) } @@ -152,3 +171,19 @@ func (c *HTTPClient) PATCH(url string, body io.Reader) ([]byte, error) { } return responseBody, nil } + +func logUpstreamReq(logger *slog.Logger, reqId string, req *http.Request) { + if logger.Enabled(context.TODO(), slog.LevelDebug) { + var body []byte + if req.Body != nil { + body, _ = CloneBody(req) + } + logger.Debug("Making upstream HTTP request", "request_id", reqId, "method", req.Method, "url", req.URL.String(), "body", body) + } +} + +func logUpstreamResp(logger *slog.Logger, reqId string, resp *http.Response, body []byte) { + if logger.Enabled(context.TODO(), slog.LevelDebug) { + logger.Debug("Received upstream HTTP response", "request_id", reqId, "status_code", resp.StatusCode, "body", body) + } +} diff --git a/clients/ui/bff/internal/integrations/http_helpers.go b/clients/ui/bff/internal/integrations/http_helpers.go new file mode 100644 index 000000000..214f9efba --- /dev/null +++ b/clients/ui/bff/internal/integrations/http_helpers.go @@ -0,0 +1,23 @@ +package integrations + +import ( + "bytes" + "fmt" + "io" + "net/http" +) + +func CloneBody(r *http.Request) ([]byte, error) { + if r.Body == nil { + return nil, fmt.Errorf("no body provided") + } + buf, _ := io.ReadAll(r.Body) + readerCopy := io.NopCloser(bytes.NewBuffer(buf)) + readerOriginal := io.NopCloser(bytes.NewBuffer(buf)) + r.Body = readerOriginal + + defer readerCopy.Close() + cloneBody, err := io.ReadAll(readerCopy) + + return cloneBody, err +} diff --git a/clients/ui/bff/internal/integrations/k8s.go b/clients/ui/bff/internal/integrations/k8s.go index 9b89bb02c..19e52ca51 100644 --- a/clients/ui/bff/internal/integrations/k8s.go +++ b/clients/ui/bff/internal/integrations/k8s.go @@ -3,6 +3,7 @@ package integrations import ( "context" "fmt" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" helper "github.com/kubeflow/model-registry/ui/bff/internal/helpers" authv1 "k8s.io/api/authorization/v1" corev1 "k8s.io/api/core/v1" @@ -22,9 +23,9 @@ import ( const ComponentLabelValue = "model-registry" type KubernetesClientInterface interface { - GetServiceNames(namespace string) ([]string, error) - GetServiceDetailsByName(namespace string, serviceName string) (ServiceDetails, error) - GetServiceDetails(namespace string) ([]ServiceDetails, error) + GetServiceNames(sessionCtx context.Context, namespace string) ([]string, error) + GetServiceDetailsByName(sessionCtx context.Context, namespace string, serviceName string) (ServiceDetails, error) + GetServiceDetails(sessionCtx context.Context, namespace string) ([]ServiceDetails, error) BearerToken() (string, error) Shutdown(ctx context.Context, logger *slog.Logger) error IsInCluster() bool @@ -152,8 +153,8 @@ func (kc *KubernetesClient) BearerToken() (string, error) { return kc.Token, nil } -func (kc *KubernetesClient) GetServiceNames(namespace string) ([]string, error) { - services, err := kc.GetServiceDetails(namespace) +func (kc *KubernetesClient) GetServiceNames(sessionCtx context.Context, namespace string) ([]string, error) { + services, err := kc.GetServiceDetails(sessionCtx, namespace) if err != nil { return nil, err } @@ -166,7 +167,7 @@ func (kc *KubernetesClient) GetServiceNames(namespace string) ([]string, error) return names, nil } -func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetails, error) { +func (kc *KubernetesClient) GetServiceDetails(sessionCtx context.Context, namespace string) ([]ServiceDetails, error) { if namespace == "" { return nil, fmt.Errorf("namespace cannot be empty") @@ -175,6 +176,8 @@ func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetail ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + sessionLogger := sessionCtx.Value(constants.TraceLoggerKey).(*slog.Logger) + serviceList := &corev1.ServiceList{} labelSelector := labels.SelectorFromSet(labels.Set{ @@ -202,12 +205,12 @@ func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetail } } if !hasHTTPPort { - kc.Logger.Error("service missing HTTP port", "serviceName", service.Name) + sessionLogger.Error("service missing HTTP port", "serviceName", service.Name) continue } if service.Spec.ClusterIP == "" { - kc.Logger.Error("service missing valid ClusterIP", "serviceName", service.Name) + sessionLogger.Error("service missing valid ClusterIP", "serviceName", service.Name) continue } @@ -220,11 +223,11 @@ func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetail } if displayName == "" { - kc.Logger.Warn("service missing displayName annotation", "serviceName", service.Name) + sessionLogger.Warn("service missing displayName annotation", "serviceName", service.Name) } if description == "" { - kc.Logger.Warn("service missing description annotation", "serviceName", service.Name) + sessionLogger.Warn("service missing description annotation", "serviceName", service.Name) } serviceDetails := ServiceDetails{ @@ -242,8 +245,8 @@ func (kc *KubernetesClient) GetServiceDetails(namespace string) ([]ServiceDetail return services, nil } -func (kc *KubernetesClient) GetServiceDetailsByName(namespace string, serviceName string) (ServiceDetails, error) { - services, err := kc.GetServiceDetails(namespace) +func (kc *KubernetesClient) GetServiceDetailsByName(sessionCtx context.Context, namespace string, serviceName string) (ServiceDetails, error) { + services, err := kc.GetServiceDetails(sessionCtx, namespace) if err != nil { return ServiceDetails{}, fmt.Errorf("failed to get service details: %w", err) } diff --git a/clients/ui/bff/internal/mocks/k8s_mock.go b/clients/ui/bff/internal/mocks/k8s_mock.go index ce2b3e61a..7a74e65df 100644 --- a/clients/ui/bff/internal/mocks/k8s_mock.go +++ b/clients/ui/bff/internal/mocks/k8s_mock.go @@ -185,8 +185,8 @@ func setupMock(mockK8sClient client.Client, ctx context.Context) error { return nil } -func (m *KubernetesClientMock) GetServiceDetails(namespace string) ([]k8s.ServiceDetails, error) { - originalServices, err := m.KubernetesClient.GetServiceDetails(namespace) +func (m *KubernetesClientMock) GetServiceDetails(sessionCtx context.Context, namespace string) ([]k8s.ServiceDetails, error) { + originalServices, err := m.KubernetesClient.GetServiceDetails(sessionCtx, namespace) if err != nil { return nil, fmt.Errorf("failed to get service details: %w", err) } @@ -199,8 +199,8 @@ func (m *KubernetesClientMock) GetServiceDetails(namespace string) ([]k8s.Servic return originalServices, nil } -func (m *KubernetesClientMock) GetServiceDetailsByName(namespace string, serviceName string) (k8s.ServiceDetails, error) { - originalService, err := m.KubernetesClient.GetServiceDetailsByName(namespace, serviceName) +func (m *KubernetesClientMock) GetServiceDetailsByName(sessionCtx context.Context, namespace string, serviceName string) (k8s.ServiceDetails, error) { + originalService, err := m.KubernetesClient.GetServiceDetailsByName(sessionCtx, namespace, serviceName) if err != nil { return k8s.ServiceDetails{}, fmt.Errorf("failed to get service details: %w", err) } diff --git a/clients/ui/bff/internal/mocks/k8s_mock_test.go b/clients/ui/bff/internal/mocks/k8s_mock_test.go index e236326aa..3f84783fa 100644 --- a/clients/ui/bff/internal/mocks/k8s_mock_test.go +++ b/clients/ui/bff/internal/mocks/k8s_mock_test.go @@ -11,7 +11,7 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the get all service successfully", func() { By("getting service details") - services, err := k8sClient.GetServiceDetails("kubeflow") + services, err := k8sClient.GetServiceDetails(NewMockSessionContextNoParent(), "kubeflow") Expect(err).NotTo(HaveOccurred(), "Failed to create HTTP request") By("checking that all services have the modified ClusterIP and HTTPPort") @@ -37,7 +37,7 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the service details by name", func() { By("getting service by name") - service, err := k8sClient.GetServiceDetailsByName("dora-namespace", "model-registry-dora") + service, err := k8sClient.GetServiceDetailsByName(NewMockSessionContextNoParent(), "dora-namespace", "model-registry-dora") Expect(err).NotTo(HaveOccurred(), "Failed to create k8s request") By("checking that service details are correct") @@ -49,7 +49,7 @@ var _ = Describe("Kubernetes ControllerRuntimeClient Test", func() { It("should retrieve the services names", func() { By("getting service by name") - services, err := k8sClient.GetServiceNames("kubeflow") + services, err := k8sClient.GetServiceNames(NewMockSessionContextNoParent(), "kubeflow") Expect(err).NotTo(HaveOccurred(), "Failed to create HTTP request") By("checking that service details are correct") diff --git a/clients/ui/bff/internal/mocks/static_data_mock.go b/clients/ui/bff/internal/mocks/static_data_mock.go index bb0408c27..252a285e8 100644 --- a/clients/ui/bff/internal/mocks/static_data_mock.go +++ b/clients/ui/bff/internal/mocks/static_data_mock.go @@ -1,7 +1,12 @@ package mocks import ( + "context" + "github.com/google/uuid" "github.com/kubeflow/model-registry/pkg/openapi" + "github.com/kubeflow/model-registry/ui/bff/internal/constants" + "log/slog" + "os" ) func GetRegisteredModelMocks() []openapi.RegisteredModel { @@ -200,3 +205,19 @@ func newCustomProperties() *map[string]openapi.MetadataValue { return &result } + +func NewMockSessionContext(parent context.Context) context.Context { + if parent == nil { + parent = context.TODO() + } + traceId := uuid.NewString() + ctx := context.WithValue(parent, constants.TraceIdKey, traceId) + + traceLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ctx = context.WithValue(ctx, constants.TraceLoggerKey, traceLogger) + return ctx +} + +func NewMockSessionContextNoParent() context.Context { + return NewMockSessionContext(context.TODO()) +} diff --git a/clients/ui/bff/internal/repositories/model_registry.go b/clients/ui/bff/internal/repositories/model_registry.go index db4175952..60ec9bbb0 100644 --- a/clients/ui/bff/internal/repositories/model_registry.go +++ b/clients/ui/bff/internal/repositories/model_registry.go @@ -1,6 +1,7 @@ package repositories import ( + "context" "fmt" k8s "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/models" @@ -13,9 +14,9 @@ func NewModelRegistryRepository() *ModelRegistryRepository { return &ModelRegistryRepository{} } -func (m *ModelRegistryRepository) GetAllModelRegistries(client k8s.KubernetesClientInterface, namespace string) ([]models.ModelRegistryModel, error) { +func (m *ModelRegistryRepository) GetAllModelRegistries(sessionCtx context.Context, client k8s.KubernetesClientInterface, namespace string) ([]models.ModelRegistryModel, error) { - resources, err := client.GetServiceDetails(namespace) + resources, err := client.GetServiceDetails(sessionCtx, namespace) if err != nil { return nil, fmt.Errorf("error fetching model registries: %w", err) } diff --git a/clients/ui/bff/internal/repositories/model_registry_test.go b/clients/ui/bff/internal/repositories/model_registry_test.go index a5a0d903b..06ef12621 100644 --- a/clients/ui/bff/internal/repositories/model_registry_test.go +++ b/clients/ui/bff/internal/repositories/model_registry_test.go @@ -1,6 +1,7 @@ package repositories import ( + "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/kubeflow/model-registry/ui/bff/internal/models" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -13,7 +14,7 @@ var _ = Describe("TestFetchAllModelRegistry", func() { By("fetching all model registries in the repository") modelRegistryRepository := NewModelRegistryRepository() - registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "kubeflow") + registries, err := modelRegistryRepository.GetAllModelRegistries(mocks.NewMockSessionContextNoParent(), k8sClient, "kubeflow") Expect(err).NotTo(HaveOccurred()) By("should match the expected model registries") @@ -28,7 +29,7 @@ var _ = Describe("TestFetchAllModelRegistry", func() { By("fetching all model registries in the repository") modelRegistryRepository := NewModelRegistryRepository() - registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "dora-namespace") + registries, err := modelRegistryRepository.GetAllModelRegistries(mocks.NewMockSessionContextNoParent(), k8sClient, "dora-namespace") Expect(err).NotTo(HaveOccurred()) By("should match the expected model registries") @@ -42,7 +43,7 @@ var _ = Describe("TestFetchAllModelRegistry", func() { By("fetching all model registries in the repository") modelRegistryRepository := NewModelRegistryRepository() - registries, err := modelRegistryRepository.GetAllModelRegistries(k8sClient, "no-namespace") + registries, err := modelRegistryRepository.GetAllModelRegistries(mocks.NewMockSessionContextNoParent(), k8sClient, "no-namespace") Expect(err).NotTo(HaveOccurred()) By("should be empty")