Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

erc20_rewards: outline for single-extension #1339

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ tasks:
generates:
- node/engine/parse/gen/*.{go,interp,tokens}

generate:abi:
desc: Generate the ABI for the smart contracts
cmds:
- abigen --abi=./node/exts/erc20reward/abigen/reward_distributor_abi.json --pkg abigen --out=./node/exts/erc20reward/abigen/reward_distributor.go --type RewardDistributor
- abigen --abi=./node/exts/erc20reward/abigen/erc20_abi.json --pkg abigen --out=./node/exts/erc20reward/abigen/erc20.go --type Erc20

# ************ docker ************
vendor:
desc: Generate vendor
Expand Down
10 changes: 10 additions & 0 deletions core/types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ func MustParseDecimalExplicit(s string, precision, scale uint16) *Decimal {
return dec
}

func NewDecimalFromInt(i int64) *Decimal {
b, err := NewDecimalFromBigInt(big.NewInt(i), 0)
if err != nil {
// wont panic because exp is 0
panic(err)
}

return b
}

// NewDecimalFromBigInt creates a new Decimal from a big.Int and an exponent.
// The negative of the exponent is the scale of the decimal.
func NewDecimalFromBigInt(i *big.Int, exp int32) (*Decimal, error) {
Expand Down
68 changes: 56 additions & 12 deletions extensions/precompiles/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,27 @@ type Precompile struct {
Methods []Method
// OnUnuse is called when a `UNUSE ...` statement is executed
OnUnuse func(ctx *common.EngineContext, app *common.App) error
// Cache is a snapshot of the in-memory state of the extension.
// It is used to save and restore the state of the extension.
Cache Cache
}

// Cache is a snapshot of the in-memory state of a precompile extension.
type Cache interface {
// Copy creates a deep copy of the cache.
Copy() Cache
// Apply applies a previously created deep copy of the cache.
// The value passed from Apply will never be changed by the engine,
// so there is no need to copy it.
Apply(cache Cache)
}

type emptyCache struct{}

func (e *emptyCache) Apply(cache Cache) {}

func (e *emptyCache) Copy() Cache { return &emptyCache{} }

// CleanExtension verifies that the extension is correctly set up.
// It does not need to be called by extension authors, as it is called
// automatically by kwild.
Expand Down Expand Up @@ -62,6 +81,10 @@ func CleanPrecompile(e *Precompile) error {
e.OnUnuse = func(ctx *common.EngineContext, app *common.App) error { return nil }
}

if e.Cache == nil {
e.Cache = &emptyCache{}
}

return nil
}

Expand All @@ -83,9 +106,11 @@ type Method struct {
// If nil, the method does not return anything.
Returns *MethodReturn
// Handler is the function that is called when the method is invoked.
Handler func(ctx *common.EngineContext, app *common.App, inputs []any, resultFn func([]any) error) error
Handler HandlerFunc
}

type HandlerFunc func(ctx *common.EngineContext, app *common.App, inputs []any, resultFn func([]any) error) error

// Copy deep-copies a method.
func (m *Method) Copy() *Method {
m2 := &Method{
Expand All @@ -97,9 +122,8 @@ func (m *Method) Copy() *Method {

if m.Returns != nil {
m2.Returns = &MethodReturn{
IsTable: m.Returns.IsTable,
Fields: copyParams(m.Returns.Fields),
FieldNames: slices.Clone(m.Returns.FieldNames),
IsTable: m.Returns.IsTable,
Fields: copyParams(m.Returns.Fields),
}
}

Expand All @@ -119,6 +143,10 @@ func (m *Method) verify() error {
return fmt.Errorf("method name %s must be lowercase", m.Name)
}

if len(m.Name) == 0 {
return fmt.Errorf("method name must not be empty")
}

if len(m.AccessModifiers) == 0 {
return fmt.Errorf("method %s has no access modifiers", m.Name)
}
Expand All @@ -134,19 +162,35 @@ func (m *Method) verify() error {
return fmt.Errorf("method %s must have exactly one of PUBLIC, PRIVATE, or SYSTEM", m.Name)
}

if err := uniqueFieldNames(m.Parameters); err != nil {
return fmt.Errorf("method %s: %w", m.Name, err)
}

if m.Returns != nil {
if len(m.Returns.Fields) == 0 {
return fmt.Errorf("method %s has no return types", m.Name)
}

if len(m.Returns.FieldNames) != 0 && len(m.Returns.FieldNames) != len(m.Returns.Fields) {
return fmt.Errorf("method %s has %d return names, but %d return types", m.Name, len(m.Returns.FieldNames), len(m.Returns.Fields))
if err := uniqueFieldNames(m.Returns.Fields); err != nil {
return fmt.Errorf("method %s: %w", m.Name, err)
}
}

return nil
}

func uniqueFieldNames(fields []PrecompileValue) error {
fieldNames := make(map[string]struct{})
for _, field := range fields {
if _, ok := fieldNames[field.Name]; ok {
return fmt.Errorf("duplicate field name %s", field.Name)
}
fieldNames[field.Name] = struct{}{}
}

return nil
}

// MethodReturn specifies the return structure of a method.
type MethodReturn struct {
// If true, then the method returns any number of rows.
Expand All @@ -156,11 +200,6 @@ type MethodReturn struct {
// It is required. If the extension returns types that are
// not matching the column types, the engine will return an error.
Fields []PrecompileValue
// FieldNames is a list of column names.
// It is optional. If it is set, its length must be equal to the length
// of the column types. If it is not set, the column names will be generated
// based on their position in the column types.
FieldNames []string
}

// Modifier modifies the access to a procedure.
Expand Down Expand Up @@ -194,6 +233,8 @@ func (m Modifiers) Has(mod Modifier) bool {
// PrecompileValue specifies the type and nullability of a value passed to or returned from
// a precompile method.
type PrecompileValue struct {
// Name is the name of the value.
Name string
// Type is the type of the value.
Type *types.DataType
// Nullable is true if the value can be null.
Expand All @@ -202,14 +243,17 @@ type PrecompileValue struct {

func (p *PrecompileValue) Copy() PrecompileValue {
return PrecompileValue{
Name: p.Name,
Type: p.Type.Copy(),
Nullable: p.Nullable,
}
}

// NewPrecompileValue creates a new precompile value.
func NewPrecompileValue(t *types.DataType, nullable bool) PrecompileValue {
// TODO: update this signature to include name
func NewPrecompileValue(name string, t *types.DataType, nullable bool) PrecompileValue {
return PrecompileValue{
Name: name,
Type: t,
Nullable: nullable,
}
Expand Down
1 change: 1 addition & 0 deletions node/engine/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var (
ErrBuiltInRole = errors.New("invalid operation on built-in role")
ErrInvalidTxCtx = errors.New("invalid transaction context")
ErrReservedNamespacePrefix = errors.New("namespace prefix is reserved")
ErrCannotAlterPrimaryKey = errors.New("cannot drop or alter a table's primary key")

// Errors that are the result of not having proper permissions or failing to meet a condition
// that was programmed by the user.
Expand Down
2 changes: 1 addition & 1 deletion node/engine/interpreter/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ func (e *executionContext) canExecute(newNamespace, actionName string, modifiers
return fmt.Errorf("%w: action %s can only be executed by the owner", engine.ErrActionOwnerOnly, actionName)
}

return nil
return e.checkPrivilege(_CALL_PRIVILEGE)
}

func (e *executionContext) app() *common.App {
Expand Down
5 changes: 3 additions & 2 deletions node/engine/interpreter/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ func initializeExtension(ctx context.Context, svc *common.Service, db sql.DB, i
}
}

if len(method.Returns.FieldNames) > 0 {
colNames = method.Returns.FieldNames
for _, field := range method.Returns.Fields {
colNames = append(colNames, field.Name)
}
}

Expand Down Expand Up @@ -122,6 +122,7 @@ func initializeExtension(ctx context.Context, svc *common.Service, db sql.DB, i
},
namespaceType: namespaceTypeExtension,
methods: methods,
extCache: inst.Cache,
}, &inst, nil
}

Expand Down
64 changes: 61 additions & 3 deletions node/engine/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ type namespace struct {
namespaceType namespaceType
// methods is a map of methods that are available if the namespace is an extension.
methods map[string]precompileExecutable
// extensionCache is a cache of in-memory state for an extension.
// It can be nil if the namespace does not have an extension.
extCache precompiles.Cache
}

// copy creates a deep copy of the namespace.
Expand All @@ -155,6 +158,10 @@ func (n *namespace) copy() *namespace {
methods: make(map[string]precompileExecutable), // we need to copy the methods as well, so shallow copy is not enough
}

if n.extCache != nil {
n2.extCache = n.extCache.Copy()
}

for tblName, tbl := range n.tables {
n2.tables[tblName] = tbl.Copy()
}
Expand All @@ -166,6 +173,21 @@ func (n *namespace) copy() *namespace {
return n2
}

// apply applies a previously created deep copy of the namespace.
// The value passed from Apply will never be changed by the engine,
func (n *namespace) apply(n2 *namespace) {
n.availableFunctions = n2.availableFunctions
n.tables = n2.tables
n.onDeploy = n2.onDeploy
n.onUndeploy = n2.onUndeploy
n.namespaceType = n2.namespaceType
n.methods = n2.methods

if n.extCache != nil {
n.extCache.Apply(n2.extCache)
}
}

type namespaceType string

const (
Expand Down Expand Up @@ -465,6 +487,30 @@ func (i *baseInterpreter) copy() *baseInterpreter {
}
}

// apply applies a previously copied state to the interpreter.
// It is used to roll back the interpreter to a previous state.
func (i *baseInterpreter) apply(copied *baseInterpreter) {
newNamespaces := make(map[string]*namespace)
for k, v := range copied.namespaces {
// if a namespace already exists, we need to call
// the apply function. If it is new, we just add it.
toSet, ok := i.namespaces[k]
if ok {
toSet.apply(v)
} else {
toSet = v
}

newNamespaces[k] = toSet
}
i.namespaces = newNamespaces

i.accessController = copied.accessController
i.service = copied.service
i.validators = copied.validators
i.accounts = copied.accounts
}

// Execute executes a statement against the database.
func (i *baseInterpreter) execute(ctx *common.EngineContext, db sql.DB, statement string, params map[string]any, fn func(*common.Row) error, toplevel bool) (err error) {
copied := i.copy()
Expand All @@ -483,8 +529,7 @@ func (i *baseInterpreter) execute(ctx *common.EngineContext, db sql.DB, statemen
i.syncNamespaceManager()
} else {
// rollback
i.namespaces = copied.namespaces
i.accessController = copied.accessController
i.apply(copied)
}
}()

Expand Down Expand Up @@ -569,10 +614,23 @@ func isValidVarName(s string) error {
// Call executes an action against the database.
// The resultFn is called with the result of the action, if any.
func (i *baseInterpreter) call(ctx *common.EngineContext, db sql.DB, namespace, action string, args []any, resultFn func(*common.Row) error, toplevel bool) (callRes *common.CallResult, err error) {
copied := i.copy()
defer func() {
i.syncNamespaceManager()
noErrOrPanic := true
if err != nil {
// rollback the interpreter
noErrOrPanic = false
}
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
noErrOrPanic = false
}

if noErrOrPanic {
i.syncNamespaceManager()
} else {
// rollback
i.apply(copied)
}
}()

Expand Down
Loading
Loading