Skip to content

Commit

Permalink
fix persistent flags
Browse files Browse the repository at this point in the history
  • Loading branch information
idanya committed Jul 25, 2023
1 parent fdef1f3 commit 9569993
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 31 deletions.
10 changes: 5 additions & 5 deletions cmd/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ import (
"fmt"
"log"

"github.com/idanya/evm-cli/clients/nodes"
"github.com/idanya/evm-cli/entities"
"github.com/spf13/cobra"
)

type AccountCommands struct {
nodeClient nodes.NodeClient
nodeClientGenerator entities.NodeClientGenerator
}

func NewAccountCommands(nodeClient nodes.NodeClient) *AccountCommands {
return &AccountCommands{nodeClient}
func NewAccountCommands(nodeClientGenerator entities.NodeClientGenerator) *AccountCommands {
return &AccountCommands{nodeClientGenerator}
}

func (ac *AccountCommands) GetRootCommand() *cobra.Command {
Expand All @@ -33,7 +33,7 @@ func (ac *AccountCommands) GetAccountNonceCommand() *cobra.Command {
Short: "Get account nonce",
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
count, err := ac.nodeClient.GetAccountNonce(context.Background(), args[0])
count, err := ac.nodeClientGenerator().GetAccountNonce(context.Background(), args[0])
if err != nil {
log.Fatal(err)
}
Expand Down
1 change: 1 addition & 0 deletions cmd/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/spf13/viper"
)


func NodeClientFromViper() *nodes.EthereumNodeClient {
chainId := viper.GetUint("chainId")
rpcUrl := viper.GetString("rpcUrl")
Expand Down
9 changes: 5 additions & 4 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"

"github.com/idanya/evm-cli/clients/directory"
"github.com/idanya/evm-cli/clients/nodes"
decompiler "github.com/idanya/evm-cli/decompiler"
"github.com/idanya/evm-cli/services"
"github.com/spf13/cobra"
Expand All @@ -27,15 +28,15 @@ func Execute(directoryClient directory.DirectoryClient, decompiler *decompiler.D
viper.BindPFlag("chainId", rootCmd.PersistentFlags().Lookup("chain-id"))
viper.BindPFlag("rpcUrl", rootCmd.PersistentFlags().Lookup("rpc-url"))

nodeClient := NodeClientFromViper()
contractService := services.NewContractService(nodeClient, decompiler, decoder)
nodeGenerator := func() nodes.NodeClient { return NodeClientFromViper() }

transactionService := services.NewTransactionService(nodeClient, directoryClient, decoder)
contractService := services.NewContractService(nodeGenerator, decompiler, decoder)
transactionService := services.NewTransactionService(nodeGenerator, directoryClient, decoder)

tx := NewTransactionCommands(transactionService)
rootCmd.AddCommand(tx.GetRootCommand())

accountCmd := NewAccountCommands(nodeClient)
accountCmd := NewAccountCommands(nodeGenerator)
rootCmd.AddCommand(accountCmd.GetRootCommand())

contractCmd := NewContractCommands(contractService, decompiler, decoder)
Expand Down
5 changes: 5 additions & 0 deletions entities/node_client_generator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package entities

import "github.com/idanya/evm-cli/clients/nodes"

type NodeClientGenerator = func() nodes.NodeClient
19 changes: 9 additions & 10 deletions services/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/idanya/evm-cli/clients/nodes"
decompiler "github.com/idanya/evm-cli/decompiler"
"github.com/idanya/evm-cli/entities"
)
Expand All @@ -24,13 +23,13 @@ var (
)

type ContractService struct {
nodeClient nodes.NodeClient
decompiler *decompiler.Decompiler
decoder *Decoder
nodeClientGenerator entities.NodeClientGenerator
decompiler *decompiler.Decompiler
decoder *Decoder
}

func NewContractService(nodeClient nodes.NodeClient, decompiler *decompiler.Decompiler, decoder *Decoder) *ContractService {
return &ContractService{nodeClient, decompiler, decoder}
func NewContractService(nodeClientGenerator entities.NodeClientGenerator, decompiler *decompiler.Decompiler, decoder *Decoder) *ContractService {
return &ContractService{nodeClientGenerator, decompiler, decoder}
}

func (cs *ContractService) ExecuteReadFunction(context context.Context, contractAddress string, inputTypes []string, outputTypes []string, functionName string, params ...string) ([]interface{}, error) {
Expand Down Expand Up @@ -62,7 +61,7 @@ func (cs *ContractService) ExecuteReadFunction(context context.Context, contract
}
}

return cs.nodeClient.ExecuteReadFunction(context, contractAddress, abi, functionName, castedParams...)
return cs.nodeClientGenerator().ExecuteReadFunction(context, contractAddress, abi, functionName, castedParams...)
}

func (cs *ContractService) GetProxyImplementation(context context.Context, contractAddress string) (string, error) {
Expand All @@ -79,7 +78,7 @@ func (cs *ContractService) GetProxyImplementation(context context.Context, contr
EIP_1167_BYTECODE_PREFIX := "363d3d373d3d3d363d73"
EIP_1167_BYTECODE_SUFFIX := "5af43d82803e903d91602b57fd5bf3"

contractCode, err := cs.nodeClient.GetContractCode(context, contractAddress)
contractCode, err := cs.nodeClientGenerator().GetContractCode(context, contractAddress)
if err == nil {
hexCode := common.Bytes2Hex(contractCode)
if strings.HasPrefix(hexCode, EIP_1167_BYTECODE_PREFIX) && strings.HasSuffix(hexCode, EIP_1167_BYTECODE_SUFFIX) {
Expand Down Expand Up @@ -112,7 +111,7 @@ func (cs *ContractService) tryGetProxyImplementationByStorage(context context.Co
OPEN_ZEPPELIN_IMPLEMENTATION_SLOT, EIP_1822_LOGIC_SLOT}

for _, slot := range storageSlots {
response, err := cs.nodeClient.GetContractStorageSlot(context, contractAddress, slot)
response, err := cs.nodeClientGenerator().GetContractStorageSlot(context, contractAddress, slot)
if err != nil {
continue
}
Expand Down Expand Up @@ -163,7 +162,7 @@ func (cs *ContractService) generateMethodABI(functionName string, inputTypes []s

func (cs *ContractService) GetContractStandards(context context.Context, contractAddress string) ([]string, error) {
matchingStandards := make([]string, 0)
contractCode, err := cs.nodeClient.GetContractCode(context, contractAddress)
contractCode, err := cs.nodeClientGenerator().GetContractCode(context, contractAddress)
if err != nil {
return nil, err
}
Expand Down
13 changes: 9 additions & 4 deletions services/contract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/ethereum/go-ethereum/common"
dirmock "github.com/idanya/evm-cli/clients/directory/mocks"
"github.com/idanya/evm-cli/clients/nodes"
"github.com/idanya/evm-cli/clients/nodes/mocks"
decompiler "github.com/idanya/evm-cli/decompiler"
"github.com/stretchr/testify/assert"
Expand All @@ -26,8 +27,9 @@ func TestContractService_DetectMinimalProxyByByteCode(t *testing.T) {
nodeClientMock.On("GetContractCode", mock.Anything, "0x3348f2aee62a0ddb164c711b5937e4001c17080e").Return(proxyBytecode, nil)
nodeClientMock.On("ExecuteReadFunction", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error"))
nodeClientMock.On("GetContractStorageSlot", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error"))
nodeGenerator := func() nodes.NodeClient { return nodeClientMock }

contractService := NewContractService(nodeClientMock, decompilerClient, decoder)
contractService := NewContractService(nodeGenerator, decompilerClient, decoder)
implementation, err := contractService.GetProxyImplementation(context.Background(), "0x3348f2aee62a0ddb164c711b5937e4001c17080e")
assert.NoError(t, err)
assert.Equal(t, "0x4d11c446473105a02b5c1ab9ebe9b03f33902a29", implementation)
Expand All @@ -44,8 +46,9 @@ func TestContractService_DetectProxyByImplementationMethods(t *testing.T) {

nodeClientMock.On("ExecuteReadFunction", mock.Anything, mock.Anything, mock.Anything, method).Return([]interface{}{common.HexToAddress("0xB650eb28d35691dd1BD481325D40E65273844F9b")}, nil)
nodeClientMock.On("ExecuteReadFunction", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error"))
nodeGenerator := func() nodes.NodeClient { return nodeClientMock }

contractService := NewContractService(nodeClientMock, decompilerClient, decoder)
contractService := NewContractService(nodeGenerator, decompilerClient, decoder)
implementation, err := contractService.GetProxyImplementation(context.Background(), "0x0000000000085d4780B73119b644AE5ecd22b376")
assert.NoError(t, err)
assert.Equal(t, "0xB650eb28d35691dd1BD481325D40E65273844F9b", implementation)
Expand All @@ -61,8 +64,9 @@ func TestExecuteReadFunction(t *testing.T) {
nodeClientMock.On("ExecuteReadFunction", mock.Anything, "0x0", mock.Anything, "func",
common.HexToAddress("0xdac17f958d2ee523a2206206994597c13d831ec7"),
new(big.Int).SetUint64(10), new(big.Int).SetUint64(100), mock.Anything).Return([]interface{}{"OK"}, nil)
nodeGenerator := func() nodes.NodeClient { return nodeClientMock }

contractService := NewContractService(nodeClientMock, decompilerClient, decoder)
contractService := NewContractService(nodeGenerator, decompilerClient, decoder)
response, err := contractService.ExecuteReadFunction(context.Background(), "0x0",
[]string{"address", "uint256", "int256", "bool"},
[]string{"address"}, "func", "0xdac17f958d2ee523a2206206994597c13d831ec7", "10", "100", "false")
Expand All @@ -82,8 +86,9 @@ func TestGetContractStandards(t *testing.T) {

nodeClientMock := mocks.NewNodeClient(t)
nodeClientMock.On("GetContractCode", mock.Anything, "0x0").Return(erc20Code, nil)
nodeGenerator := func() nodes.NodeClient { return nodeClientMock }

contractService := NewContractService(nodeClientMock, decompilerClient, decoder)
contractService := NewContractService(nodeGenerator, decompilerClient, decoder)
standards, err := contractService.GetContractStandards(context.Background(), "0x0")
assert.Nil(t, err)
assert.NotNil(t, standards)
Expand Down
15 changes: 7 additions & 8 deletions services/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@ import (

"github.com/ethereum/go-ethereum/common"
"github.com/idanya/evm-cli/clients/directory"
"github.com/idanya/evm-cli/clients/nodes"
"github.com/idanya/evm-cli/entities"
)

type TransactionService struct {
nodeClient nodes.NodeClient
directoryClient directory.DirectoryClient
decoder *Decoder
nodeClientGenerator entities.NodeClientGenerator
directoryClient directory.DirectoryClient
decoder *Decoder
}

func NewTransactionService(nodeClient nodes.NodeClient,
func NewTransactionService(nodeClientGenerator entities.NodeClientGenerator,
directoryClient directory.DirectoryClient,
decoder *Decoder) *TransactionService {
return &TransactionService{nodeClient, directoryClient, decoder}
return &TransactionService{nodeClientGenerator, directoryClient, decoder}
}

func (ts *TransactionService) GetTransactionReceipt(context context.Context, txHash string) (*entities.EnrichedReceipt, error) {
receipt, err := ts.nodeClient.GetTransactionReceipt(context, txHash)
receipt, err := ts.nodeClientGenerator().GetTransactionReceipt(context, txHash)
if err != nil {
return nil, err
}
Expand All @@ -47,7 +46,7 @@ func (ts *TransactionService) GetTransactionReceipt(context context.Context, txH
}

func (ts *TransactionService) GetTransactionByHash(context context.Context, txHash string) (*entities.EnrichedTxInfo, error) {
transaction, err := ts.nodeClient.GetTransactionByHash(context, txHash)
transaction, err := ts.nodeClientGenerator().GetTransactionByHash(context, txHash)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 9569993

Please sign in to comment.