From 9569993725b239fab739b507c903b866e59ce2d5 Mon Sep 17 00:00:00 2001 From: Idan Yael Date: Tue, 25 Jul 2023 16:16:27 +0300 Subject: [PATCH] fix persistent flags --- cmd/account.go | 10 +++++----- cmd/factory.go | 1 + cmd/root.go | 9 +++++---- entities/node_client_generator.go | 5 +++++ services/contract.go | 19 +++++++++---------- services/contract_test.go | 13 +++++++++---- services/transaction.go | 15 +++++++-------- 7 files changed, 41 insertions(+), 31 deletions(-) create mode 100644 entities/node_client_generator.go diff --git a/cmd/account.go b/cmd/account.go index b85ac9d..61e6689 100644 --- a/cmd/account.go +++ b/cmd/account.go @@ -5,16 +5,16 @@ import ( "fmt" "log" - "github.com/idanya/evm-cli/clients/nodes" + "github.com/idanya/evm-cli/entities" "github.com/spf13/cobra" ) type AccountCommands struct { - nodeClient nodes.NodeClient + nodeClientGenerator entities.NodeClientGenerator } -func NewAccountCommands(nodeClient nodes.NodeClient) *AccountCommands { - return &AccountCommands{nodeClient} +func NewAccountCommands(nodeClientGenerator entities.NodeClientGenerator) *AccountCommands { + return &AccountCommands{nodeClientGenerator} } func (ac *AccountCommands) GetRootCommand() *cobra.Command { @@ -33,7 +33,7 @@ func (ac *AccountCommands) GetAccountNonceCommand() *cobra.Command { Short: "Get account nonce", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - count, err := ac.nodeClient.GetAccountNonce(context.Background(), args[0]) + count, err := ac.nodeClientGenerator().GetAccountNonce(context.Background(), args[0]) if err != nil { log.Fatal(err) } diff --git a/cmd/factory.go b/cmd/factory.go index ed95583..f79d6f8 100644 --- a/cmd/factory.go +++ b/cmd/factory.go @@ -5,6 +5,7 @@ import ( "github.com/spf13/viper" ) + func NodeClientFromViper() *nodes.EthereumNodeClient { chainId := viper.GetUint("chainId") rpcUrl := viper.GetString("rpcUrl") diff --git a/cmd/root.go b/cmd/root.go index 250cb81..37131c0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,6 +5,7 @@ import ( "os" "github.com/idanya/evm-cli/clients/directory" + "github.com/idanya/evm-cli/clients/nodes" decompiler "github.com/idanya/evm-cli/decompiler" "github.com/idanya/evm-cli/services" "github.com/spf13/cobra" @@ -27,15 +28,15 @@ func Execute(directoryClient directory.DirectoryClient, decompiler *decompiler.D viper.BindPFlag("chainId", rootCmd.PersistentFlags().Lookup("chain-id")) viper.BindPFlag("rpcUrl", rootCmd.PersistentFlags().Lookup("rpc-url")) - nodeClient := NodeClientFromViper() - contractService := services.NewContractService(nodeClient, decompiler, decoder) + nodeGenerator := func() nodes.NodeClient { return NodeClientFromViper() } - transactionService := services.NewTransactionService(nodeClient, directoryClient, decoder) + contractService := services.NewContractService(nodeGenerator, decompiler, decoder) + transactionService := services.NewTransactionService(nodeGenerator, directoryClient, decoder) tx := NewTransactionCommands(transactionService) rootCmd.AddCommand(tx.GetRootCommand()) - accountCmd := NewAccountCommands(nodeClient) + accountCmd := NewAccountCommands(nodeGenerator) rootCmd.AddCommand(accountCmd.GetRootCommand()) contractCmd := NewContractCommands(contractService, decompiler, decoder) diff --git a/entities/node_client_generator.go b/entities/node_client_generator.go new file mode 100644 index 0000000..7dbd1a6 --- /dev/null +++ b/entities/node_client_generator.go @@ -0,0 +1,5 @@ +package entities + +import "github.com/idanya/evm-cli/clients/nodes" + +type NodeClientGenerator = func() nodes.NodeClient diff --git a/services/contract.go b/services/contract.go index f575a84..180df1c 100644 --- a/services/contract.go +++ b/services/contract.go @@ -13,7 +13,6 @@ import ( "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" - "github.com/idanya/evm-cli/clients/nodes" decompiler "github.com/idanya/evm-cli/decompiler" "github.com/idanya/evm-cli/entities" ) @@ -24,13 +23,13 @@ var ( ) type ContractService struct { - nodeClient nodes.NodeClient - decompiler *decompiler.Decompiler - decoder *Decoder + nodeClientGenerator entities.NodeClientGenerator + decompiler *decompiler.Decompiler + decoder *Decoder } -func NewContractService(nodeClient nodes.NodeClient, decompiler *decompiler.Decompiler, decoder *Decoder) *ContractService { - return &ContractService{nodeClient, decompiler, decoder} +func NewContractService(nodeClientGenerator entities.NodeClientGenerator, decompiler *decompiler.Decompiler, decoder *Decoder) *ContractService { + return &ContractService{nodeClientGenerator, decompiler, decoder} } func (cs *ContractService) ExecuteReadFunction(context context.Context, contractAddress string, inputTypes []string, outputTypes []string, functionName string, params ...string) ([]interface{}, error) { @@ -62,7 +61,7 @@ func (cs *ContractService) ExecuteReadFunction(context context.Context, contract } } - return cs.nodeClient.ExecuteReadFunction(context, contractAddress, abi, functionName, castedParams...) + return cs.nodeClientGenerator().ExecuteReadFunction(context, contractAddress, abi, functionName, castedParams...) } func (cs *ContractService) GetProxyImplementation(context context.Context, contractAddress string) (string, error) { @@ -79,7 +78,7 @@ func (cs *ContractService) GetProxyImplementation(context context.Context, contr EIP_1167_BYTECODE_PREFIX := "363d3d373d3d3d363d73" EIP_1167_BYTECODE_SUFFIX := "5af43d82803e903d91602b57fd5bf3" - contractCode, err := cs.nodeClient.GetContractCode(context, contractAddress) + contractCode, err := cs.nodeClientGenerator().GetContractCode(context, contractAddress) if err == nil { hexCode := common.Bytes2Hex(contractCode) if strings.HasPrefix(hexCode, EIP_1167_BYTECODE_PREFIX) && strings.HasSuffix(hexCode, EIP_1167_BYTECODE_SUFFIX) { @@ -112,7 +111,7 @@ func (cs *ContractService) tryGetProxyImplementationByStorage(context context.Co OPEN_ZEPPELIN_IMPLEMENTATION_SLOT, EIP_1822_LOGIC_SLOT} for _, slot := range storageSlots { - response, err := cs.nodeClient.GetContractStorageSlot(context, contractAddress, slot) + response, err := cs.nodeClientGenerator().GetContractStorageSlot(context, contractAddress, slot) if err != nil { continue } @@ -163,7 +162,7 @@ func (cs *ContractService) generateMethodABI(functionName string, inputTypes []s func (cs *ContractService) GetContractStandards(context context.Context, contractAddress string) ([]string, error) { matchingStandards := make([]string, 0) - contractCode, err := cs.nodeClient.GetContractCode(context, contractAddress) + contractCode, err := cs.nodeClientGenerator().GetContractCode(context, contractAddress) if err != nil { return nil, err } diff --git a/services/contract_test.go b/services/contract_test.go index a4921a0..9bca087 100644 --- a/services/contract_test.go +++ b/services/contract_test.go @@ -9,6 +9,7 @@ import ( "github.com/ethereum/go-ethereum/common" dirmock "github.com/idanya/evm-cli/clients/directory/mocks" + "github.com/idanya/evm-cli/clients/nodes" "github.com/idanya/evm-cli/clients/nodes/mocks" decompiler "github.com/idanya/evm-cli/decompiler" "github.com/stretchr/testify/assert" @@ -26,8 +27,9 @@ func TestContractService_DetectMinimalProxyByByteCode(t *testing.T) { nodeClientMock.On("GetContractCode", mock.Anything, "0x3348f2aee62a0ddb164c711b5937e4001c17080e").Return(proxyBytecode, nil) nodeClientMock.On("ExecuteReadFunction", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error")) nodeClientMock.On("GetContractStorageSlot", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error")) + nodeGenerator := func() nodes.NodeClient { return nodeClientMock } - contractService := NewContractService(nodeClientMock, decompilerClient, decoder) + contractService := NewContractService(nodeGenerator, decompilerClient, decoder) implementation, err := contractService.GetProxyImplementation(context.Background(), "0x3348f2aee62a0ddb164c711b5937e4001c17080e") assert.NoError(t, err) assert.Equal(t, "0x4d11c446473105a02b5c1ab9ebe9b03f33902a29", implementation) @@ -44,8 +46,9 @@ func TestContractService_DetectProxyByImplementationMethods(t *testing.T) { nodeClientMock.On("ExecuteReadFunction", mock.Anything, mock.Anything, mock.Anything, method).Return([]interface{}{common.HexToAddress("0xB650eb28d35691dd1BD481325D40E65273844F9b")}, nil) nodeClientMock.On("ExecuteReadFunction", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("error")) + nodeGenerator := func() nodes.NodeClient { return nodeClientMock } - contractService := NewContractService(nodeClientMock, decompilerClient, decoder) + contractService := NewContractService(nodeGenerator, decompilerClient, decoder) implementation, err := contractService.GetProxyImplementation(context.Background(), "0x0000000000085d4780B73119b644AE5ecd22b376") assert.NoError(t, err) assert.Equal(t, "0xB650eb28d35691dd1BD481325D40E65273844F9b", implementation) @@ -61,8 +64,9 @@ func TestExecuteReadFunction(t *testing.T) { nodeClientMock.On("ExecuteReadFunction", mock.Anything, "0x0", mock.Anything, "func", common.HexToAddress("0xdac17f958d2ee523a2206206994597c13d831ec7"), new(big.Int).SetUint64(10), new(big.Int).SetUint64(100), mock.Anything).Return([]interface{}{"OK"}, nil) + nodeGenerator := func() nodes.NodeClient { return nodeClientMock } - contractService := NewContractService(nodeClientMock, decompilerClient, decoder) + contractService := NewContractService(nodeGenerator, decompilerClient, decoder) response, err := contractService.ExecuteReadFunction(context.Background(), "0x0", []string{"address", "uint256", "int256", "bool"}, []string{"address"}, "func", "0xdac17f958d2ee523a2206206994597c13d831ec7", "10", "100", "false") @@ -82,8 +86,9 @@ func TestGetContractStandards(t *testing.T) { nodeClientMock := mocks.NewNodeClient(t) nodeClientMock.On("GetContractCode", mock.Anything, "0x0").Return(erc20Code, nil) + nodeGenerator := func() nodes.NodeClient { return nodeClientMock } - contractService := NewContractService(nodeClientMock, decompilerClient, decoder) + contractService := NewContractService(nodeGenerator, decompilerClient, decoder) standards, err := contractService.GetContractStandards(context.Background(), "0x0") assert.Nil(t, err) assert.NotNil(t, standards) diff --git a/services/transaction.go b/services/transaction.go index b1c6c9b..21b9b1d 100644 --- a/services/transaction.go +++ b/services/transaction.go @@ -5,24 +5,23 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/idanya/evm-cli/clients/directory" - "github.com/idanya/evm-cli/clients/nodes" "github.com/idanya/evm-cli/entities" ) type TransactionService struct { - nodeClient nodes.NodeClient - directoryClient directory.DirectoryClient - decoder *Decoder + nodeClientGenerator entities.NodeClientGenerator + directoryClient directory.DirectoryClient + decoder *Decoder } -func NewTransactionService(nodeClient nodes.NodeClient, +func NewTransactionService(nodeClientGenerator entities.NodeClientGenerator, directoryClient directory.DirectoryClient, decoder *Decoder) *TransactionService { - return &TransactionService{nodeClient, directoryClient, decoder} + return &TransactionService{nodeClientGenerator, directoryClient, decoder} } func (ts *TransactionService) GetTransactionReceipt(context context.Context, txHash string) (*entities.EnrichedReceipt, error) { - receipt, err := ts.nodeClient.GetTransactionReceipt(context, txHash) + receipt, err := ts.nodeClientGenerator().GetTransactionReceipt(context, txHash) if err != nil { return nil, err } @@ -47,7 +46,7 @@ func (ts *TransactionService) GetTransactionReceipt(context context.Context, txH } func (ts *TransactionService) GetTransactionByHash(context context.Context, txHash string) (*entities.EnrichedTxInfo, error) { - transaction, err := ts.nodeClient.GetTransactionByHash(context, txHash) + transaction, err := ts.nodeClientGenerator().GetTransactionByHash(context, txHash) if err != nil { return nil, err }