From a9cea3f55637f6252848b31b9afde19a55d550d2 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Sun, 19 Jan 2025 19:08:32 +0900 Subject: [PATCH] feat: referral (#459) * demo: referral * update considerations * remove unused * global getter * update getter * time guard --- referral/doc.gno | 109 +++++++ referral/errors.gno | 63 ++++ referral/global_keeper.gno | 41 +++ referral/gno.mod | 1 + referral/keeper.gno | 177 ++++++++++ referral/keeper_test.gno | 638 +++++++++++++++++++++++++++++++++++++ referral/referral.gno | 54 ++++ referral/type.gno | 38 +++ referral/utils.gno | 28 ++ setup.py | 2 +- 10 files changed, 1150 insertions(+), 1 deletion(-) create mode 100644 referral/doc.gno create mode 100644 referral/errors.gno create mode 100644 referral/global_keeper.gno create mode 100644 referral/gno.mod create mode 100644 referral/keeper.gno create mode 100644 referral/keeper_test.gno create mode 100644 referral/referral.gno create mode 100644 referral/type.gno create mode 100644 referral/utils.gno diff --git a/referral/doc.gno b/referral/doc.gno new file mode 100644 index 000000000..61d5188fb --- /dev/null +++ b/referral/doc.gno @@ -0,0 +1,109 @@ +// Package referral implements a referral system on Gno. It allows +// any authorized caller to register, update, or remove referral +// information. A referral link is defined as a mapping from one +// address (the "user") to another address (the "referrer"). +// +// ## Overview +// +// The referral package is composed of the following components: +// +// 1. **errors.gno**: Provides custom error types (ReferralError) with +// specific error codes and messages. +// 2. **utils.gno**: Contains utility functions for permission checks, +// especially isValidCaller, which ensures only specific, pre-authorized +// callers (e.g., governance or router addresses) can invoke the core +// functions. +// 3. **types.gno**: Defines core constants for event types, attributes, +// and the ReferralKeeper interface, which outlines the fundamental +// methods of the referral system (Register, Update, Remove, etc.). +// 4. **keeper.gno**: Implements the actual business logic behind the +// ReferralKeeper interface. It uses an AVL Tree (avl.Tree) to store +// referral data (address -> referrer). The keeper methods emit events +// when a new referral is registered, updated, or removed. +// 5. **referral.gno**: Exposes a public API (the Referral struct) +// that delegates to the keeper, providing external contracts or +// applications a straightforward way to interact with the system. +// +// ## Workflow +// +// Typical usage of this contract follows these steps: +// +// 1. A caller with valid permissions invokes Register, Update, or Remove +// through the Referral struct. +// 2. The Referral struct forwards the request to the internal keeper +// methods. +// 3. The keeper checks caller permission (via isValidCaller), validates +// addresses, and stores or removes data in the AVL Tree. +// 4. An event is emitted for off-chain or cross-module notifications. +// +// ## Integration with Other Contracts +// +// Other contracts can leverage the referral system in two major ways: +// +// 1. **Direct Calls**: If you wish to directly call this contract, +// instantiate the Referral object (via NewReferral) and invoke its +// methods, assuming you meet the authorized-caller criteria. +// +// 2. **Embedded or Extended**: If you have a complex module that includes +// referral features, import this package and embed a Referral instance +// in your own keeper. This way, you can handle additional validations +// or custom logic before delegating to the existing referral functions. +// +// ## Error Handling +// +// The package defines several error types through ReferralError: +// - `ErrInvalidAddress`: Returned when an address format is invalid +// - `ErrUnauthorized`: Returned when the caller lacks permission +// - `ErrNotFound`: Returned when attempting to get a non-existent referral +// - `ErrZeroAddress`: Returned when attempting operations with zero address +// +// ## Example: Integration with a Staking Contract +// +// Suppose you have a staking contract that wants to reward referrers +// when a new user stakes tokens: +// +// ```go +// +// import ( +// "std" +// "gno.land/r/gnoswap/v1/referral" +// "gno.land/p/demo/mystaking" // example staking contract +// ) +// +// func rewardReferrerOnStake(user std.Address, amount int) { +// // 1) Access the referral system +// r := referral.NewReferral() +// +// // 2) Get the user's referrer +// refAddr, err := r.GetReferral(user) +// if err != nil { +// // handle error or skip if not found +// return +// } +// +// // 3) Reward the referrer +// mystaking.AddReward(refAddr, calculateReward(amount)) +// } +// +// ``` +// +// In this simple example, the staking contract checks if the user has +// a referrer by calling `GetReferral`. If a referrer is found, it then +// calculates a reward based on the staked amount. +// +// ## Limitations and Constraints +// +// - A user can have only one referrer at a time +// - Once a referral is removed, it cannot be automatically restored +// - Only authorized contracts can modify referral relationships +// - Address validation is strict and requires proper Bech32 format +// +// # Notes +// +// - The contract strictly enforces caller restrictions via isValidCaller. +// Make sure to configure it to permit only the addresses or roles that +// should be able to register or update referrals. +// - Zero addresses are treated as a trigger for removing a referral record. +// - The system emits events (register_referral, update_referral, remove_referral) +// which can be consumed by other on-chain or off-chain services. +package referral diff --git a/referral/errors.gno b/referral/errors.gno new file mode 100644 index 000000000..f0f1fbb48 --- /dev/null +++ b/referral/errors.gno @@ -0,0 +1,63 @@ +package referral + +import ( + "gno.land/p/demo/ufmt" +) + +const ( + ErrNone = iota + ErrInvalidAddress + ErrZeroAddress + ErrSelfReferral + ErrUnauthorized + ErrInvalidCaller + ErrCyclicReference + ErrTooManyRequests + ErrNotFound +) + +var errorMessages = map[int]string{ + ErrInvalidAddress: "invalid address format", + ErrZeroAddress: "zero address is not allowed", + ErrSelfReferral: "self referral is not allowed", + ErrUnauthorized: "unauthorized caller", + ErrInvalidCaller: "invalid caller", + ErrCyclicReference: "cyclic reference is not allowed", + ErrTooManyRequests: "too many requests: operations allowed once per 24 hours for each address", + ErrNotFound: "referral not found", +} + +type ReferralError struct { + Code int + Message string + Err error +} + +func (e *ReferralError) Error() string { + // TODO: format error message to follow previous error message format + if e.Err != nil { + return ufmt.Sprintf("code: %d, message: %s, error: %v", e.Code, e.Message, e.Err) + } + return ufmt.Sprintf("code: %d, message: %s", e.Code, e.Message) +} + +func (e *ReferralError) Unwrap() error { + return e.Err +} + +func NewError(code int, args ...interface{}) *ReferralError { + msg := errorMessages[code] + var err error + + if len(args) > 0 { + if lastArg, ok := args[len(args)-1].(error); ok { + err = lastArg + } + } + + return &ReferralError{ + Code: code, + Message: msg, + Err: err, + } +} diff --git a/referral/global_keeper.gno b/referral/global_keeper.gno new file mode 100644 index 000000000..ac25c237f --- /dev/null +++ b/referral/global_keeper.gno @@ -0,0 +1,41 @@ +package referral + +import "std" + +// gReferralKeeper is the global instance of the referral keeper +var gReferralKeeper ReferralKeeper + +func init() { + gReferralKeeper = NewKeeper() +} + +// GetKeeper returns the global instance of the referral keeper +// +// Example: +// +// // In other packages: +// keeper := referral.GetKeeper() +// keeper.register(addr, refAddr) +func GetKeeper() ReferralKeeper { + return gReferralKeeper +} + +func GetReferral(addr string) string { + referral, err := gReferralKeeper.get(std.Address(addr)) + if err != nil { + panic(err) + } + return referral.String() +} + +func HasReferral(addr string) bool { + referral, err := gReferralKeeper.get(std.Address(addr)) + if err != nil { + return false + } + return referral != zeroAddress +} + +func IsEmpty() bool { + return gReferralKeeper.isEmpty() +} diff --git a/referral/gno.mod b/referral/gno.mod new file mode 100644 index 000000000..3951be4db --- /dev/null +++ b/referral/gno.mod @@ -0,0 +1 @@ +module gno.land/r/gnoswap/v1/referral diff --git a/referral/keeper.gno b/referral/keeper.gno new file mode 100644 index 000000000..911dd2e8f --- /dev/null +++ b/referral/keeper.gno @@ -0,0 +1,177 @@ +package referral + +import ( + "std" + "time" + + "gno.land/p/demo/avl" + "gno.land/p/demo/ufmt" +) + +const ( + // MinTimeBetweenUpdates represents minimum duration between operations + MinTimeBetweenUpdates int64 = 24 * 60 * 60 +) + +// keeper implements the `ReferralKeeper` interface using an AVL tree for storage. +type keeper struct { + store *avl.Tree + lastOps map[string]int64 +} + +// check interface implementation at compile time +var _ ReferralKeeper = &keeper{} + +// NewKeeper creates and returns a new instance of ReferralKeeper. +// The keeper is initialized with an empty AVL tree for storing referral relationships. +func NewKeeper() ReferralKeeper { + return &keeper{ + store: avl.NewTree(), + lastOps: make(map[string]int64), + } +} + +// register implements the `register` method of the `ReferralKeeper` interface. +// It sets a new referral relationship between the given address and referral address. +func (k *keeper) register(addr, refAddr std.Address) error { + return k.setReferral(addr, refAddr, EventTypeRegister) +} + +// update implements the `update` method of the `ReferralKeeper` interface. +// It updates the referral address for a given address. +func (k *keeper) update(addr, newRefAddr std.Address) error { + return k.setReferral(addr, newRefAddr, EventTypeUpdate) +} + +// setReferral handles the common logic for registering and updating referrals. +// It validates the addresses and caller, then stores the referral relationship. +// +// Note: The current implementation allows circular references, but since it only manages +// simple reference relationships, cycles are not a significant issue. However, when introducing +// a referral-based reward system in the future or adding business logic where cycles could cause problems, +// it will be necessary to implement validation checks. +// +// TODO: need to discuss what values to emit as event +func (k *keeper) setReferral(addr, refAddr std.Address, eventType string) error { + if err := isValidCaller(std.PrevRealm().Addr()); err != nil { + return err + } + + if err := k.validateAddresses(addr, refAddr); err != nil { + return err + } + + if err := k.checkRateLimit(addr.String()); err != nil { + // XXX: because of the weird type-related errors, this is the only + // part where we need to use `ufmt.Errorf` to build error message. + return ufmt.Errorf("too many requests") + } + + if refAddr == zeroAddress { + std.Emit( + EventTypeRemove, + "removedAddress", addr.String(), + ) + return k.remove(addr) + } + + k.store.Set(addr.String(), refAddr.String()) + std.Emit( + eventType, + "myAddress", addr.String(), + "refAddress", refAddr.String(), + ) + + return nil +} + +// validateAddresses checks if the given addresses are valid for referral +func (k *keeper) validateAddresses(addr, refAddr std.Address) error { + if !addr.IsValid() || (refAddr != zeroAddress && !refAddr.IsValid()) { + return NewError(ErrInvalidAddress) + } + if addr == refAddr { + return NewError(ErrSelfReferral) + } + return nil +} + +// remove implements the `remove` method of the `ReferralKeeper` interface. +// It validates the caller and address before removing the referral relationship. +// +// TODO: need to discuss what values to emit as event +func (k *keeper) remove(target std.Address) error { + if err := isValidCaller(std.PrevRealm().Addr()); err != nil { + return err + } + + if !target.IsValid() { + return NewError(ErrInvalidAddress) + } + + if err := k.checkRateLimit(target.String()); err != nil { + return err + } + + k.store.Remove(target.String()) + + // TODO: update event + std.Emit( + EventTypeRemove, + "removedAddress", target.String(), + ) + + return nil +} + +// has implements the `has` method of the `ReferralKeeper` interface. +// It checks if a referral relationship exists for a given address. +func (k *keeper) has(addr std.Address) bool { + _, exists := k.store.Get(addr.String()) + return exists +} + +// get implements the `get` method of the `ReferralKeeper` interface. +// It retrieves the referral address for a given address. +func (k *keeper) get(addr std.Address) (std.Address, error) { + if !addr.IsValid() { + return zeroAddress, NewError(ErrInvalidAddress) + } + + val, ok := k.store.Get(addr.String()) + if !ok { + return zeroAddress, NewError(ErrNotFound) + } + + refAddr, ok := val.(string) + if !ok { + return zeroAddress, NewError(ErrInvalidAddress) + } + + return std.Address(refAddr), nil +} + +func (k *keeper) isEmpty() bool { + empty := true + k.store.Iterate("", "", func(key string, value interface{}) bool { + empty = false + return true // stop iteration on first item + }) + return empty +} + +// checkRateLimit verifies if enough time has passed since the last operation +func (k *keeper) checkRateLimit(addr string) error { + now := time.Now().Unix() + + if lastOpTime, exists := k.lastOps[addr]; exists { + timeSinceLastOp := now - lastOpTime + + if timeSinceLastOp < MinTimeBetweenUpdates { + return NewError(ErrTooManyRequests) + } + } + + k.lastOps[addr] = now + return nil +} diff --git a/referral/keeper_test.gno b/referral/keeper_test.gno new file mode 100644 index 000000000..f3d20b869 --- /dev/null +++ b/referral/keeper_test.gno @@ -0,0 +1,638 @@ +package referral + +import ( + "std" + "testing" + "time" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/consts" +) + +const STRESS_TEST_NUM = 10000 // arbitrary number + +var ( + validAddr1 = testutils.TestAddress("valid1") + validAddr2 = testutils.TestAddress("valid2") + invalidAddr = testutils.TestAddress("invalid") +) + +// time mocking +var currentTime int64 = time.Now().Unix() + +func mockTimeNow() time.Time { + return time.Unix(currentTime, 0) +} + +func setupKeeper() *keeper { return NewKeeper().(*keeper) } + +func mockValidCaller() func() { + origCaller := std.GetOrigCaller() + std.TestSetOrigCaller(consts.ROUTER_ADDR) + return func() { + std.TestSetOrigCaller(origCaller) + } +} + +func TestRegister(t *testing.T) { + tests := []struct { + name string + addr std.Address + refAddr std.Address + setupCaller func() func() + wantErr bool + errCode int + }{ + { + name: "valid registration", + addr: validAddr1, + refAddr: validAddr2, + setupCaller: mockValidCaller, + wantErr: false, + }, + { + name: "unauthorized caller", + addr: validAddr1, + refAddr: validAddr2, + setupCaller: func() func() { + origCaller := std.GetOrigCaller() + std.TestSetOrigCaller(std.Address("unauthorized")) + return func() { + std.TestSetOrigCaller(origCaller) + } + }, + wantErr: true, + errCode: ErrUnauthorized, + }, + { + name: "self referral", + addr: validAddr1, + refAddr: validAddr1, + setupCaller: mockValidCaller, + wantErr: true, + errCode: ErrSelfReferral, + }, + { + name: "zero address referral", + addr: validAddr1, + refAddr: zeroAddress, + setupCaller: mockValidCaller, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + cleanup := tt.setupCaller() + defer cleanup() + + err := k.register(tt.addr, tt.refAddr) + + if tt.wantErr { + if err == nil { + t.Errorf("register() error = nil, wantErr %v", tt.wantErr) + return + } + referralErr, ok := err.(*ReferralError) + if !ok { + t.Errorf("register() error is not ReferralError type") + return + } + if referralErr.Code != tt.errCode { + t.Errorf("register() error code = %v, want %v", referralErr.Code, tt.errCode) + } + } else if err != nil { + t.Errorf("register() unexpected error = %v", err) + } + }) + } +} + +func TestUpdate(t *testing.T) { + tests := []struct { + name string + addr std.Address + refAddr std.Address + setupState func(*keeper) + setupCaller func() func() + wantErr bool + errCode int + }{ + { + name: "valid update", + addr: validAddr1, + refAddr: validAddr2, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), "old_ref_addr") + }, + setupCaller: mockValidCaller, + wantErr: false, + }, + { + name: "update non-existent referral", + addr: validAddr1, + refAddr: validAddr2, + setupState: func(k *keeper) {}, + setupCaller: mockValidCaller, + wantErr: false, + }, + { + name: "update to self referral", + addr: validAddr1, + refAddr: validAddr1, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + setupCaller: mockValidCaller, + wantErr: true, + errCode: ErrSelfReferral, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + cleanup := tt.setupCaller() + defer cleanup() + + err := k.update(tt.addr, tt.refAddr) + + if tt.wantErr { + if err == nil { + t.Errorf("update() error = nil, wantErr %v", tt.wantErr) + return + } + referralErr, ok := err.(*ReferralError) + if !ok { + t.Errorf("update() error is not ReferralError type") + return + } + if referralErr.Code != tt.errCode { + t.Errorf("update() error code = %v, want %v", referralErr.Code, tt.errCode) + } + } else if err != nil { + t.Errorf("update() unexpected error = %v", err) + } + }) + } +} + +func TestGet(t *testing.T) { + tests := []struct { + name string + addr std.Address + setupState func(*keeper) + wantAddr std.Address + wantErr bool + errCode int + }{ + { + name: "get existing referral", + addr: validAddr1, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + wantAddr: validAddr2, + wantErr: false, + }, + { + name: "get non-existent referral", + addr: validAddr1, + setupState: func(k *keeper) {}, + wantAddr: zeroAddress, + wantErr: true, + errCode: ErrNotFound, + }, + { + name: "get with invalid address", + addr: invalidAddr, + setupState: func(k *keeper) {}, + wantAddr: zeroAddress, + wantErr: true, + errCode: ErrNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + + gotAddr, err := k.get(tt.addr) + + if tt.wantErr { + if err == nil { + t.Errorf("get() error = nil, wantErr %v", tt.wantErr) + return + } + referralErr, ok := err.(*ReferralError) + if !ok { + t.Errorf("get() error is not ReferralError type") + return + } + if referralErr.Code != tt.errCode { + t.Errorf("get() error code = %v, want %v", referralErr.Code, tt.errCode) + } + } else { + if err != nil { + t.Errorf("get() unexpected error") + return + } + if gotAddr != tt.wantAddr { + t.Errorf("get() gotAddr = %v, want %v", gotAddr, tt.wantAddr) + } + } + }) + } +} + +func TestHas(t *testing.T) { + tests := []struct { + name string + addr std.Address + setupState func(*keeper) + want bool + }{ + { + name: "has existing referral", + addr: validAddr1, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + want: true, + }, + { + name: "does not have referral", + addr: validAddr1, + setupState: func(k *keeper) {}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + + if got := k.has(tt.addr); got != tt.want { + t.Errorf("has() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRemove(t *testing.T) { + tests := []struct { + name string + addr std.Address + setupState func(*keeper) + setupCaller func() func() + wantErr bool + errCode int + }{ + { + name: "remove existing referral", + addr: validAddr1, + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + setupCaller: mockValidCaller, + wantErr: false, + }, + { + name: "remove non-existent referral", + addr: validAddr1, + setupState: func(k *keeper) {}, + setupCaller: mockValidCaller, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + cleanup := tt.setupCaller() + defer cleanup() + + err := k.remove(tt.addr) + + if tt.wantErr { + if err == nil { + t.Errorf("remove() error = nil, wantErr %v", tt.wantErr) + return + } + referralErr, ok := err.(*ReferralError) + if !ok { + t.Errorf("remove() error is not ReferralError type") + return + } + if referralErr.Code != tt.errCode { + t.Errorf("remove() error code = %v, want %v", referralErr.Code, tt.errCode) + } + } else if err != nil { + t.Errorf("remove() unexpected error = %v", err) + } + + if k.has(tt.addr) { + t.Errorf("remove() referral still exists after removal") + } + }) + } +} + +func TestUpdateNonExistentReferral(t *testing.T) { + k := setupKeeper() + cleanup := mockValidCaller() + defer cleanup() + + err := k.update(validAddr1, validAddr2) + if err != nil { + t.Errorf("update() for non-existent referral failed: %v", err) + } + + refAddr, err := k.get(validAddr1) + if err != nil { + t.Errorf("get() after update failed: %v", err) + } + if refAddr != validAddr2 { + t.Errorf("got refAddr = %v, want %v", refAddr, validAddr2) + } +} + +func TestReferralCycles(t *testing.T) { + k := setupKeeper() + cleanup := mockValidCaller() + defer cleanup() + + addr1 := testutils.TestAddress("cycle1") + addr2 := testutils.TestAddress("cycle2") + addr3 := testutils.TestAddress("cycle3") + + // A -> B -> C + err := k.register(addr1, addr2) + if err != nil { + t.Fatalf("Failed to register addr1->addr2: %v", err) + } + + err = k.register(addr2, addr3) + if err != nil { + t.Fatalf("Failed to register addr2->addr3: %v", err) + } + + // reference cycle: C -> A + err = k.register(addr3, addr1) + if err != nil { + t.Fatalf("Failed to register addr3->addr1: %v", err) + } + + refAddr, _ := k.get(addr1) + if refAddr != addr2 { + t.Error("addr1's referral should be addr2") + } + + refAddr, _ = k.get(addr2) + if refAddr != addr3 { + t.Error("addr2's referral should be addr3") + } + + refAddr, _ = k.get(addr3) + if refAddr != addr1 { + t.Error("addr3's referral should be addr1") + } +} + +func TestStress(t *testing.T) { + t.Skip("Skipping stress test") + + k := setupKeeper() + cleanup := mockValidCaller() + defer cleanup() + + addresses := make([]std.Address, STRESS_TEST_NUM) + + for i := 0; i < STRESS_TEST_NUM; i++ { + addresses[i] = testutils.TestAddress(ufmt.Sprintf("addr%d", i)) + } + + for i := 0; i < STRESS_TEST_NUM; i++ { + err := k.register(addresses[i], addresses[(i+1)%STRESS_TEST_NUM]) + if err != nil { + t.Fatalf("Registration failed at index %d: %v", i, err) + } + + err = k.update(addresses[i], addresses[(i+2)%STRESS_TEST_NUM]) + if err != nil { + t.Fatalf("Update failed at index %d: %v", i, err) + } + + // remove some addresses + if i%3 == 0 { + err = k.remove(addresses[i]) + if err != nil { + t.Fatalf("Remove failed at index %d: %v", i, err) + } + } + + // check data consistency + if i%1000 == 0 { + for j := 0; j <= i; j++ { + if j%3 == 0 { + // check removed address + if k.has(addresses[j]) { + t.Errorf("Removed address still exists at index %d", j) + } + } else { + // check registered address + refAddr, err := k.get(addresses[j]) + if err != nil { + t.Errorf("Failed to get referral at index %d: %v", j, err) + } + expectedAddr := addresses[(j+2)%STRESS_TEST_NUM] + if refAddr != expectedAddr { + t.Errorf("Incorrect referral at index %d", j) + } + } + } + } + } +} + +func TestIsEmpty(t *testing.T) { + tests := []struct { + name string + setupState func(*keeper) + want bool + }{ + { + name: "new keeper must empty", + setupState: func(k *keeper) {}, + want: true, + }, + { + name: "keeper with data must not be empty", + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + }, + want: false, + }, + { + name: "keeper with all data removed must be empty", + setupState: func(k *keeper) { + k.store.Set(validAddr1.String(), validAddr2.String()) + k.store.Remove(validAddr1.String()) + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := setupKeeper() + tt.setupState(k) + + if got := k.isEmpty(); got != tt.want { + t.Errorf("isEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + +/* Getter Tests */ + +func TestGetReferral(t *testing.T) { + tests := []struct { + name string + addr string + setupState func() + want string + shouldPanic bool + }{ + { + name: "retrieve existing referral", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + cleanup := mockValidCaller() + defer cleanup() + gReferralKeeper.register(validAddr1, validAddr2) + }, + want: validAddr2.String(), + shouldPanic: false, + }, + { + name: "retrieve non-existent referral", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + }, + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Error("GetReferral() function did not panic") + } + }() + } + + got := GetReferral(tt.addr) + if !tt.shouldPanic && got != tt.want { + t.Errorf("GetReferral() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHasReferral(t *testing.T) { + tests := []struct { + name string + addr string + setupState func() + want bool + }{ + { + name: "referral exists", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + cleanup := mockValidCaller() + defer cleanup() + gReferralKeeper.register(validAddr1, validAddr2) + }, + want: true, + }, + { + name: "referral does not exist", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + }, + want: false, + }, + { + name: "zero address referral", + addr: validAddr1.String(), + setupState: func() { + gReferralKeeper = NewKeeper() + cleanup := mockValidCaller() + defer cleanup() + gReferralKeeper.register(validAddr1, zeroAddress) + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + if got := HasReferral(tt.addr); got != tt.want { + t.Errorf("HasReferral() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGlobalIsEmpty(t *testing.T) { + tests := []struct { + name string + setupState func() + want bool + }{ + { + name: "new global keeper must be empty", + setupState: func() { + gReferralKeeper = NewKeeper() + }, + want: true, + }, + { + name: "keeper with data must not be empty", + setupState: func() { + gReferralKeeper = NewKeeper() + cleanup := mockValidCaller() + defer cleanup() + gReferralKeeper.register(validAddr1, validAddr2) + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + if got := IsEmpty(); got != tt.want { + t.Errorf("IsEmpty() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/referral/referral.gno b/referral/referral.gno new file mode 100644 index 000000000..4b30d6a8e --- /dev/null +++ b/referral/referral.gno @@ -0,0 +1,54 @@ +package referral + +import ( + "std" +) + +// Referral provides the main interface for managing referral relationships. +// It encapsulates a `ReferralKeeper` instance which handles the actual storage +// and action of referral functionality. +type Referral struct { + keeper ReferralKeeper +} + +// NewReferral creates and returns a new instance of `Referral`. +// it initializes the underlying keeper with a new storage instance. +func NewReferral() *Referral { + return &Referral{ + keeper: NewKeeper(), + } +} + +// Register creates a new referral relationship between addr and refAddr. +// It validates both addresses and ensures the caller has the required permissions. +// Returns an error if the addresses are invalid or if the caller is unauthorized. +func (r *Referral) Register(addr, refAddr std.Address) error { + return r.keeper.register(addr, refAddr) +} + +// Update modifies an existing referral relationship for the given address. +// The new referral address will replace any existing referral. +// Returns an error if the addresses are invalid or if the caller is unauthorized. +func (r *Referral) Update(addr, newAddr std.Address) error { + return r.keeper.update(addr, newAddr) +} + +// Remove deletes the referral relationship for the given address. +// If no referral exists for the address, the operation is a no-op. +// Returns an error if the address is invalid or if the caller is unauthorized. +func (r *Referral) Remove(addr std.Address) error { + return r.keeper.remove(addr) +} + +// Has checks if a referral relationship exists for the given address. +// Returns true if a referral exists, false otherwise. +func (r *Referral) Has(addr std.Address) bool { + return r.keeper.has(addr) +} + +// Get retrieves the referral address for the given address. +// Returns the referral address and nil error if found. +// Returns zeroAddress and an error if the address is invalid or no referral exists. +func (r *Referral) Get(addr std.Address) (std.Address, error) { + return r.keeper.get(addr) +} diff --git a/referral/type.gno b/referral/type.gno new file mode 100644 index 000000000..6f64651dc --- /dev/null +++ b/referral/type.gno @@ -0,0 +1,38 @@ +package referral + +import "std" + +// zeroAddress represents an empty address used for validation and comparison. +var zeroAddress = std.Address("") + +// Event types for each referral actions. +const ( + EventTypeRegister = "RegisterReferral" + EventTypeUpdate = "UpdateReferral" + EventTypeRemove = "RemoveReferral" +) + +// ReferralKeeper defines the interface for managing referral relationships. +type ReferralKeeper interface { + // register creates a new refferal relationship betwwen address and referral address. + // returns an error if the addresses are invalid or if the caller is not authorized. + register(addr, refAddr std.Address) error + + // update updates the referral address for a given address. + // returns an error if the addresses are invalid or if the caller is not authorized. + update(addr, newRefAddr std.Address) error + + // remove removes the referral relationship for a given address. + // returns an error if the address is invalid or if the caller is not authorized. + remove(addr std.Address) error + + // has checks if a referral relationship exists for a given address. + has(addr std.Address) bool + + // get retrieves the referral address for a given address. + // returns an error if the address is invalid or if the referral relationship does not exist. + get(addr std.Address) (std.Address, error) + + // isEmpty checks if the referral relationship is empty. + isEmpty() bool +} diff --git a/referral/utils.gno b/referral/utils.gno new file mode 100644 index 000000000..118918832 --- /dev/null +++ b/referral/utils.gno @@ -0,0 +1,28 @@ +package referral + +import ( + "std" + + "gno.land/r/gnoswap/v1/consts" +) + +// validCallers is a lookup table of addresses that are authorized to modify referral data. +// This includes governance contracts, router, position manager, and staker contracts. +var validCallers = map[std.Address]bool{ + consts.GOV_GOVERNANCE_ADDR: true, + consts.GOV_STAKER_ADDR: true, + consts.ROUTER_ADDR: true, + consts.POSITION_ADDR: true, + consts.STAKER_ADDR: true, + consts.LAUNCHPAD_ADDR: true, +} + +// isValidCaller checks if the given address has permission to modify referral data. +// Only specific pre-authorized addresses defined in validCallers map are allowed to +// register, update, or remove referrals. +func isValidCaller(caller std.Address) error { + if validCallers[caller] { + return nil + } + return NewError(ErrUnauthorized, caller) +} diff --git a/setup.py b/setup.py index 38864132b..017067214 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ def copy_contracts(workdir): # Copy gnoswap realms # TODO: Detect realms automatically - for realm in ["pool", "position", "router", "staker", "emission", "community_pool", "protocol_fee", "launchpad", "gov"]: + for realm in ["pool", "position", "router", "staker", "emission", "community_pool", "protocol_fee", "launchpad", "gov", "referral"]: shutil.copytree(realm, os.path.join(gno_dir, "r", "gnoswap", "v1", realm), dirs_exist_ok=True) def move_tests(workdir):