Skip to content

Commit

Permalink
MerkleTree uses generics for leaf
Browse files Browse the repository at this point in the history
  • Loading branch information
ScreamingHawk committed Jun 6, 2024
1 parent a184c2d commit 8129f99
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 58 deletions.
128 changes: 78 additions & 50 deletions ethcoder/merkle_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ import (
"github.com/0xsequence/ethkit/go-ethereum/crypto"
)

type TLeaf []byte
type TLayer []TLeaf

type Options struct {
SortLeaves bool
SortPairs bool
Expand All @@ -23,45 +20,66 @@ var DefaultMerkleTreeOptions = Options{
}

type Proof struct {
IsLeft bool
Data TLeaf
IsLeft bool
Data []byte
}

type MerkleTree struct {
leaves []TLeaf
layers []TLayer
type MerkleTree[TLeaf any] struct {
sortLeaves bool
sortPairs bool
hashFn func(TLeaf) ([]byte, error)
leaves []TLeaf
layers [][][]byte
}

func NewMerkleTree(leaves []TLeaf, options *Options) *MerkleTree {
func NewMerkleTree[TLeaf any](leaves []TLeaf, hashFn *func(TLeaf) ([]byte, error), options *Options) *MerkleTree[TLeaf] {
if hashFn == nil {
// Assume TLeaf is []byte
fn := func(leaf TLeaf) ([]byte, error) {
return any(leaf).([]byte), nil
}
hashFn = &fn
}
if options == nil {
options = &DefaultMerkleTreeOptions
}
mt := &MerkleTree{
mt := &MerkleTree[TLeaf]{
hashFn: *hashFn,
sortLeaves: options.SortLeaves,
sortPairs: options.SortPairs,
}
mt.processLeaves(leaves)
return mt
}

func (mt *MerkleTree) processLeaves(leaves []TLeaf) {
func (mt *MerkleTree[TLeaf]) processLeaves(leaves []TLeaf) error {
mt.leaves = make([]TLeaf, len(leaves))
copy(mt.leaves, leaves)
nodes := make([][]byte, len(leaves))
if mt.sortLeaves {
sort.Slice(mt.leaves, func(i, j int) bool {
return bytes.Compare(mt.leaves[i], mt.leaves[j]) < 0
// Ignore err during sort
a, _ := mt.hashFn(mt.leaves[i])
b, _ := mt.hashFn(mt.leaves[j])
return bytes.Compare(a, b) < 0
})
}
mt.createHashes(mt.leaves)
for i, leaf := range mt.leaves {
node, err := mt.hashFn(leaf)
if err != nil {
return err
}
nodes[i] = node
}
mt.createHashes(nodes)
return nil
}

func (mt *MerkleTree) createHashes(nodes []TLeaf) {
mt.layers = make([]TLayer, 0)
func (mt *MerkleTree[TLeaf]) createHashes(nodes [][]byte) {
mt.layers = make([][][]byte, 0)
mt.layers = append(mt.layers, nodes)
for len(nodes) > 1 {
var nextLayer []TLeaf
var nextLayer [][]byte
for i := 0; i < len(nodes); i += 2 {
if i+1 == len(nodes) {
nextLayer = append(nextLayer, nodes[i])
Expand All @@ -80,17 +98,24 @@ func (mt *MerkleTree) createHashes(nodes []TLeaf) {
}
}

func (mt *MerkleTree) GetRoot() []byte {
func (mt *MerkleTree[TLeaf]) GetRoot() []byte {
if len(mt.layers) == 0 {
return TLeaf{}
return nil
}
return mt.layers[len(mt.layers)-1][0]
}

func (mt *MerkleTree) GetProof(leaf TLeaf) ([]Proof, error) {
func (mt *MerkleTree[TLeaf]) GetProof(leaf TLeaf) ([]Proof, error) {
leafIndex := -1
targetNode, err := mt.hashFn(leaf)
if err != nil {
return nil, err
}

for i, l := range mt.leaves {
if bytes.Equal(l, leaf) {
// Ignore err. Already checked in processLeaves
node, _ := mt.hashFn(l)
if bytes.Equal(node, targetNode) {
leafIndex = i
break
}
Expand All @@ -106,16 +131,16 @@ func (mt *MerkleTree) GetProof(leaf TLeaf) ([]Proof, error) {
if pairIndex < len(layer) {
isLeft := leafIndex%2 != 0
proof = append(proof, Proof{
IsLeft: isLeft,
Data: layer[pairIndex],
IsLeft: isLeft,
Data: layer[pairIndex],
})
}
leafIndex /= 2
}
return proof, nil
}

func (mt *MerkleTree) GetHexProof(leaf TLeaf) [][]byte {
func (mt *MerkleTree[TLeaf]) GetHexProof(leaf TLeaf) [][]byte {
proof, _ := mt.GetProof(leaf)
hexProof := make([][]byte, len(proof))
for _, p := range proof {
Expand All @@ -124,39 +149,42 @@ func (mt *MerkleTree) GetHexProof(leaf TLeaf) [][]byte {
return hexProof
}

func (mt *MerkleTree) Verify(proof []Proof, targetNode, root []byte) (bool, error) {
hash := targetNode
func (mt *MerkleTree[TLeaf]) Verify(proof []Proof, leaf TLeaf, root []byte) (bool, error) {
hash, err := mt.hashFn(leaf)
if err != nil {
return false, err
}

if proof == nil || len(targetNode) == 0 || len(root) == 0 {
return false, nil
if proof == nil || len(hash) == 0 || len(root) == 0 {
return false, nil
}

for i := 0; i < len(proof); i++ {
node := proof[i]
var data []byte
var isLeftNode bool

data = node.Data
isLeftNode = node.IsLeft

var buffers [][]byte

if mt.sortPairs {
if bytes.Compare(hash, data) < 0 {
buffers = append(buffers, hash, data)
} else {
buffers = append(buffers, data, hash)
}
hash = crypto.Keccak256(bytes.Join(buffers, []byte{}))
node := proof[i]
var data []byte
var isLeftNode bool

data = node.Data
isLeftNode = node.IsLeft

var buffers [][]byte

if mt.sortPairs {
if bytes.Compare(hash, data) < 0 {
buffers = append(buffers, hash, data)
} else {
buffers = append(buffers, data, hash)
}
hash = crypto.Keccak256(bytes.Join(buffers, []byte{}))
} else {
buffers = append(buffers, hash)
if isLeftNode {
buffers = append([][]byte{data}, buffers...)
} else {
buffers = append(buffers, hash)
if isLeftNode {
buffers = append([][]byte{data}, buffers...)
} else {
buffers = append(buffers, data)
}
hash = crypto.Keccak256(bytes.Join(buffers, []byte{}))
buffers = append(buffers, data)
}
hash = crypto.Keccak256(bytes.Join(buffers, []byte{}))
}
}

return bytes.Equal(hash, root), nil
Expand Down
60 changes: 52 additions & 8 deletions ethcoder/merkle_proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@ package ethcoder
import (
"crypto/rand"
"fmt"
"math/big"
"testing"

"github.com/0xsequence/ethkit/go-ethereum/accounts/abi"
"github.com/0xsequence/ethkit/go-ethereum/common"
"github.com/stretchr/testify/assert"
)

func TestMerkleProofKnown(t *testing.T) {
testAddr := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E")
leaves := []TLeaf{
leaves := [][]byte{
testAddr.Bytes(),
common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B").Bytes(),
common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7").Bytes(),
common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36").Bytes(),
}
mt := NewMerkleTree(leaves, nil)
mt := NewMerkleTree(leaves, nil, nil)

expectedRoot := common.Hex2Bytes("2620d31912c95198ebbf40473b7b069e98587ec49d0cd46aacef8c746c682334")
root := mt.GetRoot()
Expand All @@ -42,14 +44,14 @@ func TestMerkleProofKnown(t *testing.T) {

func TestMerkleProofLarge(t *testing.T) {
addrCount := 100
leaves := make([]TLeaf, addrCount)
leaves := make([][]byte, addrCount)
for i := 0; i < addrCount; i++ {
leaf := make([]byte, 20)
rand.Read(leaf)
leaves[i] = leaf
}

mt := NewMerkleTree(leaves, nil)
mt := NewMerkleTree(leaves, nil, nil)

root := mt.GetRoot()
assert.NotNil(t, root)
Expand All @@ -65,13 +67,13 @@ func TestMerkleProofLarge(t *testing.T) {

func TestMerkleInvalidLeaf(t *testing.T) {
invalidLeaf := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E").Bytes()
leaves := []TLeaf{
leaves := [][]byte{
common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B").Bytes(),
common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7").Bytes(),
common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36").Bytes(),
}

mt := NewMerkleTree(leaves, nil)
mt := NewMerkleTree(leaves, nil, nil)

root := mt.GetRoot()
assert.NotNil(t, root)
Expand All @@ -92,11 +94,11 @@ func TestMerkleInvalidLeaf(t *testing.T) {

func TestMerkleSingleLeaf(t *testing.T) {
leaf := common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E").Bytes()
leaves := []TLeaf{
leaves := [][]byte{
leaf,
}

mt := NewMerkleTree(leaves, nil)
mt := NewMerkleTree(leaves, nil, nil)

root := mt.GetRoot()
assert.NotNil(t, root)
Expand All @@ -108,3 +110,45 @@ func TestMerkleSingleLeaf(t *testing.T) {
assert.Nil(t, err)
assert.True(t, isValid)
}

type TLeaf struct {
Addr common.Address
TokenId *big.Int
}

func TestMerkleProofHashFn(t *testing.T) {
addressTy, err := abi.NewType("address", "address", nil)
assert.Nil(t, err)
uintTy, err := abi.NewType("uint256", "uint256", nil)
assert.Nil(t, err)
arguments := []abi.Argument{
{Name: "addr", Type: addressTy},
{Name: "tokenId", Type: uintTy},
}

hashFn := func(leaf TLeaf) ([]byte, error) {
packed, err := abi.Arguments(arguments).Pack(leaf.Addr, leaf.TokenId)
if err != nil {
return nil, err
}
return Keccak256(packed), nil
}

leaves := make([]TLeaf, 4)
leaves[0] = TLeaf{Addr: common.HexToAddress("0x1e946c284bdBb05Fb6EF41016C524E8681e3d05E"), TokenId: big.NewInt(1)}
leaves[1] = TLeaf{Addr: common.HexToAddress("0x1D74B866598B339006160d704642459B04ba890B"), TokenId: big.NewInt(1)}
leaves[2] = TLeaf{Addr: common.HexToAddress("0x37e948435E916069D3a1431Ddf508421073fF3E7"), TokenId: big.NewInt(1)}
leaves[3] = TLeaf{Addr: common.HexToAddress("0x29c34A7d23B8BCBE7c5Ec94C6525b78bb5cbAf36"), TokenId: big.NewInt(1)}

mt := NewMerkleTree(leaves, &hashFn, nil)

root := mt.GetRoot()
assert.NotNil(t, root)

proof, err := mt.GetProof(leaves[0])
assert.Nil(t, err)

isValid, err := mt.Verify(proof, leaves[0], root)
assert.Nil(t, err)
assert.True(t, isValid)
}

0 comments on commit 8129f99

Please sign in to comment.