diff --git a/ethcoder/typed_data.go b/ethcoder/typed_data.go index 509b0b9..7c0222b 100644 --- a/ethcoder/typed_data.go +++ b/ethcoder/typed_data.go @@ -1,9 +1,12 @@ package ethcoder import ( + "bytes" + "encoding/json" "fmt" "math/big" "sort" + "strings" "github.com/0xsequence/ethkit/go-ethereum/common" "github.com/0xsequence/ethkit/go-ethereum/crypto" @@ -30,16 +33,21 @@ func (t TypedDataTypes) EncodeType(primaryType string) (string, error) { s := primaryType + "(" for i, arg := range args { - _, ok := t[arg.Type] - if ok { + baseType := arg.Type + if strings.Index(baseType, "[") > 0 { + baseType = baseType[:strings.Index(baseType, "[")] + } + + if _, ok := t[baseType]; ok { set := false for _, v := range subTypes { - if v == arg.Type { + if v == baseType { set = true + break } } if !set { - subTypes = append(subTypes, arg.Type) + subTypes = append(subTypes, baseType) } } @@ -62,6 +70,18 @@ func (t TypedDataTypes) EncodeType(primaryType string) (string, error) { return s, nil } +func (t TypedDataTypes) Map() map[string]map[string]string { + out := map[string]map[string]string{} + for k, v := range t { + m := make(map[string]string, len(v)) + for _, arg := range v { + m[arg.Name] = arg.Type + } + out[k] = m + } + return out +} + func (t TypedDataTypes) TypeHash(primaryType string) ([]byte, error) { encodeType, err := t.EncodeType(primaryType) if err != nil { @@ -128,94 +148,274 @@ func (t *TypedData) encodeData(primaryType string, data map[string]interface{}) return nil, fmt.Errorf("encoding failed for type %s, expecting %d arguments but received %d data values", primaryType, len(args), len(data)) } - abiTypes := []string{} - abiValues := []interface{}{} + encodedTypes := make([]string, len(args)) + encodedValues := make([]interface{}, len(args)) - for _, arg := range args { + for i, arg := range args { dataValue, ok := data[arg.Name] if !ok { return nil, fmt.Errorf("data value missing for type %s with argument name %s", primaryType, arg.Name) } - switch arg.Type { - case "bytes", "string": - var bytesValue []byte - if v, ok := dataValue.([]byte); ok { - bytesValue = v - } else if v, ok := dataValue.(string); ok { - bytesValue = []byte(v) - } else { - return nil, fmt.Errorf("data value invalid for type %s with argument name %s", primaryType, arg.Name) - } - abiTypes = append(abiTypes, "bytes32") - abiValues = append(abiValues, BytesToBytes32(Keccak256(bytesValue))) - - default: - dataValueString, isString := dataValue.(string) - if isString { - v, err := ABIUnmarshalStringValues([]string{arg.Type}, []string{dataValueString}) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal string value for type %s with argument name %s, because %w", primaryType, arg.Name, err) - } - abiValues = append(abiValues, v[0]) - } else { - abiValues = append(abiValues, dataValue) + encValue, err := t.encodeValue(arg.Type, dataValue) + if err != nil { + return nil, fmt.Errorf("failed to encode %s: %w", arg.Name, err) + } + encodedTypes[i] = "bytes" + encodedValues[i] = encValue + } + + return SolidityPack(encodedTypes, encodedValues) +} + +// encodeValue handles the recursive encoding of values according to their types +func (t *TypedData) encodeValue(typ string, value interface{}) ([]byte, error) { + // Handle arrays + if strings.Index(typ, "[") > 0 { + baseType := typ[:strings.Index(typ, "[")] + values, ok := value.([]interface{}) + if !ok { + return nil, fmt.Errorf("expected array for type %s", typ) + } + + encodedValues := make([][]byte, len(values)) + for i, val := range values { + encoded, err := t.encodeValue(baseType, val) + if err != nil { + return nil, fmt.Errorf("failed to encode array element %d: %w", i, err) } - abiTypes = append(abiTypes, arg.Type) + encodedValues[i] = encoded } + + // For arrays, we concatenate the encoded values and hash the result + concat := bytes.Join(encodedValues, nil) + return Keccak256(concat), nil } - if len(args) != len(abiTypes) || len(args) != len(abiValues) { - return nil, fmt.Errorf("argument encoding failed to encode all values") + // Handle bytes and string + if typ == "bytes" || typ == "string" { + var bytesValue []byte + if v, ok := value.([]byte); ok { + bytesValue = v + } else if v, ok := value.(string); ok { + bytesValue = []byte(v) + } else { + return nil, fmt.Errorf("invalid value for type %s", typ) + } + return Keccak256(bytesValue), nil } - // NOTE: each part must be bytes32 - var err error - encodedTypes := make([]string, len(args)) - encodedValues := make([]interface{}, len(args)) - for i := 0; i < len(args); i++ { - pack, err := SolidityPack([]string{abiTypes[i]}, []interface{}{abiValues[i]}) - if err != nil { - return nil, err + // Handle custom struct types + if _, isCustomType := t.Types[typ]; isCustomType { + mapVal, ok := value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid value for custom type %s", typ) } - encodedValues[i], err = PadZeros(pack, 32) + encoded, err := t.HashStruct(typ, mapVal) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to encode custom type %s: %w", typ, err) } - encodedTypes[i] = "bytes" + return PadZeros(encoded, 32) } - encodedData, err := SolidityPack(encodedTypes, encodedValues) + // Handle primitive types + packed, err := SolidityPack([]string{typ}, []interface{}{value}) if err != nil { return nil, err } - return encodedData, nil + return PadZeros(packed, 32) } -func (t *TypedData) EncodeDigest() ([]byte, error) { - EIP191_HEADER := "0x1901" +// Encode returns the digest of the typed data and the fully encoded EIP712 typed data message. +// +// NOTE: +// * the digest is the hash of the fully encoded EIP712 message +// * the encoded message is the fully encoded EIP712 message (0x1901 + domain + hashStruct(message)) +func (t *TypedData) Encode() ([]byte, []byte, error) { + EIP191_HEADER := "0x1901" // EIP191 for typed data eip191Header, err := HexDecode(EIP191_HEADER) if err != nil { - return nil, err + return nil, nil, err } // Prepare hash struct for the domain domainHash, err := t.HashStruct("EIP712Domain", t.Domain.Map()) if err != nil { - return nil, err + return nil, nil, err } // Prepare hash struct for the message object messageHash, err := t.HashStruct(t.PrimaryType, t.Message) + if err != nil { + return nil, nil, err + } + + encodedMessage, err := SolidityPack([]string{"bytes", "bytes32", "bytes32"}, []interface{}{eip191Header, domainHash, messageHash}) + if err != nil { + return nil, nil, err + } + + digest := crypto.Keccak256(encodedMessage) + + return digest, encodedMessage, nil +} + +// EncodeDigest returns the digest of the typed data message. +func (t *TypedData) EncodeDigest() ([]byte, error) { + digest, _, err := t.Encode() if err != nil { return nil, err } + return digest, nil +} - hashPack, err := SolidityPack([]string{"bytes", "bytes32", "bytes32"}, []interface{}{eip191Header, domainHash, messageHash}) +func TypedDataFromJSON(typedDataJSON string) (*TypedData, error) { + var typedData TypedData + err := json.Unmarshal([]byte(typedDataJSON), &typedData) if err != nil { - return []byte{}, err + return nil, err + } + return &typedData, nil +} + +func (t *TypedData) UnmarshalJSON(data []byte) error { + // Intermediary structure to decode message field + type TypedDataRaw struct { + Types TypedDataTypes `json:"types"` + PrimaryType string `json:"primaryType"` + Domain TypedDataDomain `json:"domain"` + Message map[string]interface{} `json:"message"` } - hashBytes := crypto.Keccak256(hashPack) - return hashBytes, nil + // Json decoder with json.Number support, so that we can decode big.Int values + dec := json.NewDecoder(bytes.NewReader(data)) + dec.UseNumber() + + var raw TypedDataRaw + if err := dec.Decode(&raw); err != nil { + return err + } + + // Ensure the "EIP712Domain" type is defined. In case its not defined + // we will add it to the types map + _, ok := raw.Types["EIP712Domain"] + if !ok { + raw.Types["EIP712Domain"] = []TypedDataArgument{} + if raw.Domain.Name != "" { + raw.Types["EIP712Domain"] = append(raw.Types["EIP712Domain"], TypedDataArgument{Name: "name", Type: "string"}) + } + if raw.Domain.Version != "" { + raw.Types["EIP712Domain"] = append(raw.Types["EIP712Domain"], TypedDataArgument{Name: "version", Type: "string"}) + } + if raw.Domain.ChainID != nil { + raw.Types["EIP712Domain"] = append(raw.Types["EIP712Domain"], TypedDataArgument{Name: "chainId", Type: "uint256"}) + } + if raw.Domain.VerifyingContract != nil { + raw.Types["EIP712Domain"] = append(raw.Types["EIP712Domain"], TypedDataArgument{Name: "verifyingContract", Type: "address"}) + } + if raw.Domain.Salt != nil { + raw.Types["EIP712Domain"] = append(raw.Types["EIP712Domain"], TypedDataArgument{Name: "salt", Type: "bytes32"}) + } + } + + // Ensure primary type is defined + if raw.PrimaryType == "" { + return fmt.Errorf("primary type is required") + } + _, ok = raw.Types[raw.PrimaryType] + if !ok { + return fmt.Errorf("primary type '%s' is not defined", raw.PrimaryType) + } + + // Decode the raw message into Go runtime types + message, err := typedDataDecodeRawMessageMap(raw.Types.Map(), raw.PrimaryType, raw.Message) + if err != nil { + return err + } + + t.Types = raw.Types + t.PrimaryType = raw.PrimaryType + t.Domain = raw.Domain + + m, ok := message.(map[string]interface{}) + if !ok { + return fmt.Errorf("resulting message is not a map") + } + t.Message = m + + return nil +} + +func typedDataDecodeRawMessageMap(typesMap map[string]map[string]string, primaryType string, data interface{}) (interface{}, error) { + // Handle array types + if arr, ok := data.([]interface{}); ok { + results := make([]interface{}, len(arr)) + for i, item := range arr { + decoded, err := typedDataDecodeRawMessageMap(typesMap, primaryType, item) + if err != nil { + return nil, err + } + results[i] = decoded + } + return results, nil + } + + // Handle primitive directly + message, ok := data.(map[string]interface{}) + if !ok { + return typedDataDecodePrimitiveValue(primaryType, data) + } + + currentType, ok := typesMap[primaryType] + if !ok { + return nil, fmt.Errorf("type %s is not defined", primaryType) + } + + processedMessage := make(map[string]interface{}) + for k, v := range message { + typ, ok := currentType[k] + if !ok { + return nil, fmt.Errorf("message field '%s' is missing type definition on '%s'", k, primaryType) + } + + // Extract base type and check if it's an array + baseType := typ + isArray := false + if idx := strings.Index(typ, "["); idx != -1 { + baseType = typ[:idx] + isArray = true + } + + // Process value based on whether it's a custom or primitive type + if _, isCustomType := typesMap[baseType]; isCustomType { + decoded, err := typedDataDecodeRawMessageMap(typesMap, baseType, v) + if err != nil { + return nil, err + } + processedMessage[k] = decoded + } else { + var decoded interface{} + var err error + if isArray { + decoded, err = typedDataDecodeRawMessageMap(typesMap, baseType, v) + } else { + decoded, err = typedDataDecodePrimitiveValue(baseType, v) + } + if err != nil { + return nil, fmt.Errorf("failed to decode field '%s': %w", k, err) + } + processedMessage[k] = decoded + } + } + + return processedMessage, nil +} + +func typedDataDecodePrimitiveValue(typ string, value interface{}) (interface{}, error) { + val := fmt.Sprintf("%v", value) + out, err := ABIUnmarshalStringValuesAny([]string{typ}, []any{val}) + if err != nil { + return nil, err + } + return out[0], nil } diff --git a/ethcoder/typed_data_test.go b/ethcoder/typed_data_test.go index 5b6e3ae..2f3d1d5 100644 --- a/ethcoder/typed_data_test.go +++ b/ethcoder/typed_data_test.go @@ -8,6 +8,7 @@ import ( "github.com/0xsequence/ethkit/ethwallet" "github.com/0xsequence/ethkit/go-ethereum/common" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTypedDataTypes(t *testing.T) { @@ -64,9 +65,8 @@ func TestTypedDataCase1(t *testing.T) { VerifyingContract: &verifyingContract, }, Message: map[string]interface{}{ - "name": "Bob", - // "wallet": common.HexToAddress("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), // NOTE: passing common.Address object works too - "wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", + "name": "Bob", + "wallet": common.HexToAddress("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), }, } @@ -74,27 +74,27 @@ func TestTypedDataCase1(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "0xf2cee375fa42b42143804025fc449deafd50cc031ca257e0b194a650a912090f", ethcoder.HexEncode(domainHash)) - digest, err := typedData.EncodeDigest() + digest, _, err := typedData.Encode() assert.NoError(t, err) assert.Equal(t, "0x0a94cf6625e5860fc4f330d75bcd0c3a4737957d2321d1a024540ab5320fe903", ethcoder.HexEncode(digest)) - // fmt.Println("===> digest", HexEncode(digest)) + // fmt.Println("===> digest", ethcoder.HexEncode(digest)) // lets sign it.. wallet, err := ethwallet.NewWalletFromMnemonic("dose weasel clever culture letter volume endorse used harvest ripple circle install") assert.NoError(t, err) - ethSigedTypedData, err := wallet.SignMessage([]byte(digest)) + ethSigedTypedData, encodedTypeData, err := wallet.SignTypedData(typedData) ethSigedTypedDataHex := ethcoder.HexEncode(ethSigedTypedData) assert.NoError(t, err) assert.Equal(t, - "0x842ed2d5c3bf97c4977ee84e600fec7d0f9c5e21d4090b5035a3ea650ec6127d18053e4aafb631de26eb3fd5d61e4a6f2d6a106ee8e3d8d5cb0c4571d06798741b", + "0x07cc7c723b24733e11494438927012ec9b086e8edcb06022231710988ff7e54c45b0bb8911b1e06d322eb24b919f2a479e3062fee75ce57c1f7d7fc16c371fa81b", ethSigedTypedDataHex, ) // recover / validate signature - valid, err := ethwallet.ValidateEthereumSignature(wallet.Address().Hex(), digest, ethSigedTypedDataHex) + valid, err := ethwallet.ValidateEthereumSignature(wallet.Address().Hex(), encodedTypeData, ethSigedTypedDataHex) assert.NoError(t, err) assert.True(t, valid) } @@ -124,9 +124,8 @@ func TestTypedDataCase2(t *testing.T) { VerifyingContract: &verifyingContract, }, Message: map[string]interface{}{ - "name": "Bob", - // "wallet": common.HexToAddress("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), // NOTE: passing common.Address object works too - "wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", + "name": "Bob", + "wallet": common.HexToAddress("0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB"), "count": uint8(4), }, } @@ -135,10 +134,160 @@ func TestTypedDataCase2(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "0xf2cee375fa42b42143804025fc449deafd50cc031ca257e0b194a650a912090f", ethcoder.HexEncode(domainHash)) - digest, err := typedData.EncodeDigest() + digest, _, err := typedData.Encode() assert.NoError(t, err) assert.Equal(t, "0x2218fda59750be7bb9e5dfb2b49e4ec000dc2542862c5826f1fe980d6d727e95", ethcoder.HexEncode(digest)) // fmt.Println("===> digest", HexEncode(digest)) +} + +func TestTypedDataFromJSON(t *testing.T) { + typedDataJson := `{ + "types": { + "EIP712Domain": [ + {"name": "name", "type": "string"}, + {"name": "version", "type": "string"}, + {"name": "chainId", "type": "uint256"}, + {"name": "verifyingContract", "type": "address"} + ], + "Person": [ + {"name": "name", "type": "string"}, + {"name": "wallet", "type": "address"}, + {"name": "count", "type": "uint8"} + ] + }, + "primaryType": "Person", + "domain": { + "name": "Ether Mail", + "version": "1", + "chainId": 1, + "verifyingContract": "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC" + }, + "message": { + "name": "Bob", + "wallet": "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", + "count": 4 + } + }` + + typedData, err := ethcoder.TypedDataFromJSON(typedDataJson) + require.NoError(t, err) + + domainHash, err := typedData.HashStruct("EIP712Domain", typedData.Domain.Map()) + require.NoError(t, err) + require.Equal(t, "0xf2cee375fa42b42143804025fc449deafd50cc031ca257e0b194a650a912090f", ethcoder.HexEncode(domainHash)) + + digest, typedDataEncoded, err := typedData.Encode() + require.NoError(t, err) + require.Equal(t, "0x2218fda59750be7bb9e5dfb2b49e4ec000dc2542862c5826f1fe980d6d727e95", ethcoder.HexEncode(digest)) + require.Equal(t, "0x1901f2cee375fa42b42143804025fc449deafd50cc031ca257e0b194a650a912090ff5117e79519388f3d62844df1325ebe783523d9db9762c50fa78a60400a20b5b", ethcoder.HexEncode(typedDataEncoded)) + + // Sign and validate + wallet, err := ethwallet.NewWalletFromMnemonic("dose weasel clever culture letter volume endorse used harvest ripple circle install") + require.NoError(t, err) + + ethSigedTypedData, typedDataEncodedOut, err := wallet.SignTypedData(typedData) + ethSigedTypedDataHex := ethcoder.HexEncode(ethSigedTypedData) + require.NoError(t, err) + require.Equal(t, typedDataEncoded, typedDataEncodedOut) + + // NOTE: this signature and above method has been compared against ethers v6 test + require.Equal(t, + "0x296c98bed8f3fd7ea96f55ca8148b4d092cbada953c8d9205b2fff759461ab4e6d6db0b78833b954684900530caeee9aaef8e42dfd8439a3fa107e910b57e2cc1b", + ethSigedTypedDataHex, + ) + // recover / validate signature + valid, err := ethwallet.ValidateEthereumSignature(wallet.Address().Hex(), typedDataEncodedOut, ethSigedTypedDataHex) + require.NoError(t, err) + require.True(t, valid) +} + +func TestTypedDataFromJSONPart2(t *testing.T) { + // NOTE: we omit the EIP712Domain type definition because it will + // automatically be added by the library if its not specified + typedDataJson := `{ + "types": { + "Person": [ + { "name": "name", "type": "string" }, + { "name": "wallets", "type": "address[]" } + ], + "Mail": [ + { "name": "from", "type": "Person" }, + { "name": "to", "type": "Person[]" }, + { "name": "contents", "type": "string" } + ] + }, + "domain": { + "name": "Ether Mail", + "version": "1", + "chainId": 1, + "verifyingContract": "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC" + }, + "primaryType": "Mail", + "message": { + "from": { + "name": "Cow", + "wallets": [ + "0xCD2a3d9F938E13CD947Ec05AbC7FE734Df8DD826", + "0xDeaDbeefdEAdbeefdEadbEEFdeadbeEFdEaDbeeF" + ] + }, + "to": [{ + "name": "Bob", + "wallets": [ + "0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB", + "0xB0BdaBea57B0BDABeA57b0bdABEA57b0BDabEa57", + "0xB0B0b0b0b0b0B000000000000000000000000000" + ] + }], + "contents": "Hello, Bob!" + } + }` + + typedData, err := ethcoder.TypedDataFromJSON(typedDataJson) + require.NoError(t, err) + + domainHash, err := typedData.HashStruct("EIP712Domain", typedData.Domain.Map()) + require.NoError(t, err) + require.Equal(t, "0xf2cee375fa42b42143804025fc449deafd50cc031ca257e0b194a650a912090f", ethcoder.HexEncode(domainHash)) + + personTypeHash, err := typedData.Types.TypeHash("Person") + require.NoError(t, err) + require.Equal(t, "0xfabfe1ed996349fc6027709802be19d047da1aa5d6894ff5f6486d92db2e6860", ethcoder.HexEncode(personTypeHash)) + + fromArg, ok := typedData.Message["from"].(map[string]interface{}) + require.True(t, ok) + personHashStruct, err := typedData.HashStruct("Person", fromArg) + require.NoError(t, err) + require.Equal(t, "0x9b4846dd48b866f0ac54d61b9b21a9e746f921cefa4ee94c4c0a1c49c774f67f", ethcoder.HexEncode(personHashStruct)) + + mailHashStruct, err := typedData.HashStruct("Mail", typedData.Message) + require.NoError(t, err) + require.Equal(t, "0xeb4221181ff3f1a83ea7313993ca9218496e424604ba9492bb4052c03d5c3df8", ethcoder.HexEncode(mailHashStruct)) + + digest, typedDataEncoded, err := typedData.Encode() + require.NoError(t, err) + require.Equal(t, "0xa85c2e2b118698e88db68a8105b794a8cc7cec074e89ef991cb4f5f533819cc2", ethcoder.HexEncode(digest)) + require.Equal(t, "0x1901f2cee375fa42b42143804025fc449deafd50cc031ca257e0b194a650a912090feb4221181ff3f1a83ea7313993ca9218496e424604ba9492bb4052c03d5c3df8", ethcoder.HexEncode(typedDataEncoded)) + + // Sign and validate + wallet, err := ethwallet.NewWalletFromMnemonic("dose weasel clever culture letter volume endorse used harvest ripple circle install") + require.NoError(t, err) + + ethSigedTypedData, typedDataEncodedOut, err := wallet.SignTypedData(typedData) + ethSigedTypedDataHex := ethcoder.HexEncode(ethSigedTypedData) + require.NoError(t, err) + require.Equal(t, typedDataEncoded, typedDataEncodedOut) + + // NOTE: this signature and above method has been compared against ethers v6 test + require.Equal(t, + "0xafd9e7d3b912a9ca989b622837ab92a8616446e6a517c486de5745dda166152f2d40f1d62593da438a65b58deacfdfbbeb7bbce2a12056815b19c678c563cc311c", + ethSigedTypedDataHex, + ) + + // recover / validate signature + valid, err := ethwallet.ValidateEthereumSignature(wallet.Address().Hex(), typedDataEncodedOut, ethSigedTypedDataHex) + require.NoError(t, err) + require.True(t, valid) } diff --git a/ethwallet/ethwallet.go b/ethwallet/ethwallet.go index 7804313..8f33113 100644 --- a/ethwallet/ethwallet.go +++ b/ethwallet/ethwallet.go @@ -7,6 +7,7 @@ import ( "fmt" "math/big" + "github.com/0xsequence/ethkit/ethcoder" "github.com/0xsequence/ethkit/ethrpc" "github.com/0xsequence/ethkit/ethtxn" "github.com/0xsequence/ethkit/go-ethereum/accounts" @@ -254,6 +255,11 @@ func (w *Wallet) SignTx(tx *types.Transaction, chainID *big.Int) (*types.Transac return signedTx, nil } +// SignMessage signs a message with EIP-191 prefix with the wallet's private key. +// +// This is the same as SignData, but it adds the prefix "Ethereum Signed Message:\n" to +// the message and encodes the length of the message in the prefix. In case the message +// already has the prefix, it will not be added again. func (w *Wallet) SignMessage(message []byte) ([]byte, error) { message191 := []byte("\x19Ethereum Signed Message:\n") if !bytes.HasPrefix(message, message191) { @@ -263,21 +269,40 @@ func (w *Wallet) SignMessage(message []byte) ([]byte, error) { } else { message191 = message } + return w.SignData(message191) +} - h := crypto.Keccak256(message191) - - sig, err := crypto.Sign(h, w.hdnode.PrivateKey()) +// SignTypedData signs a typed data with EIP-712 prefix with the wallet's private key. +// It returns the signature and the digest of the typed data. +func (w *Wallet) SignTypedData(typedData *ethcoder.TypedData) ([]byte, []byte, error) { + _, encodedData, err := typedData.Encode() if err != nil { - return []byte{}, err + return []byte{}, []byte{}, err } - sig[64] += 27 - return sig, nil + sig, err := w.SignData(encodedData) + if err != nil { + return []byte{}, []byte{}, err + } + return sig, encodedData, nil } +// SignData signs a message with the wallet's private key. +// +// This is the same as SignMessage, but it does not add the EIP-191 prefix. +// Please be careful with this method as it can be used to sign arbitrary data, but +// its helpful for signing typed data as defined by EIP-712. func (w *Wallet) SignData(data []byte) ([]byte, error) { + // NOTE: this is commended out for now, in case we use this method anywhere + // without expecting the data to be EIP191 prefixed. + // + // extra protection to ensure the input data is EIP191 prefixed + // if !(data[0] == 0x19 && (data[1] == 0x00 || data[1] == 0x01 || data[1] == 0x45)) { + // return nil, fmt.Errorf("invalid EIP191 input data") + // } + + // hash the data and sign it with the wallet's private key h := crypto.Keccak256(data) - sig, err := crypto.Sign(h, w.hdnode.PrivateKey()) if err != nil { return []byte{}, err diff --git a/ethwallet/utils.go b/ethwallet/utils.go index d189efc..f5caa4a 100644 --- a/ethwallet/utils.go +++ b/ethwallet/utils.go @@ -75,15 +75,37 @@ func IsValid191Signature(address common.Address, message, signature []byte) (boo return false, fmt.Errorf("signature is not of proper length") } - message191 := []byte("\x19Ethereum Signed Message:\n") - if !bytes.HasPrefix(message, message191) { + // Ensure EIP191 signature + var message191 []byte + personalSignPrefix := []byte("\x19Ethereum Signed Message:\n") + + if message[0] == 0x19 { + if message[1] == 0x45 { + // EIP191 for "Ethereum Signed Message" prefix + if !bytes.HasPrefix(message, personalSignPrefix) { + mlen := fmt.Sprintf("%d", len(message)) + message191 = append(personalSignPrefix, []byte(mlen)...) + message191 = append(message191, message...) + } else { + message191 = message + } + } else if message[1] == 0x01 { + // EIP191 for typed data + message191 = message + } + } + + // auto-prefix if message wasn't previously prefixed + if len(message191) == 0 { + // Message is not a EIP191, so we will automatically add the EIP191 prefix + // assuming its a message scheme. + message191 = personalSignPrefix mlen := fmt.Sprintf("%d", len(message)) message191 = append(message191, []byte(mlen)...) message191 = append(message191, message...) - } else { - message191 = message } + // Recovery the address from the signature sig := make([]byte, 65) copy(sig, signature)