-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* ethcoder: merkleproof * update * Correct implementation * Default merkle tree options * hex proof returns []byte * MerkleTree uses generics for leaf * Fix error handling verify --------- Co-authored-by: Michael Standen <mstan@horizon.io>
- Loading branch information
1 parent
d096aa7
commit 162ed06
Showing
2 changed files
with
344 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
package ethcoder | ||
|
||
import ( | ||
"bytes" | ||
"errors" | ||
"sort" | ||
|
||
"github.com/0xsequence/ethkit/go-ethereum/crypto" | ||
) | ||
|
||
type Options struct { | ||
SortLeaves bool | ||
SortPairs bool | ||
} | ||
|
||
var DefaultMerkleTreeOptions = Options{ | ||
// Default to true | ||
SortLeaves: true, | ||
SortPairs: true, | ||
} | ||
|
||
type Proof struct { | ||
IsLeft bool | ||
Data []byte | ||
} | ||
|
||
type MerkleTree[TLeaf any] struct { | ||
sortLeaves bool | ||
sortPairs bool | ||
hashFn func(TLeaf) ([]byte, error) | ||
leaves []TLeaf | ||
layers [][][]byte | ||
} | ||
|
||
func NewMerkleTree[TLeaf any](leaves []TLeaf, hashFn *func(TLeaf) ([]byte, error), options *Options) *MerkleTree[TLeaf] { | ||
if hashFn == nil { | ||
// Assume TLeaf is []byte | ||
fn := func(leaf TLeaf) ([]byte, error) { | ||
return any(leaf).([]byte), nil | ||
} | ||
hashFn = &fn | ||
} | ||
if options == nil { | ||
options = &DefaultMerkleTreeOptions | ||
} | ||
mt := &MerkleTree[TLeaf]{ | ||
hashFn: *hashFn, | ||
sortLeaves: options.SortLeaves, | ||
sortPairs: options.SortPairs, | ||
} | ||
mt.processLeaves(leaves) | ||
return mt | ||
} | ||
|
||
func (mt *MerkleTree[TLeaf]) processLeaves(leaves []TLeaf) error { | ||
mt.leaves = make([]TLeaf, len(leaves)) | ||
copy(mt.leaves, leaves) | ||
nodes := make([][]byte, len(leaves)) | ||
if mt.sortLeaves { | ||
sort.Slice(mt.leaves, func(i, j int) bool { | ||
// Ignore err during sort | ||
a, _ := mt.hashFn(mt.leaves[i]) | ||
b, _ := mt.hashFn(mt.leaves[j]) | ||
return bytes.Compare(a, b) < 0 | ||
}) | ||
} | ||
for i, leaf := range mt.leaves { | ||
node, err := mt.hashFn(leaf) | ||
if err != nil { | ||
return err | ||
} | ||
nodes[i] = node | ||
} | ||
mt.createHashes(nodes) | ||
return nil | ||
} | ||
|
||
func (mt *MerkleTree[TLeaf]) createHashes(nodes [][]byte) { | ||
mt.layers = make([][][]byte, 0) | ||
mt.layers = append(mt.layers, nodes) | ||
for len(nodes) > 1 { | ||
var nextLayer [][]byte | ||
for i := 0; i < len(nodes); i += 2 { | ||
if i+1 == len(nodes) { | ||
nextLayer = append(nextLayer, nodes[i]) | ||
} else { | ||
left := nodes[i] | ||
right := nodes[i+1] | ||
if mt.sortPairs && bytes.Compare(left, right) > 0 { | ||
left, right = right, left | ||
} | ||
hash := crypto.Keccak256(append(left, right...)) | ||
nextLayer = append(nextLayer, hash) | ||
} | ||
} | ||
nodes = nextLayer | ||
mt.layers = append(mt.layers, nodes) | ||
} | ||
} | ||
|
||
func (mt *MerkleTree[TLeaf]) GetRoot() []byte { | ||
if len(mt.layers) == 0 { | ||
return nil | ||
} | ||
return mt.layers[len(mt.layers)-1][0] | ||
} | ||
|
||
func (mt *MerkleTree[TLeaf]) GetProof(leaf TLeaf) ([]Proof, error) { | ||
leafIndex := -1 | ||
targetNode, err := mt.hashFn(leaf) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
for i, l := range mt.leaves { | ||
// Ignore err. Already checked in processLeaves | ||
node, _ := mt.hashFn(l) | ||
if bytes.Equal(node, targetNode) { | ||
leafIndex = i | ||
break | ||
} | ||
} | ||
if leafIndex == -1 { | ||
return nil, errors.New("leaf not found in tree") | ||
} | ||
|
||
proof := []Proof{} | ||
for i := 0; i < len(mt.layers)-1; i++ { | ||
layer := mt.layers[i] | ||
pairIndex := leafIndex ^ 1 | ||
if pairIndex < len(layer) { | ||
isLeft := leafIndex%2 != 0 | ||
proof = append(proof, Proof{ | ||
IsLeft: isLeft, | ||
Data: layer[pairIndex], | ||
}) | ||
} | ||
leafIndex /= 2 | ||
} | ||
return proof, nil | ||
} | ||
|
||
func (mt *MerkleTree[TLeaf]) GetHexProof(leaf TLeaf) [][]byte { | ||
proof, _ := mt.GetProof(leaf) | ||
hexProof := make([][]byte, len(proof)) | ||
for _, p := range proof { | ||
hexProof = append(hexProof, []byte(p.Data)) | ||
} | ||
return hexProof | ||
} | ||
|
||
func (mt *MerkleTree[TLeaf]) Verify(proof []Proof, leaf TLeaf, root []byte) (bool, error) { | ||
hash, err := mt.hashFn(leaf) | ||
if err != nil { | ||
return false, err | ||
} | ||
|
||
if proof == nil || len(hash) == 0 || len(root) == 0 { | ||
return false, errors.New("invalid proof, leaf or root") | ||
} | ||
|
||
for i := 0; i < len(proof); i++ { | ||
node := proof[i] | ||
var data []byte | ||
var isLeftNode bool | ||
|
||
data = node.Data | ||
isLeftNode = node.IsLeft | ||
|
||
var buffers [][]byte | ||
|
||
if mt.sortPairs { | ||
if bytes.Compare(hash, data) < 0 { | ||
buffers = append(buffers, hash, data) | ||
} else { | ||
buffers = append(buffers, data, hash) | ||
} | ||
hash = crypto.Keccak256(bytes.Join(buffers, []byte{})) | ||
} else { | ||
buffers = append(buffers, hash) | ||
if isLeftNode { | ||
buffers = append([][]byte{data}, buffers...) | ||
} else { | ||
buffers = append(buffers, data) | ||
} | ||
hash = crypto.Keccak256(bytes.Join(buffers, []byte{})) | ||
} | ||
} | ||
|
||
return bytes.Equal(hash, root), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
package ethcoder | ||
|
||
import ( | ||
"crypto/rand" | ||
"fmt" | ||
"math/big" | ||
"testing" | ||
|
||
"github.com/0xsequence/ethkit/go-ethereum/accounts/abi" | ||
"github.com/0xsequence/ethkit/go-ethereum/common" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestMerkleProofKnown(t *testing.T) { | ||
testAddr := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E") | ||
leaves := [][]byte{ | ||
testAddr.Bytes(), | ||
common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B").Bytes(), | ||
common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7").Bytes(), | ||
common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36").Bytes(), | ||
} | ||
mt := NewMerkleTree(leaves, nil, nil) | ||
|
||
expectedRoot := common.Hex2Bytes("2620d31912c95198ebbf40473b7b069e98587ec49d0cd46aacef8c746c682334") | ||
root := mt.GetRoot() | ||
assert.Equal(t, expectedRoot, root) | ||
fmt.Printf("Root: %x\n", root) | ||
|
||
expectedProof := [][]byte{ | ||
common.Hex2Bytes("1d74b866598b339006160d704642459b04ba890b"), | ||
common.Hex2Bytes("39ceb165765d969b9bfbbab524649adc484bab29db86b6c0df8635feebf0154e"), | ||
} | ||
proof, err := mt.GetProof(testAddr.Bytes()) | ||
assert.Nil(t, err) | ||
for i, p := range proof { | ||
fmt.Printf("Proof part %d: IsLeft=%v, Data=%x\n", i, p.IsLeft, p.Data) | ||
assert.Equal(t, expectedProof[i], []byte(p.Data)) | ||
} | ||
|
||
isValid, err := mt.Verify(proof, testAddr.Bytes(), root) | ||
assert.Nil(t, err) | ||
assert.True(t, isValid) | ||
} | ||
|
||
func TestMerkleProofLarge(t *testing.T) { | ||
addrCount := 100 | ||
leaves := make([][]byte, addrCount) | ||
for i := 0; i < addrCount; i++ { | ||
leaf := make([]byte, 20) | ||
rand.Read(leaf) | ||
leaves[i] = leaf | ||
} | ||
|
||
mt := NewMerkleTree(leaves, nil, nil) | ||
|
||
root := mt.GetRoot() | ||
assert.NotNil(t, root) | ||
|
||
proof, err := mt.GetProof(leaves[69]) | ||
assert.Nil(t, err) | ||
assert.GreaterOrEqual(t, len(proof), 1) | ||
|
||
isValid, err := mt.Verify(proof, leaves[69], root) | ||
assert.Nil(t, err) | ||
assert.True(t, isValid) | ||
} | ||
|
||
func TestMerkleInvalidLeaf(t *testing.T) { | ||
invalidLeaf := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E").Bytes() | ||
leaves := [][]byte{ | ||
common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B").Bytes(), | ||
common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7").Bytes(), | ||
common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36").Bytes(), | ||
} | ||
|
||
mt := NewMerkleTree(leaves, nil, nil) | ||
|
||
root := mt.GetRoot() | ||
assert.NotNil(t, root) | ||
|
||
// Invalid leaf | ||
_, err := mt.GetProof(invalidLeaf) | ||
assert.Error(t, err) | ||
|
||
// Valid proof | ||
proof, err := mt.GetProof(leaves[0]) | ||
assert.Nil(t, err) | ||
|
||
// Invalid leaf | ||
isValid, _ := mt.Verify(proof, invalidLeaf, root) | ||
assert.False(t, isValid) | ||
} | ||
|
||
func TestMerkleSingleLeaf(t *testing.T) { | ||
leaf := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E").Bytes() | ||
leaves := [][]byte{ | ||
leaf, | ||
} | ||
|
||
mt := NewMerkleTree(leaves, nil, nil) | ||
|
||
root := mt.GetRoot() | ||
assert.NotNil(t, root) | ||
|
||
proof, err := mt.GetProof(leaf) | ||
assert.Nil(t, err) | ||
|
||
isValid, err := mt.Verify(proof, leaf, root) | ||
assert.Nil(t, err) | ||
assert.True(t, isValid) | ||
} | ||
|
||
type TLeaf struct { | ||
Addr common.Address | ||
TokenId *big.Int | ||
} | ||
|
||
func TestMerkleProofHashFn(t *testing.T) { | ||
addressTy, err := abi.NewType("address", "address", nil) | ||
assert.Nil(t, err) | ||
uintTy, err := abi.NewType("uint256", "uint256", nil) | ||
assert.Nil(t, err) | ||
arguments := []abi.Argument{ | ||
{Name: "addr", Type: addressTy}, | ||
{Name: "tokenId", Type: uintTy}, | ||
} | ||
|
||
hashFn := func(leaf TLeaf) ([]byte, error) { | ||
packed, err := abi.Arguments(arguments).Pack(leaf.Addr, leaf.TokenId) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return Keccak256(packed), nil | ||
} | ||
|
||
leaves := make([]TLeaf, 4) | ||
leaves[0] = TLeaf{Addr: common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E"), TokenId: big.NewInt(1)} | ||
leaves[1] = TLeaf{Addr: common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B"), TokenId: big.NewInt(1)} | ||
leaves[2] = TLeaf{Addr: common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7"), TokenId: big.NewInt(1)} | ||
leaves[3] = TLeaf{Addr: common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36"), TokenId: big.NewInt(1)} | ||
|
||
mt := NewMerkleTree(leaves, &hashFn, nil) | ||
|
||
root := mt.GetRoot() | ||
assert.NotNil(t, root) | ||
|
||
proof, err := mt.GetProof(leaves[0]) | ||
assert.Nil(t, err) | ||
|
||
isValid, err := mt.Verify(proof, leaves[0], root) | ||
assert.Nil(t, err) | ||
assert.True(t, isValid) | ||
} |