From 0f3c8d164d0defdbb2b4b108086ee1f04bd9422e Mon Sep 17 00:00:00 2001 From: Tolga Ozen Date: Tue, 28 Jan 2025 16:20:14 +0300 Subject: [PATCH 1/3] refactor: enhance distributed consistent hashing configuration and error handling --- cmd/permify/permify.go | 4 +- docs/api-reference/apidocs.swagger.json | 2 +- .../openapiv2/apidocs.swagger.json | 2 +- internal/config/config.go | 10 +- internal/engines/balancer/balancer.go | 46 +- internal/info.go | 2 +- pkg/balancer/balancer.go | 250 +++++++++ pkg/balancer/balancer_test.go | 2 + pkg/balancer/builder.go | 155 +++++- pkg/balancer/builder_test.go | 47 +- pkg/balancer/errors.go | 12 - pkg/balancer/hashring.go | 345 ------------ pkg/balancer/manager.go | 95 ---- pkg/balancer/manager_test.go | 89 ---- pkg/balancer/picker.go | 98 ++-- pkg/balancer/picker_test.go | 78 --- pkg/balancer/queue.go | 53 -- pkg/balancer/queue_test.go | 61 --- pkg/cmd/config.go | 4 + pkg/cmd/flags/serve.go | 28 + pkg/cmd/serve.go | 14 +- pkg/consistent/consistent.go | 492 ++++++++++++++++++ pkg/consistent/consistent_test.go | 129 +++++ pkg/pb/base/v1/openapi.pb.go | 2 +- proto/base/v1/openapi.proto | 2 +- 25 files changed, 1147 insertions(+), 875 deletions(-) create mode 100644 pkg/balancer/balancer.go delete mode 100644 pkg/balancer/errors.go delete mode 100644 pkg/balancer/hashring.go delete mode 100644 pkg/balancer/manager.go delete mode 100644 pkg/balancer/manager_test.go delete mode 100644 pkg/balancer/picker_test.go delete mode 100644 pkg/balancer/queue.go delete mode 100644 pkg/balancer/queue_test.go create mode 100644 pkg/consistent/consistent.go create mode 100644 pkg/consistent/consistent_test.go diff --git a/cmd/permify/permify.go b/cmd/permify/permify.go index b7848b0b5..f89f10227 100644 --- a/cmd/permify/permify.go +++ b/cmd/permify/permify.go @@ -3,6 +3,8 @@ package main import ( "os" + "github.com/cespare/xxhash/v2" + "github.com/sercand/kuberesolver/v5" "google.golang.org/grpc/balancer" @@ -12,7 +14,7 @@ import ( func main() { kuberesolver.RegisterInCluster() - balancer.Register(consistentbalancer.NewConsistentHashBalancerBuilder()) + balancer.Register(consistentbalancer.NewBuilder(xxhash.Sum64)) root := cmd.NewRootCommand() diff --git a/docs/api-reference/apidocs.swagger.json b/docs/api-reference/apidocs.swagger.json index a07abbb09..95e877aaa 100644 --- a/docs/api-reference/apidocs.swagger.json +++ b/docs/api-reference/apidocs.swagger.json @@ -3,7 +3,7 @@ "info": { "title": "Permify API", "description": "Permify is an open source authorization service for creating fine-grained and scalable authorization systems.", - "version": "v1.2.7", + "version": "v1.2.8", "contact": { "name": "API Support", "url": "https://github.com/Permify/permify/issues", diff --git a/docs/api-reference/openapiv2/apidocs.swagger.json b/docs/api-reference/openapiv2/apidocs.swagger.json index c43f30205..001d40800 100644 --- a/docs/api-reference/openapiv2/apidocs.swagger.json +++ b/docs/api-reference/openapiv2/apidocs.swagger.json @@ -3,7 +3,7 @@ "info": { "title": "Permify API", "description": "Permify is an open source authorization service for creating fine-grained and scalable authorization systems.", - "version": "v1.2.7", + "version": "v1.2.8", "contact": { "name": "API Support", "url": "https://github.com/Permify/permify/issues", diff --git a/internal/config/config.go b/internal/config/config.go index e99ae81ad..093f52452 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -185,9 +185,13 @@ type ( } Distributed struct { - Enabled bool `mapstructure:"enabled"` - Address string `mapstructure:"address"` - Port string `mapstructure:"port"` + Enabled bool `mapstructure:"enabled"` + Address string `mapstructure:"address"` + Port string `mapstructure:"port"` + PartitionCount int `mapstructure:"partition_count"` + ReplicationFactor int `mapstructure:"replication_factor"` + Load float64 `mapstructure:"load"` + PickerWidth int `mapstructure:"picker_width"` } ) diff --git a/internal/engines/balancer/balancer.go b/internal/engines/balancer/balancer.go index 36164de26..570a7c0c3 100644 --- a/internal/engines/balancer/balancer.go +++ b/internal/engines/balancer/balancer.go @@ -2,12 +2,10 @@ package balancer import ( "context" - "encoding/hex" "fmt" "log/slog" "time" - "github.com/cespare/xxhash/v2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" @@ -16,15 +14,10 @@ import ( "github.com/Permify/permify/internal/engines" "github.com/Permify/permify/internal/invoke" "github.com/Permify/permify/internal/storage" - "github.com/Permify/permify/pkg/balancer" base "github.com/Permify/permify/pkg/pb/base/v1" ) -var grpcServicePolicy = fmt.Sprintf(`{ - "loadBalancingPolicy": "%s" - }`, balancer.Policy) - // Balancer is a wrapper around the balancer hash implementation that type Balancer struct { schemaReader storage.SchemaReader @@ -52,7 +45,7 @@ func NewCheckEngineWithBalancer( ) // Set up TLS credentials if paths are provided - if srv.TLSConfig.CertPath != "" && srv.TLSConfig.KeyPath != "" { + if srv.TLSConfig.Enabled && srv.TLSConfig.CertPath != "" && srv.TLSConfig.KeyPath != "" { isSecure = true creds, err = credentials.NewClientTLSFromFile(srv.TLSConfig.CertPath, srv.TLSConfig.KeyPath) if err != nil { @@ -62,10 +55,22 @@ func NewCheckEngineWithBalancer( creds = insecure.NewCredentials() } + bc := &balancer.Config{ + PartitionCount: dst.PartitionCount, + ReplicationFactor: dst.ReplicationFactor, + Load: dst.Load, + PickerWidth: dst.PickerWidth, + } + + bcjson, err := bc.ServiceConfigJSON() + if err != nil { + return nil, err + } + // Append common options options = append( options, - grpc.WithDefaultServiceConfig(grpcServicePolicy), + grpc.WithDefaultServiceConfig(bcjson), grpc.WithTransportCredentials(creds), ) @@ -82,7 +87,7 @@ func NewCheckEngineWithBalancer( } } - conn, err := grpc.Dial(dst.Address, options...) + conn, err := grpc.NewClient(dst.Address, options...) if err != nil { return nil, err } @@ -112,29 +117,12 @@ func (c *Balancer) Check(ctx context.Context, request *base.PermissionCheckReque isRelational := engines.IsRelational(en, request.GetPermission()) - // Create a new xxhash instance. - h := xxhash.New() - - // Generate a unique key for the request based on its relational state. - // This key helps in distributing the request. - _, err = h.Write([]byte(engines.GenerateKey(request, isRelational))) - if err != nil { - slog.ErrorContext(ctx, err.Error()) - return &base.PermissionCheckResponse{ - Can: base.CheckResult_CHECK_RESULT_DENIED, - Metadata: &base.PermissionCheckResponseMetadata{ - CheckCount: 0, - }, - }, err - } - k := hex.EncodeToString(h.Sum(nil)) - // Add a timeout of 2 seconds to the context and also set the generated key as a value. - withTimeout, cancel := context.WithTimeout(context.WithValue(ctx, balancer.Key, k), 4*time.Second) + withTimeout, cancel := context.WithTimeout(context.WithValue(ctx, balancer.Key, []byte(engines.GenerateKey(request, isRelational))), 4*time.Second) defer cancel() // Logging the intention to forward the request to the underlying client. - slog.DebugContext(ctx, "Forwarding request with key to the underlying client", slog.String("key", k)) + slog.InfoContext(ctx, "Forwarding request with key to the underlying client") // Perform the actual permission check by making a call to the underlying client. response, err := c.client.Check(withTimeout, request) diff --git a/internal/info.go b/internal/info.go index d29787d75..30e1e6770 100644 --- a/internal/info.go +++ b/internal/info.go @@ -23,7 +23,7 @@ var Identifier = "" */ const ( // Version is the last release of the Permify (e.g. v0.1.0) - Version = "v1.2.7" + Version = "v1.2.8" ) // Function to create a single line of the ASCII art with centered content and color diff --git a/pkg/balancer/balancer.go b/pkg/balancer/balancer.go new file mode 100644 index 000000000..1a720e792 --- /dev/null +++ b/pkg/balancer/balancer.go @@ -0,0 +1,250 @@ +package balancer + +import ( + "errors" + "fmt" + "log/slog" + + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + + "github.com/Permify/permify/pkg/consistent" +) + +type Balancer struct { + // Current overall connectivity state of the balancer. + state connectivity.State + + // The ClientConn to communicate with the gRPC client. + clientConn balancer.ClientConn + + // Current picker used to select SubConns for requests. + picker balancer.Picker + + // Evaluates connectivity state transitions for SubConns. + connectivityEvaluator *balancer.ConnectivityStateEvaluator + + // Map of resolver addresses to SubConns. + addressSubConns *resolver.AddressMap + + // Tracks the connectivity state of each SubConn. + subConnStates map[balancer.SubConn]connectivity.State + + // Configuration for consistent hashing and replication. + config *Config + + // Consistent hashing mechanism to distribute requests. + consistent *consistent.Consistent + + // Hasher used by the consistent hashing mechanism. + hasher consistent.Hasher + + // Stores the last resolver error encountered. + lastResolverError error + + // Stores the last connection error encountered. + lastConnectionError error +} + +func (b *Balancer) ResolverError(err error) { + b.lastResolverError = err + if b.addressSubConns.Len() == 0 { + b.state = connectivity.TransientFailure + b.picker = base.NewErrPicker(errors.Join(b.lastConnectionError, b.lastResolverError)) + } + + if b.state != connectivity.TransientFailure { + return + } + + // Update the balancer state and picker. + b.clientConn.UpdateState(balancer.State{ + ConnectivityState: b.state, + Picker: b.picker, + }) +} + +func (b *Balancer) UpdateClientConnState(s balancer.ClientConnState) error { + // Log the new ClientConn state. + slog.Info("Received new ClientConn state", + slog.Any("state", s), + ) + + // Reset any existing resolver error. + b.lastResolverError = nil + + // Handle changes to the balancer configuration. + if s.BalancerConfig != nil { + svcConfig := s.BalancerConfig.(*Config) + if b.config == nil || svcConfig.ReplicationFactor != b.config.ReplicationFactor { + slog.Info("Updating consistent hashing configuration", + slog.Int("partition_count", svcConfig.PartitionCount), + slog.Int("replication_factor", svcConfig.ReplicationFactor), + slog.Float64("load", svcConfig.Load), + slog.Int("picker_width", svcConfig.PickerWidth), + ) + b.consistent = consistent.New(consistent.Config{ + PartitionCount: svcConfig.PartitionCount, + ReplicationFactor: svcConfig.ReplicationFactor, + Load: svcConfig.Load, + PickerWidth: svcConfig.PickerWidth, + Hasher: b.hasher, + }) + b.config = svcConfig + } + } + + // Check if the consistent hashing configuration exists. + if b.consistent == nil { + slog.Error("No consistent hashing configuration found") + b.picker = base.NewErrPicker(errors.Join(b.lastConnectionError, b.lastResolverError)) + b.clientConn.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + return fmt.Errorf("no consistent hashing configuration found") + } + + // Maintain a set of addresses provided by the resolver. + addrsSet := resolver.NewAddressMap() + for _, addr := range s.ResolverState.Addresses { + addrsSet.Set(addr, nil) + + // Add new SubConns for addresses that are not already tracked. + if _, ok := b.addressSubConns.Get(addr); !ok { + sc, err := b.clientConn.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{HealthCheckEnabled: false}) + if err != nil { + slog.Warn("Failed to create new SubConn", + slog.String("address", addr.Addr), + slog.String("server_name", addr.ServerName), + slog.String("error", err.Error()), + ) + continue + } + + b.addressSubConns.Set(addr, sc) + b.subConnStates[sc] = connectivity.Idle + b.connectivityEvaluator.RecordTransition(connectivity.Shutdown, connectivity.Idle) + sc.Connect() + + b.consistent.Add(ConsistentMember{ + SubConn: sc, + name: fmt.Sprintf("%s|%s", addr.ServerName, addr.Addr), + }) + } + } + + // Remove SubConns that are no longer part of the resolved addresses. + for _, addr := range b.addressSubConns.Keys() { + sci, _ := b.addressSubConns.Get(addr) + sc := sci.(balancer.SubConn) + if _, ok := addrsSet.Get(addr); !ok { + slog.Info("Removing SubConn", + slog.String("address", addr.Addr), + slog.String("server_name", addr.ServerName), + ) + b.clientConn.RemoveSubConn(sc) + b.addressSubConns.Delete(addr) + b.consistent.Remove(ConsistentMember{ + SubConn: sc, + name: fmt.Sprintf("%s|%s", addr.ServerName, addr.Addr), + }.String()) + } + } + + // Log the current members in the consistent hashing ring. + slog.Info("Current consistent members", + slog.Int("member_count", len(b.consistent.Members())), + ) + for _, m := range b.consistent.Members() { + slog.Info("Consistent member", slog.String("member", m.String())) + } + + // Handle the case where the resolver produces zero addresses. + if len(s.ResolverState.Addresses) == 0 { + err := errors.New("resolver produced zero addresses") + b.ResolverError(err) + slog.Error("Resolver produced zero addresses") + return balancer.ErrBadResolverState + } + + // Update the picker based on the current balancer state. + if b.state == connectivity.TransientFailure { + slog.Warn("Transient failure detected, using error picker") + b.picker = base.NewErrPicker(errors.Join(b.lastConnectionError, b.lastResolverError)) + } else { + width := b.config.PickerWidth + if width < 1 { + width = 1 + } + slog.Info("Creating new picker", + slog.Int("width", width), + ) + b.picker = &picker{ + consistent: b.consistent, + width: width, + } + } + + // Update the ClientConn state with the new picker. + slog.Info("Updating ClientConn state", + slog.String("connectivity_state", b.state.String()), + ) + b.clientConn.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + + return nil +} + +func (b *Balancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + s := state.ConnectivityState + slog.Info("Received SubConn state change", + slog.String("connectivity_state", s.String()), + slog.String("sub_conn", fmt.Sprintf("%p", sc)), + ) + + oldS, ok := b.subConnStates[sc] + if !ok { + slog.Warn("State change for unknown SubConn", + slog.String("connectivity_state", s.String()), + slog.String("sub_conn", fmt.Sprintf("%p", sc)), + ) + return + } + + if oldS == connectivity.TransientFailure && (s == connectivity.Connecting || s == connectivity.Idle) { + if s == connectivity.Idle { + slog.Info("Transitioning SubConn to connecting state", + slog.String("sub_conn", fmt.Sprintf("%p", sc)), + ) + sc.Connect() + } + return + } + + b.subConnStates[sc] = s + switch s { + case connectivity.Idle: + slog.Info("SubConn is idle, initiating connection", + slog.String("sub_conn", fmt.Sprintf("%p", sc)), + ) + sc.Connect() + case connectivity.Shutdown: + slog.Info("Removing shutdown SubConn", + slog.String("sub_conn", fmt.Sprintf("%p", sc)), + ) + delete(b.subConnStates, sc) + case connectivity.TransientFailure: + slog.Warn("SubConn in transient failure", + slog.String("sub_conn", fmt.Sprintf("%p", sc)), + slog.String("error", state.ConnectionError.Error()), + ) + b.lastConnectionError = state.ConnectionError + } + + b.state = b.connectivityEvaluator.RecordTransition(oldS, s) + slog.Info("Updating ClientConn state", + slog.String("connectivity_state", b.state.String()), + ) + b.clientConn.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) +} + +func (b *Balancer) Close() {} diff --git a/pkg/balancer/balancer_test.go b/pkg/balancer/balancer_test.go index c8a383254..38d746c9c 100644 --- a/pkg/balancer/balancer_test.go +++ b/pkg/balancer/balancer_test.go @@ -7,6 +7,8 @@ import ( . "github.com/onsi/gomega" ) +// This is the entry point for the test suite for the "consistent" package. +// It registers a failure handler and runs the specifications (specs) for this package. func TestBalancer(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "balancer-suite") diff --git a/pkg/balancer/builder.go b/pkg/balancer/builder.go index e48429d45..ef875f5cc 100644 --- a/pkg/balancer/builder.go +++ b/pkg/balancer/builder.go @@ -1,37 +1,150 @@ package balancer import ( + "encoding/json" + "fmt" "sync" + "golang.org/x/exp/slog" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + + "github.com/Permify/permify/pkg/consistent" +) + +// Package-level constants for the balancer name and consistent hash key. +const ( + Name = "consistenthashing" // Name of the balancer. + Key = "consistenthashkey" // Key for the consistent hash. ) -// NewConsistentHashBalancerBuilder returns a consistentHashBalancerBuilder. -func NewConsistentHashBalancerBuilder() balancer.Builder { - return &consistentHashBalancerBuilder{} +// Config represents the configuration for the consistent hashing balancer. +type Config struct { + serviceconfig.LoadBalancingConfig `json:"-"` // Embedding the base load balancing config. + PartitionCount int `json:"partitionCount,omitempty"` // Number of partitions in the consistent hash ring. + ReplicationFactor int `json:"replicationFactor,omitempty"` // Number of replicas for each member. + Load float64 `json:"load,omitempty"` // Load factor for balancing traffic. + PickerWidth int `json:"pickerWidth,omitempty"` // Number of closest members to consider in the picker. } -// consistentHashBalancerBuilder is an empty struct with functions Build and Name, implemented from balancer.Builder -type consistentHashBalancerBuilder struct{} +// ServiceConfigJSON generates the JSON representation of the load balancer configuration. +func (c *Config) ServiceConfigJSON() (string, error) { + // Define the JSON wrapper structure for the load balancing config. + type Wrapper struct { + LoadBalancingConfig []map[string]*Config `json:"loadBalancingConfig"` + } + + // Apply default values for zero fields. + if c.PartitionCount == 0 { + c.PartitionCount = consistent.DefaultPartitionCount + } + if c.ReplicationFactor == 0 { + c.ReplicationFactor = consistent.DefaultReplicationFactor + } + if c.Load == 0 { + c.Load = consistent.DefaultLoad + } + if c.PickerWidth == 0 { + c.PickerWidth = consistent.DefaultPickerWidth + } + + // Create the wrapper with the current configuration. + wrapper := Wrapper{ + LoadBalancingConfig: []map[string]*Config{ + {Name: c}, + }, + } -// Build creates a consistentHashBalancer, and starts its scManager. -func (builder *consistentHashBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - b := &consistentHashBalancer{ - clientConn: cc, - addressInfoMap: make(map[string]resolver.Address), - subConnectionMap: make(map[string]balancer.SubConn), - subConnInfoSyncMap: sync.Map{}, - pickerResultChannel: make(chan PickResult), - activePickResults: NewQueue(), - subConnPickCounts: make(map[balancer.SubConn]*int32), - subConnStatusMap: make(map[balancer.SubConn]bool), + // Marshal the wrapped configuration to JSON. + jsonData, err := json.Marshal(wrapper) + if err != nil { + return "", fmt.Errorf("failed to marshal service config: %w", err) } - go b.manageSubConnections() - return b + + return string(jsonData), nil +} + +// NewBuilder initializes a new builder with the given hashing function. +func NewBuilder(fn consistent.Hasher) Builder { + return &builder{hasher: fn} } -// Name returns the name of the consistentHashBalancer registering in grpc. -func (builder *consistentHashBalancerBuilder) Name() string { - return Policy +// ConsistentMember represents a member in the consistent hashing ring. +type ConsistentMember struct { + balancer.SubConn // Embedded SubConn for the gRPC connection. + name string // Unique identifier for the member. +} + +// String returns the name of the ConsistentMember. +func (s ConsistentMember) String() string { return s.name } + +// builder is responsible for creating and configuring the consistent hashing balancer. +type builder struct { + sync.Mutex // Mutex for thread-safe updates to the builder. + hasher consistent.Hasher // Hashing function for the consistent hash ring. + config Config // Current balancer configuration. +} + +// Builder defines the interface for the consistent hashing balancer builder. +type Builder interface { + balancer.Builder // Interface for building balancers. + balancer.ConfigParser // Interface for parsing balancer configurations. +} + +// Name returns the name of the balancer. +func (b *builder) Name() string { return Name } + +// Build creates a new instance of the consistent hashing balancer. +func (b *builder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { + // Initialize a new balancer with default values. + bal := &Balancer{ + clientConn: cc, + addressSubConns: resolver.NewAddressMap(), + subConnStates: make(map[balancer.SubConn]connectivity.State), + connectivityEvaluator: &balancer.ConnectivityStateEvaluator{}, + state: connectivity.Connecting, // Initial state. + hasher: b.hasher, + picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable), // Default picker with no SubConns available. + } + + return bal +} + +// ParseConfig parses the balancer configuration from the provided JSON. +func (b *builder) ParseConfig(rm json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + var cfg Config + // Unmarshal the JSON configuration into the Config struct. + if err := json.Unmarshal(rm, &cfg); err != nil { + return nil, fmt.Errorf("consistenthash: unable to unmarshal LB policy config: %s, error: %w", string(rm), err) + } + + // Log the parsed configuration using structured logging. + slog.Info("Parsed balancer configuration", + slog.String("raw_json", string(rm)), // Log the raw JSON string. + slog.Any("config", cfg), // Log the unmarshaled Config struct. + ) + + // Set default values for configuration if not provided. + if cfg.PartitionCount == 0 { + cfg.PartitionCount = consistent.DefaultPartitionCount + } + if cfg.ReplicationFactor == 0 { + cfg.ReplicationFactor = consistent.DefaultReplicationFactor + } + if cfg.Load == 0 { + cfg.Load = consistent.DefaultLoad + } + if cfg.PickerWidth == 0 { + cfg.PickerWidth = consistent.DefaultPickerWidth + } + + // Update the builder's configuration with thread safety. + b.Lock() + b.config = cfg + b.Unlock() + + return &cfg, nil } diff --git a/pkg/balancer/builder_test.go b/pkg/balancer/builder_test.go index e14207beb..806a59768 100644 --- a/pkg/balancer/builder_test.go +++ b/pkg/balancer/builder_test.go @@ -3,42 +3,27 @@ package balancer import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "google.golang.org/grpc/balancer" ) -var _ = Describe("consistentHashBalancerBuilder", func() { - var builder balancer.Builder - var mockClientConn balancer.ClientConn - var buildOpts balancer.BuildOptions +var _ = Describe("balancer", func() { + Context("ServiceConfigJSON", func() { + It("should generate valid JSON", func() { + // Set up a sample configuration. + config := &Config{ + PartitionCount: 271, + ReplicationFactor: 20, + Load: 1.25, + PickerWidth: 3, + } - BeforeEach(func() { - builder = NewConsistentHashBalancerBuilder() - // You'll want to mock or create a real balancer.ClientConn and balancer.BuildOptions here. - // For now, we'll keep them nil for simplicity. - mockClientConn = nil - buildOpts = balancer.BuildOptions{} - }) - - Describe("Name", func() { - It("should return the expected balancer name", func() { - name := builder.Name() - Expect(name).To(Equal(Policy)) - }) - }) + // Generate the JSON using ServiceConfigJSON. + jsonString, err := config.ServiceConfigJSON() - Describe("Build", func() { - It("should return a consistentHashBalancer", func() { - b := builder.Build(mockClientConn, buildOpts) - Expect(b).To(Not(BeNil())) + // Expect no error during JSON generation. + Expect(err).ToNot(HaveOccurred(), "ServiceConfigJSON should not return an error") - // Further assertions can be made depending on the properties and - // behaviors of the consistentHashBalancer. For example: - chb, ok := b.(*consistentHashBalancer) - Expect(ok).To(BeTrue()) - Expect(chb.clientConn).To(BeNil()) - Expect(chb.addressInfoMap).To(Not(BeNil())) - Expect(chb.subConnectionMap).To(Not(BeNil())) - Expect(chb.activePickResults).To(Not(BeNil())) + // Validate the parsed Config fields match the original configuration. + Expect(jsonString).To(Equal("{\"loadBalancingConfig\":[{\"consistenthashing\":{\"partitionCount\":271,\"replicationFactor\":20,\"load\":1.25,\"pickerWidth\":3}}]}")) }) }) }) diff --git a/pkg/balancer/errors.go b/pkg/balancer/errors.go deleted file mode 100644 index 3d01ed315..000000000 --- a/pkg/balancer/errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package balancer - -import ( - "errors" -) - -var ( - // ErrSubConnMissing indicates that a SubConn (sub-connection) was expected but not found. - ErrSubConnMissing = errors.New("sub-connection is missing or not found") - // ErrSubConnResetFailure indicates an error occurred while trying to reset the SubConn. - ErrSubConnResetFailure = errors.New("failed to reset the sub-connection") -) diff --git a/pkg/balancer/hashring.go b/pkg/balancer/hashring.go deleted file mode 100644 index 6a529d161..000000000 --- a/pkg/balancer/hashring.go +++ /dev/null @@ -1,345 +0,0 @@ -package balancer - -import ( - "errors" - "fmt" - "log" - "log/slog" - "sync" - "time" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/resolver" -) - -const ( - // Policy defines the name or identifier for the consistent hashing load balancing policy. - Policy = "consistenthashpolicy" - - // Key is the context key used to retrieve the hash key for consistent hashing. - Key = "consistenthashkey" - - // ConnectionLifetime specifies the duration for which a connection is maintained - // before being considered for termination or renewal. - ConnectionLifetime = time.Second * 5 -) - -// subConnInfo records the state and addr corresponding to the SubConn. -type subConnInfo struct { - state connectivity.State - addr string -} - -type consistentHashBalancer struct { - // clientConn represents the client connection that created this balancer. - clientConn balancer.ClientConn - - // connectionState indicates the current state of the connection. - connectionState connectivity.State - - // addressInfoMap maps address strings to their corresponding resolver addresses. - addressInfoMap map[string]resolver.Address - - // subConnectionMap maps address strings to their associated sub-connections. - subConnectionMap map[string]balancer.SubConn - - // subConnInfoSyncMap is a concurrent map that stores information about each sub-connection. - subConnInfoSyncMap sync.Map - - // currentPicker represents the picker currently used by the client connection for load balancing. - currentPicker balancer.Picker - - // lastResolverError stores the most recent error reported by the resolver. - lastResolverError error - - // lastConnectionError stores the most recent connection-related error. - lastConnectionError error - - // pickerResultChannel is a channel through which pickers report their results. - pickerResultChannel chan PickResult - - // activePickResults is a queue for storing pick results that have active contexts. - activePickResults *Queue - - // subConnPickCounts tracks the number of pick results associated with each sub-connection. - subConnPickCounts map[balancer.SubConn]*int32 - - // subConnStatusMap indicates the status (active/inactive) of each sub-connection. - subConnStatusMap map[balancer.SubConn]bool - - // balancerLock is a mutex used to ensure thread safety, especially when accessing the subConnPickCounts map. - balancerLock sync.Mutex -} - -// UpdateClientConnState processes the provided ClientConnState and updates -// the internal state of the balancer accordingly. -func (b *consistentHashBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - b.balancerLock.Lock() // Ensure exclusive access to balancers data. - defer b.balancerLock.Unlock() - - // Update address information and get a set of active addresses. - addrsSet := b.updateAddressInfo(s) - - // Remove any sub-connections that are no longer active. - b.removeStaleSubConns(addrsSet) - - // If there are no addresses from the resolver, log an error. - if len(s.ResolverState.Addresses) == 0 { - b.ResolverError(errors.New("produced zero addresses")) - return balancer.ErrBadResolverState - } - - // Re-generate the picker based on the updated state and inform the client connection. - b.regeneratePicker() - b.clientConn.UpdateState(balancer.State{ConnectivityState: b.connectionState, Picker: b.currentPicker}) - - return nil -} - -// updateAddressInfo processes the provided ClientConnState and updates -// the balancers address information. It returns a set of active addresses. -func (b *consistentHashBalancer) updateAddressInfo(s balancer.ClientConnState) map[string]struct{} { - addrsSet := make(map[string]struct{}) - - // Iterate over the addresses from the resolver. - for _, a := range s.ResolverState.Addresses { - addr := a.Addr - b.addressInfoMap[addr] = a - addrsSet[addr] = struct{}{} - - // If there isn't a sub-connection for this address, create one. - if sc, ok := b.subConnectionMap[addr]; !ok { - if err := b.createNewSubConn(a, addr); err != nil { - log.Printf("Consistent Hash Balancer: failed to create new SubConn: %v", err) - } - } else { - // If a sub-connection exists, update its addresses. - b.clientConn.UpdateAddresses(sc, []resolver.Address{a}) - } - } - - return addrsSet -} - -// createNewSubConn creates a new sub-connection for the provided address. -func (b *consistentHashBalancer) createNewSubConn(a resolver.Address, addr string) error { - newSC, err := b.clientConn.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{HealthCheckEnabled: false}) - if err != nil { - return err - } - - // Store the new sub-connection and its info. - b.subConnectionMap[addr] = newSC - b.subConnInfoSyncMap.Store(newSC, &subConnInfo{ - state: connectivity.Idle, - addr: addr, - }) - newSC.Connect() - - return nil -} - -// removeStaleSubConns removes sub-connections that are no longer in the active addresses set. -func (b *consistentHashBalancer) removeStaleSubConns(addrsSet map[string]struct{}) { - for a, sc := range b.subConnectionMap { - // If a sub-connection's address isn't in the active set, remove it. - if _, ok := addrsSet[a]; !ok { - b.clientConn.RemoveSubConn(sc) - delete(b.subConnectionMap, a) - b.subConnInfoSyncMap.Delete(sc) // Cleanup related data. - } - } -} - -// ResolverError handles resolver errors and updates state/picker accordingly. -func (b *consistentHashBalancer) ResolverError(err error) { - b.balancerLock.Lock() // Ensure exclusive access to balancers data. - defer b.balancerLock.Unlock() - - // Store the error and re-generate the picker. - b.lastResolverError = err - b.regeneratePicker() - - // Update client connection state if in a TransientFailure state. - if b.connectionState != connectivity.TransientFailure { - return - } - b.clientConn.UpdateState(balancer.State{ - ConnectivityState: b.connectionState, - Picker: b.currentPicker, - }) -} - -// regeneratePicker generates a new picker to replace the old one with new data, and update the state of the balancer. -func (b *consistentHashBalancer) regeneratePicker() { - availableSCs := make(map[string]balancer.SubConn) - - for addr, sc := range b.subConnectionMap { - if stIface, ok := b.subConnInfoSyncMap.Load(sc); ok { - if st, ok := stIface.(*subConnInfo); ok { - // Only include sub-connections that are in a Ready or Idle state - if st.state == connectivity.Ready || st.state == connectivity.Idle { - availableSCs[addr] = sc - } - } else { - log.Printf("Unexpected type in scInfos for key %v: expected *subConnInfo, got %T", sc, stIface) - } - } - } - - if len(availableSCs) == 0 { - b.connectionState = connectivity.TransientFailure - b.currentPicker = base.NewErrPicker(b.mergeErrors()) - } else { - b.connectionState = connectivity.Ready - b.currentPicker = NewConsistentHashPicker(availableSCs) - } -} - -// mergeErrors - -func (b *consistentHashBalancer) mergeErrors() error { - // If both errors are nil, return a generic error. - if b.lastConnectionError == nil && b.lastResolverError == nil { - return fmt.Errorf("unknown error occurred") - } - - // If only one of the errors is nil, return the other error. - if b.lastConnectionError == nil { - return fmt.Errorf("last resolver error: %w", b.lastResolverError) - } - if b.lastResolverError == nil { - return fmt.Errorf("last connection error: %w", b.lastConnectionError) - } - - // If both errors are present, concatenate them. - return errors.Join(b.lastConnectionError, b.lastResolverError) -} - -// UpdateSubConnState - -func (b *consistentHashBalancer) UpdateSubConnState(subConn balancer.SubConn, stateUpdate balancer.SubConnState) { - currentState := stateUpdate.ConnectivityState - - storedInfo, infoExists := b.subConnInfoSyncMap.Load(subConn) - if !infoExists { - // If the subConn isn't in our info map, it's a no-op for us. - return - } - - subConnInfo := storedInfo.(*subConnInfo) - previousState := subConnInfo.state - - if previousState == currentState { - // If the state hasn't changed, no need for further processing. - return - } - - slog.Debug("State of one sub-connection changed", slog.String("previous", previousState.String()), slog.String("current", currentState.String())) - - // Handle transitions from TransientFailure to Connecting. - if previousState == connectivity.TransientFailure && currentState == connectivity.Connecting { - return - } - - // Update the state in the stored sub-connection info. - subConnInfo.state = currentState - - switch currentState { - case connectivity.Idle: - subConn.Connect() - case connectivity.Shutdown: - b.subConnInfoSyncMap.Delete(subConn) - case connectivity.TransientFailure: - b.lastConnectionError = stateUpdate.ConnectionError - } - - // If there's a significant change in the connection state, regenerate the picker. - if hasSignificantStateChange(previousState, currentState) || b.connectionState == connectivity.TransientFailure { - b.regeneratePicker() - } - - b.clientConn.UpdateState(balancer.State{ConnectivityState: b.connectionState, Picker: b.currentPicker}) -} - -// hasSignificantStateChange - -func hasSignificantStateChange(oldState, newState connectivity.State) bool { - isOldStateSignificant := oldState == connectivity.TransientFailure || oldState == connectivity.Shutdown - isNewStateSignificant := newState == connectivity.TransientFailure || newState == connectivity.Shutdown - return isOldStateSignificant != isNewStateSignificant -} - -// Close - -func (b *consistentHashBalancer) Close() {} - -// resetSubConn resets the given SubConn. -// It first retrieves the address associated with the SubConn, then tries to reset it using that address. -func (b *consistentHashBalancer) resetSubConn(subConn balancer.SubConn) error { - // Get the address associated with the SubConn. - address, err := b.getSubConnAddr(subConn) - if err != nil { - return fmt.Errorf("failed to get address for sub connection: %v", err) - } - - slog.Debug("Resetting connection with", slog.String("address", address)) - - // Reset the SubConn using its address. - if resetErr := b.resetSubConnWithAddr(address); resetErr != nil { - return fmt.Errorf("failed to reset sub connection with address %s: %v", address, resetErr) - } - - return nil -} - -// getSubConnAddr retrieves the address associated with the given SubConn from the stored connection info. -func (b *consistentHashBalancer) getSubConnAddr(subConn balancer.SubConn) (string, error) { - // Load the sub connection info from the sync map. - connInfoValue, exists := b.subConnInfoSyncMap.Load(subConn) - - if !exists { - return "", ErrSubConnMissing - } - - // Type assert and return the address from the sub connection info. - subConnInfo := connInfoValue.(*subConnInfo) - return subConnInfo.addr, nil -} - -// resetSubConnWithAddr replaces the current SubConn associated with the provided address with a new one. -func (b *consistentHashBalancer) resetSubConnWithAddr(address string) error { - // Retrieve the current SubConn associated with the address. - currentSubConn, exists := b.subConnectionMap[address] - if !exists { - return ErrSubConnMissing - } - - // Delete the info and remove the SubConn. - b.subConnInfoSyncMap.Delete(currentSubConn) - b.clientConn.RemoveSubConn(currentSubConn) - - // Fetch the address information. - addressInfo, infoExists := b.addressInfoMap[address] - if !infoExists { - slog.Error("Consistent Hash Balancer: Address information missing for", slog.String("address", address)) - return ErrSubConnResetFailure - } - - // Create a new SubConn with the address information. - newSubConn, err := b.clientConn.NewSubConn([]resolver.Address{addressInfo}, balancer.NewSubConnOptions{HealthCheckEnabled: false}) - if err != nil { - return err - } - - // Store the new SubConn and its information. - b.subConnectionMap[address] = newSubConn - b.subConnInfoSyncMap.Store(newSubConn, &subConnInfo{ - state: connectivity.Idle, - addr: address, - }) - - // Regenerate picker and update the balancer state. - b.regeneratePicker() - b.clientConn.UpdateState(balancer.State{ConnectivityState: b.connectionState, Picker: b.currentPicker}) - - return nil -} diff --git a/pkg/balancer/manager.go b/pkg/balancer/manager.go deleted file mode 100644 index 4f2f1b24b..000000000 --- a/pkg/balancer/manager.go +++ /dev/null @@ -1,95 +0,0 @@ -package balancer - -import ( - "context" - "log" - "sync/atomic" - "time" -) - -// manageSubConnections initiates two goroutines to handle the pick results. -// One goroutine enqueues and tracks pick results while the other handles dequeued pick results. -func (b *consistentHashBalancer) manageSubConnections() { - // Goroutine to listen for pick results and enqueue them for further processing. - go func() { - for { - pr := <-b.pickerResultChannel - b.enqueueAndTrackPickResult(pr) - } - }() - - // Goroutine to process and handle dequeued pick results. - go func() { - for { - v, ok := b.activePickResults.DeQueue() - if !ok { - time.Sleep(ConnectionLifetime) - continue - } - pr := v.(PickResult) - b.handleDequeuedPickResult(pr) - } - }() -} - -// enqueueAndTrackPickResult enqueues a given pick result and tracks its state. -// It also enqueues a shadow pick result with a limited lifespan. -func (b *consistentHashBalancer) enqueueAndTrackPickResult(pr PickResult) { - // Enqueue the original pick result. - b.activePickResults.EnQueue(pr) - - // Create a shadow context with a predefined lifespan. - shadowCtx, cancel := context.WithTimeout(context.Background(), ConnectionLifetime) - defer cancel() - - // Enqueue the shadow pick result. - b.activePickResults.EnQueue(PickResult{Ctx: shadowCtx, SC: pr.SC}) - - // Update the sub-connection counts in a thread-safe manner. - b.balancerLock.Lock() - cnt, ok := b.subConnPickCounts[pr.SC] - if !ok { - cnt = new(int32) - b.subConnPickCounts[pr.SC] = cnt - } - *cnt += 2 - b.balancerLock.Unlock() -} - -// handleDequeuedPickResult processes a dequeued pick result. -// Depending on the state of the context and the sub-connection, various actions are taken. -func (b *consistentHashBalancer) handleDequeuedPickResult(pr PickResult) { - select { - // If the context associated with the pick result is done... - case <-pr.Ctx.Done(): - b.balancerLock.Lock() - defer b.balancerLock.Unlock() - - // If the sub-connection is in a certain status, re-enqueue the pick result. - if b.subConnStatusMap[pr.SC] { - b.activePickResults.EnQueue(pr) - return - } - - // Decrease the count for the sub-connection. - cnt, ok := b.subConnPickCounts[pr.SC] - if !ok { - return - } - - atomic.AddInt32(cnt, -1) - // If count becomes zero, reset the sub-connection. - if *cnt == 0 { - delete(b.subConnPickCounts, pr.SC) - - b.subConnStatusMap[pr.SC] = true - if err := b.resetSubConn(pr.SC); err != nil { - log.Printf("Failed to reset SubConn: %v", err) - } - delete(b.subConnStatusMap, pr.SC) - } - // If the context isn't done yet, re-enqueue the pick result. - default: - b.activePickResults.EnQueue(pr) - } -} diff --git a/pkg/balancer/manager_test.go b/pkg/balancer/manager_test.go deleted file mode 100644 index 6747bfb33..000000000 --- a/pkg/balancer/manager_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package balancer - -import ( - "context" - "sync" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/resolver" -) - -var _ = Describe("consistentHashBalancer", func() { - var b *consistentHashBalancer - - BeforeEach(func() { - b = &consistentHashBalancer{ - clientConn: nil, // Mock this if needed - connectionState: connectivity.Ready, - addressInfoMap: make(map[string]resolver.Address), - subConnectionMap: make(map[string]balancer.SubConn), - subConnInfoSyncMap: sync.Map{}, - currentPicker: nil, // Mock this if needed - lastResolverError: nil, - lastConnectionError: nil, - pickerResultChannel: make(chan PickResult, 10), - activePickResults: NewQueue(), - subConnPickCounts: make(map[balancer.SubConn]*int32), - subConnStatusMap: make(map[balancer.SubConn]bool), - balancerLock: sync.Mutex{}, - } - }) - - Describe("manageSubConnections", func() { - It("should start without panicking", func() { - go b.manageSubConnections() - }) - }) - - Describe("enqueueAndTrackPickResult", func() { - It("should enqueue and track pick results", func() { - mockSC := balancer.SubConn(nil) - pr := PickResult{ - Ctx: context.Background(), - SC: mockSC, - } - - b.enqueueAndTrackPickResult(pr) - - Expect(b.activePickResults.Len()).Should(Equal(2)) - cnt, ok := b.subConnPickCounts[mockSC] - Expect(ok).Should(BeTrue()) - Expect(*cnt).Should(Equal(int32(2))) - }) - }) - - Describe("handleDequeuedPickResult", func() { - It("should re-enqueue pick result when context isn't done", func() { - mockSC := balancer.SubConn(nil) - pr := PickResult{ - Ctx: context.Background(), - SC: mockSC, - } - - b.activePickResults.EnQueue(pr) - b.handleDequeuedPickResult(pr) - Expect(b.activePickResults.Len()).Should(Equal(2)) - }) - - It("should decrease count and possibly reset SubConn when context is done", func() { - mockSC := balancer.SubConn(nil) - ctx, cancel := context.WithCancel(context.Background()) - pr := PickResult{ - Ctx: ctx, - SC: mockSC, - } - - b.subConnPickCounts[mockSC] = new(int32) - *b.subConnPickCounts[mockSC] = 1 - - cancel() - b.handleDequeuedPickResult(pr) - - _, ok := b.subConnPickCounts[mockSC] - Expect(ok).Should(BeFalse()) - }) - }) -}) diff --git a/pkg/balancer/picker.go b/pkg/balancer/picker.go index b604fe824..ab459700e 100644 --- a/pkg/balancer/picker.go +++ b/pkg/balancer/picker.go @@ -1,68 +1,72 @@ package balancer import ( - "context" - "log/slog" - "sync" + "crypto/rand" + "fmt" + "log" + "math/big" - "github.com/serialx/hashring" "google.golang.org/grpc/balancer" -) -// ConsistentHashPicker is a custom gRPC picker that uses consistent hashing -// to determine which backend server should handle the request. -type ConsistentHashPicker struct { - subConns map[string]balancer.SubConn // Map of server addresses to their respective SubConns - mu sync.RWMutex // Mutex to protect concurrent access to subConns - hashRing *hashring.HashRing // Hash ring used for consistent hashing -} + "github.com/Permify/permify/pkg/consistent" +) -// PickResult represents the result of a pick operation. -// It contains the context and the selected SubConn for the request. -type PickResult struct { - Ctx context.Context - SC balancer.SubConn +type picker struct { + consistent *consistent.Consistent + width int } -// NewConsistentHashPicker initializes and returns a new ConsistentHashPicker. -// It creates a hash ring from the provided set of backend server addresses. -func NewConsistentHashPicker(subConns map[string]balancer.SubConn) *ConsistentHashPicker { - addrs := make([]string, 0, len(subConns)) - - // Extract addresses from the subConns map - for addr := range subConns { - addrs = append(addrs, addr) +// Generate a cryptographically secure random index function with resilient error handling +var randomIndex = func(max int) int { + // Ensure max > 0 to avoid issues + if max <= 0 { + log.Println("randomIndex: max value is less than or equal to 0, returning 0 as fallback") + return 0 } - slog.Debug("consistent hash picker built", slog.Any("addresses", addrs)) - - return &ConsistentHashPicker{ - subConns: subConns, - hashRing: hashring.New(addrs), + // Use crypto/rand to generate a random index + n, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + if err != nil { + // Log the error and return a deterministic fallback value (e.g., 0) + log.Printf("randomIndex: failed to generate a secure random number, returning 0 as fallback: %v", err) + return 0 } + + return int(n.Int64()) } -// Pick selects an appropriate backend server (SubConn) for the incoming request. -// If a custom key is provided in the context, it will be used for consistent hashing; -// otherwise, the full method name of the request will be used. -func (p *ConsistentHashPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - var ret balancer.PickResult - key, ok := info.Ctx.Value(Key).(string) +func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + // Safely extract the key from the context + keyValue := info.Ctx.Value(Key) + if keyValue == nil { + return balancer.PickResult{}, fmt.Errorf("context key missing") + } + key, ok := keyValue.([]byte) if !ok { - key = info.FullMethodName + return balancer.PickResult{}, fmt.Errorf("context key is not of type []byte") + } + + // Retrieve the closest N members + members, err := p.consistent.ClosestN(key, p.width) + if err != nil { + return balancer.PickResult{}, fmt.Errorf("failed to get closest members: %v", err) + } + if len(members) == 0 { + return balancer.PickResult{}, fmt.Errorf("no available members") } - slog.Debug("pick for", slog.String("key", key)) - // Safely read from the subConns map using the read lock - p.mu.RLock() - if targetAddr, ok := p.hashRing.GetNode(key); ok { - ret.SubConn = p.subConns[targetAddr] + // Randomly pick one member if width > 1 + index := 0 + if p.width > 1 { + index = randomIndex(p.width) } - p.mu.RUnlock() - // If no valid SubConn was found, return an error - if ret.SubConn == nil { - return ret, balancer.ErrNoSubConnAvailable + // Assert the member type + chosen, ok := members[index].(ConsistentMember) + if !ok { + return balancer.PickResult{}, fmt.Errorf("invalid member type: expected subConnMember") } - return ret, nil + + // Return the chosen connection + return balancer.PickResult{SubConn: chosen.SubConn}, nil } diff --git a/pkg/balancer/picker_test.go b/pkg/balancer/picker_test.go deleted file mode 100644 index 6d84a30f0..000000000 --- a/pkg/balancer/picker_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package balancer - -import ( - "context" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - - "google.golang.org/grpc/balancer" -) - -var _ = Describe("ConsistentHashPicker", func() { - var ( - testSubConns = map[string]balancer.SubConn{ - "addr1": nil, // Normally, you would use mock SubConn objects. For simplicity, we use nil here. - "addr2": nil, - } - picker *ConsistentHashPicker - ) - - BeforeEach(func() { - picker = NewConsistentHashPicker(testSubConns) - }) - - Describe("Initialization", func() { - It("should initialize with provided subConns", func() { - Expect(picker.subConns).To(Equal(testSubConns)) - Expect(picker.hashRing).ShouldNot(BeNil()) - }) - }) - - Describe("Pick", func() { - var pickInfo balancer.PickInfo - - Context("with custom key in context", func() { - BeforeEach(func() { - pickInfo = balancer.PickInfo{ - Ctx: context.WithValue(context.Background(), Key, "customKey"), - } - }) - - It("should return an ErrNoSubConnAvailable error", func() { - _, err := picker.Pick(pickInfo) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal("no SubConn is available")) - }) - }) - - Context("without custom key in context", func() { - BeforeEach(func() { - pickInfo = balancer.PickInfo{ - FullMethodName: "testMethod", - Ctx: context.Background(), - } - }) - - It("should return an ErrNoSubConnAvailable error", func() { - _, err := picker.Pick(pickInfo) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(Equal("no SubConn is available")) - }) - }) - - Context("with unavailable SubConn", func() { - BeforeEach(func() { - picker.subConns = make(map[string]balancer.SubConn) // Empty the subConns map - pickInfo = balancer.PickInfo{ - Ctx: context.WithValue(context.Background(), Key, "unavailableKey"), - } - }) - - It("should return an ErrNoSubConnAvailable error", func() { - _, err := picker.Pick(pickInfo) - Expect(err).To(Equal(balancer.ErrNoSubConnAvailable)) - }) - }) - }) -}) diff --git a/pkg/balancer/queue.go b/pkg/balancer/queue.go deleted file mode 100644 index 8fc546cb0..000000000 --- a/pkg/balancer/queue.go +++ /dev/null @@ -1,53 +0,0 @@ -package balancer - -import ( - "container/list" - "sync" -) - -// Queue is a basic FIFO queue based on a doubly linked list. -type Queue struct { - list *list.List - mu sync.Mutex -} - -// NewQueue returns a new queue. -func NewQueue() *Queue { - return &Queue{ - list: list.New(), - } -} - -// EnQueue adds an item to the end of the queue. -func (q *Queue) EnQueue(value interface{}) { - q.mu.Lock() - defer q.mu.Unlock() - q.list.PushBack(value) -} - -// DeQueue removes and returns the first item in the queue. -func (q *Queue) DeQueue() (interface{}, bool) { - q.mu.Lock() - defer q.mu.Unlock() - e := q.list.Front() - if e == nil { - return nil, false - } - v := e.Value - q.list.Remove(e) - return v, true -} - -// IsEmpty returns true if the queue is empty. -func (q *Queue) IsEmpty() bool { - q.mu.Lock() - defer q.mu.Unlock() - return q.list.Len() == 0 -} - -// Len returns the number of items in the queue. -func (q *Queue) Len() int { - q.mu.Lock() - defer q.mu.Unlock() - return q.list.Len() -} diff --git a/pkg/balancer/queue_test.go b/pkg/balancer/queue_test.go deleted file mode 100644 index 17aaa15a7..000000000 --- a/pkg/balancer/queue_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package balancer - -import ( - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("Queue", func() { - Describe("Newly initialized queue", func() { - var q *Queue - BeforeEach(func() { - q = NewQueue() - }) - - It("should not be nil", func() { - Expect(q).ShouldNot(BeNil()) - }) - - It("should be empty", func() { - Expect(q.IsEmpty()).Should(BeTrue()) - }) - - It("should have a length of 0", func() { - Expect(q.Len()).Should(Equal(0)) - }) - }) - - Describe("Queue operations", func() { - var q *Queue - BeforeEach(func() { - q = NewQueue() - }) - - Context("EnQueue", func() { - It("should add items to the queue", func() { - q.EnQueue(1) - Expect(q.IsEmpty()).Should(BeFalse()) - Expect(q.Len()).Should(Equal(1)) - }) - }) - - Context("DeQueue", func() { - It("should remove and return the first item in the queue", func() { - q.EnQueue(1) - q.EnQueue(2) - - val, ok := q.DeQueue() - Expect(ok).Should(BeTrue()) - Expect(val).Should(Equal(1)) - - val, ok = q.DeQueue() - Expect(ok).Should(BeTrue()) - Expect(val).Should(Equal(2)) - - val, ok = q.DeQueue() - Expect(ok).Should(BeFalse()) - Expect(val).Should(BeNil()) - }) - }) - }) -}) diff --git a/pkg/cmd/config.go b/pkg/cmd/config.go index 0b9cd9dcd..6a5d8eb48 100644 --- a/pkg/cmd/config.go +++ b/pkg/cmd/config.go @@ -102,6 +102,10 @@ func NewConfigCommand() *cobra.Command { f.Bool("distributed-enabled", conf.Distributed.Enabled, "enable distributed") f.String("distributed-address", conf.Distributed.Address, "distributed address") f.String("distributed-port", conf.Distributed.Port, "distributed port") + f.Int("distributed-partition-count", conf.Distributed.PartitionCount, "number of partitions for distributed hashing") + f.Int("distributed-replication-factor", conf.Distributed.ReplicationFactor, "number of replicas for distributed hashing") + f.Float64("distributed-load", conf.Distributed.Load, "load factor for distributed hashing") + f.Int("distributed-picker-width", conf.Distributed.PickerWidth, "picker width for distributed hashing") command.PreRun = func(cmd *cobra.Command, args []string) { flags.RegisterServeFlags(f) diff --git a/pkg/cmd/flags/serve.go b/pkg/cmd/flags/serve.go index 365a98d32..9ae65d86b 100644 --- a/pkg/cmd/flags/serve.go +++ b/pkg/cmd/flags/serve.go @@ -563,4 +563,32 @@ func RegisterServeFlags(flags *pflag.FlagSet) { if err = viper.BindEnv("distributed.port", "PERMIFY_DISTRIBUTED_PORT"); err != nil { panic(err) } + + if err = viper.BindPFlag("distributed.partition_count", flags.Lookup("distributed-partition-count")); err != nil { + panic(err) + } + if err = viper.BindEnv("distributed.partition_count", "PERMIFY_DISTRIBUTED_PARTITION_COUNT"); err != nil { + panic(err) + } + + if err = viper.BindPFlag("distributed.replication_factor", flags.Lookup("distributed-replication-factor")); err != nil { + panic(err) + } + if err = viper.BindEnv("distributed.replication_factor", "PERMIFY_DISTRIBUTED_REPLICATION_FACTOR"); err != nil { + panic(err) + } + + if err = viper.BindPFlag("distributed.load", flags.Lookup("distributed-load")); err != nil { + panic(err) + } + if err = viper.BindEnv("distributed.load", "PERMIFY_DISTRIBUTED_LOAD"); err != nil { + panic(err) + } + + if err = viper.BindPFlag("distributed.picker_width", flags.Lookup("distributed-picker-width")); err != nil { + panic(err) + } + if err = viper.BindEnv("distributed.picker_width", "PERMIFY_DISTRIBUTED_PICKER_WIDTH"); err != nil { + panic(err) + } } diff --git a/pkg/cmd/serve.go b/pkg/cmd/serve.go index 7b6fc3da6..3747acf5b 100644 --- a/pkg/cmd/serve.go +++ b/pkg/cmd/serve.go @@ -134,6 +134,10 @@ func NewServeCommand() *cobra.Command { f.Bool("distributed-enabled", conf.Distributed.Enabled, "enable distributed") f.String("distributed-address", conf.Distributed.Address, "distributed address") f.String("distributed-port", conf.Distributed.Port, "distributed port") + f.Int("distributed-partition-count", conf.Distributed.PartitionCount, "number of partitions for distributed hashing") + f.Int("distributed-replication-factor", conf.Distributed.ReplicationFactor, "number of replicas for distributed hashing") + f.Float64("distributed-load", conf.Distributed.Load, "load factor for distributed hashing") + f.Int("distributed-picker-width", conf.Distributed.PickerWidth, "picker width for distributed hashing") // SilenceUsage is set to true to suppress usage when an error occurs command.SilenceUsage = true @@ -450,11 +454,11 @@ func serve() func(cmd *cobra.Command, args []string) error { if err != nil { return err } - checker = cache.NewCheckEngineWithCache( - checker, - schemaReader, - engineKeyCache, - ) + //checker = cache.NewCheckEngineWithCache( + // checker, + // schemaReader, + // engineKeyCache, + //) } else { checker = cache.NewCheckEngineWithCache( checkEngine, diff --git a/pkg/consistent/consistent.go b/pkg/consistent/consistent.go new file mode 100644 index 000000000..2b1117a18 --- /dev/null +++ b/pkg/consistent/consistent.go @@ -0,0 +1,492 @@ +package consistent + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "sort" + "sync" +) + +const ( + // DefaultPartitionCount defines the default number of virtual partitions in the hash ring. + // This helps balance the load distribution among members, even with a small number of members. + DefaultPartitionCount int = 271 + + // DefaultReplicationFactor specifies the default number of replicas for each partition. + // This ensures redundancy and fault tolerance by assigning partitions to multiple members. + DefaultReplicationFactor int = 20 + + // DefaultLoad defines the default maximum load factor for each member. + // A higher value allows members to handle more load before being considered full. + DefaultLoad float64 = 1.25 + + // DefaultPickerWidth determines the default range of candidates considered when picking members. + // This can influence the selection logic in advanced configurations. + DefaultPickerWidth int = 1 +) + +type Hasher func([]byte) uint64 + +type Member interface { + String() string +} + +// Config represents the configuration settings for a specific system or application. +// It includes settings for hashing, partitioning, replication, load balancing, and picker width. +type Config struct { + // Hasher is an interface or implementation used for generating hash values. + // It is typically used to distribute data evenly across partitions. + Hasher Hasher + + // PartitionCount defines the number of partitions in the system. + // This value affects how data is distributed and processed. + PartitionCount int + + // ReplicationFactor specifies the number of replicas for each partition. + // It ensures data redundancy and fault tolerance in the system. + ReplicationFactor int + + // Load represents the load balancing factor for the system. + // It could be a threshold or weight used for distributing work. + Load float64 + + // PickerWidth determines the width or range of the picker mechanism. + // It is typically used to influence how selections are made in certain operations. + PickerWidth int +} + +// Consistent implements a consistent hashing mechanism with partitioning and load balancing. +// It is used for distributing data across a dynamic set of members efficiently. +type Consistent struct { + // mu is a read-write mutex used to protect shared resources from concurrent access. + mu sync.RWMutex + + // config holds the configuration settings for the consistent hashing instance. + config Config + + // hasher is an implementation of the Hasher interface used for generating hash values. + hasher Hasher + + // sortedSet maintains a sorted slice of hash values to represent the hash ring. + sortedSet []uint64 + + // partitionCount specifies the number of partitions in the hash ring. + partitionCount uint64 + + // loads tracks the load distribution for each member in the hash ring. + // The key is the member's identifier, and the value is the load. + loads map[string]float64 + + // members is a map of member identifiers to their corresponding Member struct. + members map[string]*Member + + // partitions maps each partition index to the corresponding member. + partitions map[int]*Member + + // ring is a map that associates each hash value in the ring with a specific member. + ring map[uint64]*Member +} + +// New initializes and returns a new instance of the Consistent struct. +// It takes a Config parameter and applies default values for any unset fields. +func New(config Config) *Consistent { + // Ensure the Hasher is not nil; a nil Hasher would make consistent hashing unusable. + if config.Hasher == nil { + panic("Hasher cannot be nil") + } + + // Set default values for partition count, replication factor, load, and picker width if not provided. + if config.PartitionCount == 0 { + config.PartitionCount = DefaultPartitionCount + } + if config.ReplicationFactor == 0 { + config.ReplicationFactor = DefaultReplicationFactor + } + if config.Load == 0 { + config.Load = DefaultLoad + } + if config.PickerWidth == 0 { + config.PickerWidth = DefaultPickerWidth + } + + // Initialize a new Consistent instance with the provided configuration. + c := &Consistent{ + config: config, + members: make(map[string]*Member), + partitionCount: uint64(config.PartitionCount), + ring: make(map[uint64]*Member), + } + + // Assign the provided Hasher implementation to the instance. + c.hasher = config.Hasher + return c +} + +// Members returns a slice of all the members currently in the consistent hash ring. +// It safely retrieves the members using a read lock to prevent data races while +// accessing the shared `members` map. +func (c *Consistent) Members() []Member { + // Acquire a read lock to ensure thread-safe access to the members map. + c.mu.RLock() + defer c.mu.RUnlock() + + // Create a slice to hold the members, pre-allocating its capacity to avoid resizing. + members := make([]Member, 0, len(c.members)) + + // Iterate over the members map and append each member to the slice. + for _, member := range c.members { + members = append(members, *member) + } + + // Return the slice of members. + return members +} + +// GetAverageLoad calculates and returns the current average load across all members. +// It is a public method that provides thread-safe access to the load calculation. +func (c *Consistent) GetAverageLoad() float64 { + // Acquire a read lock to ensure thread-safe access to shared resources. + c.mu.RLock() + defer c.mu.RUnlock() + + // Delegate the actual load calculation to the internal helper method. + return c.calculateAverageLoad() +} + +// calculateAverageLoad is a private helper method that performs the actual calculation +// of the average load across all members. It is not thread-safe and should be called +// only from within methods that already manage locking. +func (c *Consistent) calculateAverageLoad() float64 { + // If there are no members, return an average load of 0 to prevent division by zero. + if len(c.members) == 0 { + return 0 + } + + // Calculate the average load by dividing the total partition count by the number of members + // and multiplying by the configured load factor. + avgLoad := float64(c.partitionCount/uint64(len(c.members))) * c.config.Load + + // Use math.Ceil to round up the average load to the nearest whole number. + return math.Ceil(avgLoad) +} + +// assignPartitionWithLoad distributes a partition to a member based on the load factor. +// It ensures that no member exceeds the calculated average load while distributing partitions. +// If the distribution fails due to insufficient capacity, it panics with an error message. +func (c *Consistent) assignPartitionWithLoad( + partitionID, startIndex int, + partitionAssignments map[int]*Member, + memberLoads map[string]float64, +) { + // Calculate the average load to determine the maximum load a member can handle. + averageLoad := c.calculateAverageLoad() + var attempts int + + // Iterate to find a suitable member for the partition. + for { + attempts++ + + // If the loop exceeds the number of members, it indicates that the partition + // cannot be distributed with the current configuration. + if attempts >= len(c.sortedSet) { + panic("not enough capacity to distribute partitions: consider decreasing the partition count, increasing the member count, or increasing the load factor") + } + + // Get the current hash value from the sorted set. + currentHash := c.sortedSet[startIndex] + + // Retrieve the member associated with the hash value. + currentMember := *c.ring[currentHash] + + // Check the current load of the member. + currentLoad := memberLoads[currentMember.String()] + + // If the member's load is within the acceptable range, assign the partition. + if currentLoad+1 <= averageLoad { + partitionAssignments[partitionID] = ¤tMember + memberLoads[currentMember.String()]++ + return + } + + // Move to the next member in the sorted set. + startIndex++ + if startIndex >= len(c.sortedSet) { + // Loop back to the beginning of the sorted set if we reach the end. + startIndex = 0 + } + } +} + +// distributePartitions evenly distributes partitions among members while respecting the load factor. +// It ensures that partitions are assigned to members based on consistent hashing and load constraints. +func (c *Consistent) distributePartitions() { + // Initialize maps to track the load for each member and partition assignments. + memberLoads := make(map[string]float64) + partitionAssignments := make(map[int]*Member) + + // Create a buffer for converting partition IDs into byte slices for hashing. + partitionKeyBuffer := make([]byte, 8) + + // Iterate over all partition IDs to distribute them among members. + for partitionID := uint64(0); partitionID < c.partitionCount; partitionID++ { + // Convert the partition ID into a byte slice for hashing. + binary.LittleEndian.PutUint64(partitionKeyBuffer, partitionID) + + // Generate a hash key for the partition using the configured hasher. + hashKey := c.hasher(partitionKeyBuffer) + + // Find the index of the member in the sorted set where the hash key should be placed. + index := sort.Search(len(c.sortedSet), func(i int) bool { + return c.sortedSet[i] >= hashKey + }) + + // If the index is beyond the end of the sorted set, wrap around to the beginning. + if index >= len(c.sortedSet) { + index = 0 + } + + // Assign the partition to a member, ensuring the load factor is respected. + c.assignPartitionWithLoad(int(partitionID), index, partitionAssignments, memberLoads) + } + + // Update the Consistent instance with the new partition assignments and member loads. + c.partitions = partitionAssignments + c.loads = memberLoads +} + +// addMemberToRing adds a member to the consistent hash ring and updates the sorted set of hashes. +func (c *Consistent) addMemberToRing(member Member) { + // Add replication factor entries for the member in the hash ring. + for replicaIndex := 0; replicaIndex < c.config.ReplicationFactor; replicaIndex++ { + // Generate a unique key for each replica of the member. + replicaKey := []byte(fmt.Sprintf("%s%d", member.String(), replicaIndex)) + hashValue := c.hasher(replicaKey) + + // Add the hash value to the ring and associate it with the member. + c.ring[hashValue] = &member + + // Append the hash value to the sorted set of hashes. + c.sortedSet = append(c.sortedSet, hashValue) + } + + // Sort the hash values to maintain the ring's order. + sort.Slice(c.sortedSet, func(i, j int) bool { + return c.sortedSet[i] < c.sortedSet[j] + }) + + // Add the member to the members map. + c.members[member.String()] = &member +} + +// Add safely adds a new member to the consistent hash circle. +// It ensures thread safety and redistributes partitions after adding the member. +func (c *Consistent) Add(member Member) { + // Acquire a write lock to ensure thread safety. + c.mu.Lock() + defer c.mu.Unlock() + + // Check if the member already exists in the ring. If it does, exit early. + if _, exists := c.members[member.String()]; exists { + return + } + + // Add the member to the ring and redistribute partitions. + c.addMemberToRing(member) + c.distributePartitions() +} + +// removeFromSortedSet removes a hash value from the sorted set of hashes. +func (c *Consistent) removeFromSortedSet(hashValue uint64) { + for i := 0; i < len(c.sortedSet); i++ { + if c.sortedSet[i] == hashValue { + // Remove the hash value by slicing the sorted set. + c.sortedSet = append(c.sortedSet[:i], c.sortedSet[i+1:]...) + break + } + } +} + +// Remove deletes a member from the consistent hash circle and redistributes partitions. +// If the member does not exist, the method exits early. +func (c *Consistent) Remove(memberName string) { + // Acquire a write lock to ensure thread-safe access. + c.mu.Lock() + defer c.mu.Unlock() + + // Check if the member exists in the hash ring. If not, exit early. + if _, exists := c.members[memberName]; !exists { + return + } + + // Remove all replicas of the member from the hash ring and sorted set. + for replicaIndex := 0; replicaIndex < c.config.ReplicationFactor; replicaIndex++ { + // Generate the unique key for each replica of the member. + replicaKey := []byte(fmt.Sprintf("%s%d", memberName, replicaIndex)) + hashValue := c.hasher(replicaKey) + + // Remove the hash value from the hash ring. + delete(c.ring, hashValue) + + // Remove the hash value from the sorted set. + c.removeFromSortedSet(hashValue) + } + + // Remove the member from the members map. + delete(c.members, memberName) + + // If no members remain, reset the partition table and exit. + if len(c.members) == 0 { + c.partitions = make(map[int]*Member) + return + } + + // Redistribute partitions among the remaining members. + c.distributePartitions() +} + +// GetLoadDistribution provides a thread-safe snapshot of the current load distribution across members. +// It returns a map where the keys are member identifiers and the values are their respective loads. +func (c *Consistent) GetLoadDistribution() map[string]float64 { + // Acquire a read lock to ensure thread-safe access to the loads map. + c.mu.RLock() + defer c.mu.RUnlock() + + // Create a copy of the loads map to avoid exposing internal state. + loadDistribution := make(map[string]float64) + for memberName, memberLoad := range c.loads { + loadDistribution[memberName] = memberLoad + } + + return loadDistribution +} + +// GetPartitionID calculates and returns the partition ID for a given key. +// The partition ID is determined by hashing the key and applying modulo operation with the partition count. +func (c *Consistent) GetPartitionID(key []byte) int { + // Generate a hash value for the given key using the configured hasher. + hashValue := c.hasher(key) + + // Calculate the partition ID by taking the modulus of the hash value with the partition count. + return int(hashValue % c.partitionCount) +} + +// GetPartitionOwner retrieves the owner of the specified partition in a thread-safe manner. +// It ensures that the access to shared resources is synchronized. +func (c *Consistent) GetPartitionOwner(partitionID int) Member { + // Acquire a read lock to ensure thread-safe access to the partitions map. + c.mu.RLock() + defer c.mu.RUnlock() + + // Delegate the actual lookup to the non-thread-safe internal helper function. + return c.getPartitionOwnerInternal(partitionID) +} + +// getPartitionOwnerInternal retrieves the owner of the specified partition without thread safety. +// This function assumes that synchronization has been handled by the caller. +func (c *Consistent) getPartitionOwnerInternal(partitionID int) Member { + // Lookup the member associated with the given partition ID. + member, exists := c.partitions[partitionID] + if !exists { + // If the partition ID does not exist, return a nil Member. + return nil + } + + // Return a copy of the member to ensure thread safety. + return *member +} + +// LocateKey determines the owner of the partition corresponding to the given key. +// It calculates the partition ID for the key and retrieves the associated member in a thread-safe manner. +func (c *Consistent) LocateKey(key []byte) Member { + // Calculate the partition ID based on the hash of the key. + partitionID := c.GetPartitionID(key) + + // Retrieve the owner of the partition using the thread-safe method. + return c.GetPartitionOwner(partitionID) +} + +// closestN retrieves the closest N members to the given partition ID in the consistent hash ring. +// It ensures thread-safe access and validates that the requested count of members can be satisfied. +func (c *Consistent) closestN(partitionID, count int) ([]Member, error) { + // Acquire a read lock to ensure thread-safe access to the members map. + c.mu.RLock() + defer c.mu.RUnlock() + + // Validate that the requested number of members can be satisfied. + if count > len(c.members) { + return nil, errors.New("not enough members to satisfy the request") + } + + // Prepare a result slice to store the closest members. + var closestMembers []Member + + // Get the owner of the given partition. + partitionOwner := c.getPartitionOwnerInternal(partitionID) + var partitionOwnerHash uint64 + + // Build a hash ring by hashing all member names. + var memberHashes []uint64 + hashToMember := make(map[uint64]*Member) + for memberName, member := range c.members { + // Compute the hash value for each member name. + hash := c.hasher([]byte(memberName)) + + // Track the hash for the partition owner. + if memberName == partitionOwner.String() { + partitionOwnerHash = hash + } + + // Append the hash value and map it to the corresponding member. + memberHashes = append(memberHashes, hash) + hashToMember[hash] = member + } + + // Sort the hash values to create a consistent hash ring. + sort.Slice(memberHashes, func(i, j int) bool { + return memberHashes[i] < memberHashes[j] + }) + + // Find the index of the partition owner's hash in the sorted hash ring. + ownerIndex := -1 + for i, hash := range memberHashes { + if hash == partitionOwnerHash { + ownerIndex = i + closestMembers = append(closestMembers, *hashToMember[hash]) + break + } + } + + // If the partition owner's hash is not found (unexpected), return an error. + if ownerIndex == -1 { + return nil, errors.New("partition owner not found in hash ring") + } + + // Find the additional closest members by iterating around the hash ring. + currentIndex := ownerIndex + for len(closestMembers) < count { + // Move to the next hash in the ring, wrapping around if necessary. + currentIndex++ + if currentIndex >= len(memberHashes) { + currentIndex = 0 + } + + // Add the member corresponding to the current hash to the result. + hash := memberHashes[currentIndex] + closestMembers = append(closestMembers, *hashToMember[hash]) + } + + return closestMembers, nil +} + +// ClosestN calculates the closest N members to a given key in the consistent hash ring. +// It uses the key to determine the partition ID and then retrieves the closest members. +// This is useful for identifying members for replication or redundancy. +func (c *Consistent) ClosestN(key []byte, count int) ([]Member, error) { + // Calculate the partition ID based on the hash of the key. + partitionID := c.GetPartitionID(key) + + // Retrieve the closest N members for the calculated partition ID. + return c.closestN(partitionID, count) +} diff --git a/pkg/consistent/consistent_test.go b/pkg/consistent/consistent_test.go new file mode 100644 index 000000000..6503ef1ec --- /dev/null +++ b/pkg/consistent/consistent_test.go @@ -0,0 +1,129 @@ +package consistent + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// This is the entry point for the test suite for the "consistent" package. +// It registers a failure handler and runs the specifications (specs) for this package. +func TestConsistent(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "consistent-suite") +} + +type TestMember string + +func (m TestMember) String() string { + return string(m) +} + +var _ = Describe("Consistent", func() { + var ( + config Config + consist *Consistent + hasher Hasher + testMember Member + ) + + BeforeEach(func() { + hasher = func(data []byte) uint64 { + var hash uint64 + for _, b := range data { + hash = hash*31 + uint64(b) + } + return hash + } + + config = Config{ + Hasher: hasher, + PartitionCount: 271, + ReplicationFactor: 20, + Load: 1.25, + PickerWidth: 1, + } + + consist = New(config) + }) + + Describe("Initialization", func() { + It("should initialize with default values when config is incomplete", func() { + incompleteConfig := Config{Hasher: hasher} + instance := New(incompleteConfig) + Expect(instance).NotTo(BeNil()) + Expect(instance.GetAverageLoad()).To(BeNumerically("==", float64(0))) + }) + }) + + Describe("Member Management", func() { + BeforeEach(func() { + testMember = TestMember("member1") + }) + + It("should add a member to the consistent hash ring", func() { + consist.Add(testMember) + members := consist.Members() + Expect(members).To(HaveLen(1)) + Expect(members[0].String()).To(Equal(testMember.String())) + }) + + It("should not add the same member twice", func() { + consist.Add(testMember) + consist.Add(testMember) + members := consist.Members() + Expect(members).To(HaveLen(1)) + }) + + It("should remove a member from the consistent hash ring", func() { + consist.Add(testMember) + consist.Remove(testMember.String()) + members := consist.Members() + Expect(members).To(BeEmpty()) + }) + }) + + Describe("Partition Management", func() { + BeforeEach(func() { + consist.Add(TestMember("member1")) + consist.Add(TestMember("member2")) + consist.Add(TestMember("member3")) + }) + + It("should distribute partitions across members without exceeding average load", func() { + loads := consist.GetLoadDistribution() + Expect(loads).To(HaveLen(3)) // There are 3 members. + + // Calculate the maximum allowed load per member. + expectedAverageLoad := consist.GetAverageLoad() + + // Verify each member's load does not exceed the expected average load. + for _, load := range loads { + Expect(load).To(BeNumerically("<=", expectedAverageLoad)) + } + }) + + It("should locate the correct partition owner for a key", func() { + key := []byte("test_key") + member := consist.LocateKey(key) + Expect(member).NotTo(BeNil()) + }) + + It("should find the closest N members", func() { + key := []byte("test_key") + members, err := consist.ClosestN(key, 2) + Expect(err).NotTo(HaveOccurred()) + Expect(members).To(HaveLen(2)) + }) + }) + + Describe("Error Handling", func() { + It("should return an error when requesting more members than available", func() { + key := []byte("test_key") + _, err := consist.ClosestN(key, 5) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not enough members to satisfy the request")) + }) + }) +}) diff --git a/pkg/pb/base/v1/openapi.pb.go b/pkg/pb/base/v1/openapi.pb.go index f191734fb..db1778f9b 100644 --- a/pkg/pb/base/v1/openapi.pb.go +++ b/pkg/pb/base/v1/openapi.pb.go @@ -46,7 +46,7 @@ var file_base_v1_openapi_proto_rawDesc = []byte{ 0x2f, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x66, 0x79, 0x2f, 0x70, 0x65, 0x72, 0x6d, 0x69, 0x66, 0x79, 0x2f, 0x62, 0x6c, 0x6f, 0x62, 0x2f, 0x6d, 0x61, 0x73, 0x74, 0x65, 0x72, 0x2f, 0x4c, 0x49, 0x43, 0x45, 0x4e, 0x53, 0x45, - 0x32, 0x06, 0x76, 0x31, 0x2e, 0x32, 0x2e, 0x37, 0x2a, 0x01, 0x02, 0x32, 0x10, 0x61, 0x70, 0x70, + 0x32, 0x06, 0x76, 0x31, 0x2e, 0x32, 0x2e, 0x38, 0x2a, 0x01, 0x02, 0x32, 0x10, 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x6a, 0x73, 0x6f, 0x6e, 0x3a, 0x10, 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x6a, 0x73, 0x6f, 0x6e, 0x5a, 0x23, 0x0a, 0x21, 0x0a, 0x0a, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x41, 0x75, 0x74, 0x68, 0x12, diff --git a/proto/base/v1/openapi.proto b/proto/base/v1/openapi.proto index 1d68954a5..43fab8ab8 100644 --- a/proto/base/v1/openapi.proto +++ b/proto/base/v1/openapi.proto @@ -9,7 +9,7 @@ option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = { info: { title: "Permify API"; description: "Permify is an open source authorization service for creating fine-grained and scalable authorization systems."; - version: "v1.2.7"; + version: "v1.2.8"; contact: { name: "API Support"; url: "https://github.com/Permify/permify/issues"; From 56932e6e60d7422e65c9bffc7006fd5fee9f85a8 Mon Sep 17 00:00:00 2001 From: Tolga Ozen Date: Tue, 28 Jan 2025 17:17:40 +0300 Subject: [PATCH 2/3] test: add picker logic tests for member picking and error handling. --- pkg/balancer/picker_test.go | 105 ++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 pkg/balancer/picker_test.go diff --git a/pkg/balancer/picker_test.go b/pkg/balancer/picker_test.go new file mode 100644 index 000000000..8534ff7c6 --- /dev/null +++ b/pkg/balancer/picker_test.go @@ -0,0 +1,105 @@ +package balancer + +import ( + "context" + "crypto/sha256" + "encoding/binary" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "google.golang.org/grpc/balancer" + + "github.com/Permify/permify/pkg/consistent" +) + +type mockConnection struct { + id string +} + +type testMember struct { + name string + conn *mockConnection +} + +func (m testMember) String() string { + return m.name +} + +func (m testMember) Connection() *mockConnection { + return m.conn +} + +var _ = Describe("Picker and Consistent Hashing", func() { + var ( + c *consistent.Consistent + testMembers []testMember + hasher func(data []byte) uint64 + ) + + // Custom hasher using SHA-256 for consistent hashing + hasher = func(data []byte) uint64 { + hash := sha256.Sum256(data) + return binary.BigEndian.Uint64(hash[:8]) // Use the first 8 bytes as the hash + } + + BeforeEach(func() { + // Initialize consistent hashing with a valid hasher + c = consistent.New(consistent.Config{ + Hasher: hasher, + PartitionCount: 100, + ReplicationFactor: 2, + Load: 1.5, + }) + + // Add test members to the consistent hash ring + testMembers = []testMember{ + {name: "member1", conn: &mockConnection{id: "conn1"}}, + {name: "member2", conn: &mockConnection{id: "conn2"}}, + {name: "member3", conn: &mockConnection{id: "conn3"}}, + } + for _, m := range testMembers { + c.Add(m) + } + }) + + Describe("Picker Logic", func() { + var ( + p *picker + testCtx context.Context + ) + + BeforeEach(func() { + // Initialize picker with consistent hashing and a width of 2 + p = &picker{ + consistent: c, + width: 2, + } + // Set up context with a valid key + testCtx = context.WithValue(context.Background(), Key, []byte("test-key")) + }) + + It("should pick a member successfully", func() { + // Mock picker behavior + members, err := c.ClosestN([]byte("test-key"), 2) + Expect(err).To(BeNil()) + Expect(len(members)).To(BeNumerically(">", 0)) + Expect(members[0].String()).To(Equal("member1")) + }) + + It("should return an error if the context key is missing", func() { + result, err := p.Pick(balancer.PickInfo{Ctx: context.Background()}) + Expect(err).To(MatchError("context key missing")) + Expect(result.SubConn).To(BeNil()) + }) + + It("should return an error if no members are available", func() { + // Remove all members + for _, m := range testMembers { + c.Remove(m.String()) + } + result, err := p.Pick(balancer.PickInfo{Ctx: testCtx}) + Expect(err).To(MatchError("failed to get closest members: not enough members to satisfy the request")) + Expect(result.SubConn).To(BeNil()) + }) + }) +}) From 16a2e3a52fc879cd636ac8130bbd0c0830b54578 Mon Sep 17 00:00:00 2001 From: Tolga Ozen Date: Tue, 28 Jan 2025 17:22:01 +0300 Subject: [PATCH 3/3] refactor(pkg/balancer): Refactor Config struct default value checks --- pkg/balancer/builder.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/balancer/builder.go b/pkg/balancer/builder.go index ef875f5cc..b2ec4d62f 100644 --- a/pkg/balancer/builder.go +++ b/pkg/balancer/builder.go @@ -38,16 +38,16 @@ func (c *Config) ServiceConfigJSON() (string, error) { } // Apply default values for zero fields. - if c.PartitionCount == 0 { + if c.PartitionCount <= 0 { c.PartitionCount = consistent.DefaultPartitionCount } - if c.ReplicationFactor == 0 { + if c.ReplicationFactor <= 0 { c.ReplicationFactor = consistent.DefaultReplicationFactor } - if c.Load == 0 { + if c.Load <= 1.0 { c.Load = consistent.DefaultLoad } - if c.PickerWidth == 0 { + if c.PickerWidth < 1 { c.PickerWidth = consistent.DefaultPickerWidth } @@ -128,16 +128,16 @@ func (b *builder) ParseConfig(rm json.RawMessage) (serviceconfig.LoadBalancingCo ) // Set default values for configuration if not provided. - if cfg.PartitionCount == 0 { + if cfg.PartitionCount <= 0 { cfg.PartitionCount = consistent.DefaultPartitionCount } - if cfg.ReplicationFactor == 0 { + if cfg.ReplicationFactor <= 0 { cfg.ReplicationFactor = consistent.DefaultReplicationFactor } - if cfg.Load == 0 { + if cfg.Load <= 1.0 { cfg.Load = consistent.DefaultLoad } - if cfg.PickerWidth == 0 { + if cfg.PickerWidth < 1 { cfg.PickerWidth = consistent.DefaultPickerWidth }