Skip to content

Commit

Permalink
ethcoder: merkleproof (#123)
Browse files Browse the repository at this point in the history
* ethcoder: merkleproof

* update

* Correct implementation

* Default merkle tree options

* hex proof returns []byte

* MerkleTree uses generics for leaf

* Fix error handling verify

---------

Co-authored-by: Michael Standen <mstan@horizon.io>
  • Loading branch information
pkieltyka and ScreamingHawk authored Jun 6, 2024
1 parent d096aa7 commit 162ed06
Show file tree
Hide file tree
Showing 2 changed files with 344 additions and 0 deletions.
191 changes: 191 additions & 0 deletions ethcoder/merkle_proof.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package ethcoder

import (
"bytes"
"errors"
"sort"

"github.com/0xsequence/ethkit/go-ethereum/crypto"
)

type Options struct {
SortLeaves bool
SortPairs bool
}

var DefaultMerkleTreeOptions = Options{
// Default to true
SortLeaves: true,
SortPairs: true,
}

type Proof struct {
IsLeft bool
Data []byte
}

type MerkleTree[TLeaf any] struct {
sortLeaves bool
sortPairs bool
hashFn func(TLeaf) ([]byte, error)
leaves []TLeaf
layers [][][]byte
}

func NewMerkleTree[TLeaf any](leaves []TLeaf, hashFn *func(TLeaf) ([]byte, error), options *Options) *MerkleTree[TLeaf] {
if hashFn == nil {
// Assume TLeaf is []byte
fn := func(leaf TLeaf) ([]byte, error) {
return any(leaf).([]byte), nil
}
hashFn = &fn
}
if options == nil {
options = &DefaultMerkleTreeOptions
}
mt := &MerkleTree[TLeaf]{
hashFn: *hashFn,
sortLeaves: options.SortLeaves,
sortPairs: options.SortPairs,
}
mt.processLeaves(leaves)
return mt
}

func (mt *MerkleTree[TLeaf]) processLeaves(leaves []TLeaf) error {
mt.leaves = make([]TLeaf, len(leaves))
copy(mt.leaves, leaves)
nodes := make([][]byte, len(leaves))
if mt.sortLeaves {
sort.Slice(mt.leaves, func(i, j int) bool {
// Ignore err during sort
a, _ := mt.hashFn(mt.leaves[i])
b, _ := mt.hashFn(mt.leaves[j])
return bytes.Compare(a, b) < 0
})
}
for i, leaf := range mt.leaves {
node, err := mt.hashFn(leaf)
if err != nil {
return err
}
nodes[i] = node
}
mt.createHashes(nodes)
return nil
}

func (mt *MerkleTree[TLeaf]) createHashes(nodes [][]byte) {
mt.layers = make([][][]byte, 0)
mt.layers = append(mt.layers, nodes)
for len(nodes) > 1 {
var nextLayer [][]byte
for i := 0; i < len(nodes); i += 2 {
if i+1 == len(nodes) {
nextLayer = append(nextLayer, nodes[i])
} else {
left := nodes[i]
right := nodes[i+1]
if mt.sortPairs && bytes.Compare(left, right) > 0 {
left, right = right, left
}
hash := crypto.Keccak256(append(left, right...))
nextLayer = append(nextLayer, hash)
}
}
nodes = nextLayer
mt.layers = append(mt.layers, nodes)
}
}

func (mt *MerkleTree[TLeaf]) GetRoot() []byte {
if len(mt.layers) == 0 {
return nil
}
return mt.layers[len(mt.layers)-1][0]
}

func (mt *MerkleTree[TLeaf]) GetProof(leaf TLeaf) ([]Proof, error) {
leafIndex := -1
targetNode, err := mt.hashFn(leaf)
if err != nil {
return nil, err
}

for i, l := range mt.leaves {
// Ignore err. Already checked in processLeaves
node, _ := mt.hashFn(l)
if bytes.Equal(node, targetNode) {
leafIndex = i
break
}
}
if leafIndex == -1 {
return nil, errors.New("leaf not found in tree")
}

proof := []Proof{}
for i := 0; i < len(mt.layers)-1; i++ {
layer := mt.layers[i]
pairIndex := leafIndex ^ 1
if pairIndex < len(layer) {
isLeft := leafIndex%2 != 0
proof = append(proof, Proof{
IsLeft: isLeft,
Data: layer[pairIndex],
})
}
leafIndex /= 2
}
return proof, nil
}

func (mt *MerkleTree[TLeaf]) GetHexProof(leaf TLeaf) [][]byte {
proof, _ := mt.GetProof(leaf)
hexProof := make([][]byte, len(proof))
for _, p := range proof {
hexProof = append(hexProof, []byte(p.Data))
}
return hexProof
}

func (mt *MerkleTree[TLeaf]) Verify(proof []Proof, leaf TLeaf, root []byte) (bool, error) {
hash, err := mt.hashFn(leaf)
if err != nil {
return false, err
}

if proof == nil || len(hash) == 0 || len(root) == 0 {
return false, errors.New("invalid proof, leaf or root")
}

for i := 0; i < len(proof); i++ {
node := proof[i]
var data []byte
var isLeftNode bool

data = node.Data
isLeftNode = node.IsLeft

var buffers [][]byte

if mt.sortPairs {
if bytes.Compare(hash, data) < 0 {
buffers = append(buffers, hash, data)
} else {
buffers = append(buffers, data, hash)
}
hash = crypto.Keccak256(bytes.Join(buffers, []byte{}))
} else {
buffers = append(buffers, hash)
if isLeftNode {
buffers = append([][]byte{data}, buffers...)
} else {
buffers = append(buffers, data)
}
hash = crypto.Keccak256(bytes.Join(buffers, []byte{}))
}
}

return bytes.Equal(hash, root), nil
}
153 changes: 153 additions & 0 deletions ethcoder/merkle_proof_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package ethcoder

import (
"crypto/rand"
"fmt"
"math/big"
"testing"

"github.com/0xsequence/ethkit/go-ethereum/accounts/abi"
"github.com/0xsequence/ethkit/go-ethereum/common"
"github.com/stretchr/testify/assert"
)

func TestMerkleProofKnown(t *testing.T) {
testAddr := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E")
leaves := [][]byte{
testAddr.Bytes(),
common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B").Bytes(),
common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7").Bytes(),
common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36").Bytes(),
}
mt := NewMerkleTree(leaves, nil, nil)

expectedRoot := common.Hex2Bytes("2620d31912c95198ebbf40473b7b069e98587ec49d0cd46aacef8c746c682334")
root := mt.GetRoot()
assert.Equal(t, expectedRoot, root)
fmt.Printf("Root: %x\n", root)

expectedProof := [][]byte{
common.Hex2Bytes("1d74b866598b339006160d704642459b04ba890b"),
common.Hex2Bytes("39ceb165765d969b9bfbbab524649adc484bab29db86b6c0df8635feebf0154e"),
}
proof, err := mt.GetProof(testAddr.Bytes())
assert.Nil(t, err)
for i, p := range proof {
fmt.Printf("Proof part %d: IsLeft=%v, Data=%x\n", i, p.IsLeft, p.Data)
assert.Equal(t, expectedProof[i], []byte(p.Data))
}

isValid, err := mt.Verify(proof, testAddr.Bytes(), root)
assert.Nil(t, err)
assert.True(t, isValid)
}

func TestMerkleProofLarge(t *testing.T) {
addrCount := 100
leaves := make([][]byte, addrCount)
for i := 0; i < addrCount; i++ {
leaf := make([]byte, 20)
rand.Read(leaf)
leaves[i] = leaf
}

mt := NewMerkleTree(leaves, nil, nil)

root := mt.GetRoot()
assert.NotNil(t, root)

proof, err := mt.GetProof(leaves[69])
assert.Nil(t, err)
assert.GreaterOrEqual(t, len(proof), 1)

isValid, err := mt.Verify(proof, leaves[69], root)
assert.Nil(t, err)
assert.True(t, isValid)
}

func TestMerkleInvalidLeaf(t *testing.T) {
invalidLeaf := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E").Bytes()
leaves := [][]byte{
common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B").Bytes(),
common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7").Bytes(),
common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36").Bytes(),
}

mt := NewMerkleTree(leaves, nil, nil)

root := mt.GetRoot()
assert.NotNil(t, root)

// Invalid leaf
_, err := mt.GetProof(invalidLeaf)
assert.Error(t, err)

// Valid proof
proof, err := mt.GetProof(leaves[0])
assert.Nil(t, err)

// Invalid leaf
isValid, _ := mt.Verify(proof, invalidLeaf, root)
assert.False(t, isValid)
}

func TestMerkleSingleLeaf(t *testing.T) {
leaf := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E").Bytes()
leaves := [][]byte{
leaf,
}

mt := NewMerkleTree(leaves, nil, nil)

root := mt.GetRoot()
assert.NotNil(t, root)

proof, err := mt.GetProof(leaf)
assert.Nil(t, err)

isValid, err := mt.Verify(proof, leaf, root)
assert.Nil(t, err)
assert.True(t, isValid)
}

type TLeaf struct {
Addr common.Address
TokenId *big.Int
}

func TestMerkleProofHashFn(t *testing.T) {
addressTy, err := abi.NewType("address", "address", nil)
assert.Nil(t, err)
uintTy, err := abi.NewType("uint256", "uint256", nil)
assert.Nil(t, err)
arguments := []abi.Argument{
{Name: "addr", Type: addressTy},
{Name: "tokenId", Type: uintTy},
}

hashFn := func(leaf TLeaf) ([]byte, error) {
packed, err := abi.Arguments(arguments).Pack(leaf.Addr, leaf.TokenId)
if err != nil {
return nil, err
}
return Keccak256(packed), nil
}

leaves := make([]TLeaf, 4)
leaves[0] = TLeaf{Addr: common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E"), TokenId: big.NewInt(1)}
leaves[1] = TLeaf{Addr: common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B"), TokenId: big.NewInt(1)}
leaves[2] = TLeaf{Addr: common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7"), TokenId: big.NewInt(1)}
leaves[3] = TLeaf{Addr: common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36"), TokenId: big.NewInt(1)}

mt := NewMerkleTree(leaves, &hashFn, nil)

root := mt.GetRoot()
assert.NotNil(t, root)

proof, err := mt.GetProof(leaves[0])
assert.Nil(t, err)

isValid, err := mt.Verify(proof, leaves[0], root)
assert.Nil(t, err)
assert.True(t, isValid)
}

0 comments on commit 162ed06

Please sign in to comment.