From 0f45c5b6f1daab7b121732708853de379409d62b Mon Sep 17 00:00:00 2001 From: Emanuel Pargov Date: Mon, 13 Jan 2025 13:21:59 +0200 Subject: [PATCH] Add block store --- internal/block/block.go | 20 +++ internal/block/header.go | 23 +++ internal/store/chain.go | 203 +++++++++++++++++++++++++++ internal/store/chain_test.go | 264 +++++++++++++++++++++++++++++++++++ 4 files changed, 510 insertions(+) create mode 100644 internal/store/chain.go create mode 100644 internal/store/chain_test.go diff --git a/internal/block/block.go b/internal/block/block.go index 4e6bacbb..44292a05 100644 --- a/internal/block/block.go +++ b/internal/block/block.go @@ -1,5 +1,7 @@ package block +import "github.com/eigerco/strawberry/pkg/serialization/codec/jam" + // Block represents the main block structure type Block struct { Header Header @@ -14,3 +16,21 @@ type Extrinsic struct { EA AssurancesExtrinsic ED DisputeExtrinsic } + +// Bytes returns the Jam encoded bytes of the block +func (b Block) Bytes() ([]byte, error) { + bytes, err := jam.Marshal(b) + if err != nil { + return nil, err + } + return bytes, nil +} + +// BlockFromBytes unmarshals a block from Jam encoded bytes +func BlockFromBytes(data []byte) (Block, error) { + var block Block + if err := jam.Unmarshal(data, &block); err != nil { + return Block{}, err + } + return block, nil +} diff --git a/internal/block/header.go b/internal/block/header.go index dd65932b..af756e59 100644 --- a/internal/block/header.go +++ b/internal/block/header.go @@ -38,6 +38,29 @@ type EpochMarker struct { type WinningTicketMarker [jamtime.TimeslotsPerEpoch]Ticket +// Hash returns the hash of the header +func (h Header) Hash() (crypto.Hash, error) { + jamBytes, err := jam.Marshal(h) + if err != nil { + return crypto.Hash{}, fmt.Errorf("marshal header: %w", err) + } + return crypto.HashData(jamBytes), nil +} + +// Bytes returns the Jam encoded bytes of the header +func (h Header) Bytes() ([]byte, error) { + return jam.Marshal(h) +} + +// HeaderFromBytes unmarshals a header from Jam encoded bytes +func HeaderFromBytes(data []byte) (Header, error) { + var header Header + if err := jam.Unmarshal(data, &header); err != nil { + return Header{}, fmt.Errorf("unmarshal header: %w", err) + } + return header, nil +} + // AncestorStoreSingleton the in memory store for headers that need to be kept for 24 hours // TODO: Add 24 hours TTL var AncestorStoreSingleton = NewAncestorStore() diff --git a/internal/store/chain.go b/internal/store/chain.go new file mode 100644 index 00000000..9938003d --- /dev/null +++ b/internal/store/chain.go @@ -0,0 +1,203 @@ +package store + +import ( + "errors" + "fmt" + "log" + "sync/atomic" + + "github.com/eigerco/strawberry/internal/block" + "github.com/eigerco/strawberry/internal/crypto" + "github.com/eigerco/strawberry/pkg/db" + "github.com/eigerco/strawberry/pkg/db/pebble" +) + +var ( + ErrBlockNotFound = errors.New("block not found") + ErrChainClosed = errors.New("chain store is closed") +) + +const ( + prefixHeader byte = iota + 1 + prefixBlock +) + +// Chain manages blockchain storage using a key-value store +type Chain struct { + db db.KVStore + closed atomic.Bool +} + +// NewChain creates a new chain store using KVStore +func NewChain(db db.KVStore) *Chain { + return &Chain{db: db} +} + +// PutBlock stores a block and its header atomically +func (c *Chain) PutBlock(b block.Block) error { + if c.closed.Load() { + return ErrChainClosed + } + + // Create new batch for atomic operations + batch := c.db.NewBatch() + defer batch.Close() + + headerHash, err := b.Header.Hash() + if err != nil { + return fmt.Errorf("hash header: %w", err) + } + + // TODO: We should probably store the header here and refactor the Header file storage (AncestorStore) + // Store full block + blockBytes, err := b.Bytes() + if err != nil { + return fmt.Errorf("marshal block: %w", err) + } + if err := batch.Put(makeKey(prefixBlock, headerHash[:]), blockBytes); err != nil { + return fmt.Errorf("store block: %w", err) + } + + // Commit the batch + if err := batch.Commit(); err != nil { + return fmt.Errorf("commit batch: %w", err) + } + + return nil +} + +// GetBlock retrieves a block by its header hash +func (c *Chain) GetBlock(hash crypto.Hash) (block.Block, error) { + if c.closed.Load() { + return block.Block{}, ErrChainClosed + } + + blockBytes, err := c.db.Get(makeKey(prefixBlock, hash[:])) + if err != nil { + if errors.Is(err, pebble.ErrNotFound) { + return block.Block{}, ErrBlockNotFound + } + return block.Block{}, fmt.Errorf("get block: %w", err) + } + + return block.BlockFromBytes(blockBytes) +} + +// FindChildren finds all immediate child blocks for a given block hash +func (c *Chain) FindChildren(parentHash crypto.Hash) ([]block.Block, error) { + if c.closed.Load() { + return nil, ErrChainClosed + } + + var children []block.Block + + // Create iterator for block prefix + iter, err := c.db.NewIterator([]byte{prefixBlock}, []byte{prefixBlock + 1}) + if err != nil { + return nil, fmt.Errorf("create iterator: %w", err) + } + defer iter.Close() + + // Iterate through blocks + for iter.Next() { + blockBytes, err := iter.Value() + if err != nil { + log.Println("read block value from iterator", err) + continue + } + b, err := block.BlockFromBytes(blockBytes) + if err != nil { + log.Println("parse block from bytes", err) + continue + } + + if b.Header.ParentHash == parentHash { + children = append(children, b) + } + } + + return children, nil +} + +// GetBlockSequence retrieves a sequence of blocks. +// If ascending is true, returns children of the start block (exclusive). +// If ascending is false, returns the start block and its ancestors (inclusive). +func (c *Chain) GetBlockSequence(startHash crypto.Hash, ascending bool, maxBlocks uint32) ([]block.Block, error) { + if c.closed.Load() { + return nil, ErrChainClosed + } + + currentBlock, err := c.GetBlock(startHash) + if err != nil { + if errors.Is(err, ErrBlockNotFound) { + return nil, fmt.Errorf("starting block not found: %w", err) + } + return nil, fmt.Errorf("get starting block: %w", err) + } + + var blocks []block.Block + currentHash := startHash + + for uint32(len(blocks)) < maxBlocks { + if ascending { + // For ascending (exclusive), skip first block + if currentHash != startHash { + blocks = append(blocks, currentBlock) + } + // Find children and take the first one + children, err := c.FindChildren(currentHash) + if err != nil || len(children) == 0 { + break + } + + // Get hash for next iteration + currentHash, err = children[0].Header.Hash() + if err != nil { + return nil, fmt.Errorf("marshal child header: %w", err) + } + } else { + // For descending (inclusive), include current and follow parent + blocks = append(blocks, currentBlock) + currentHash = currentBlock.Header.ParentHash + } + + // Retrieve next block + currentBlock, err = c.GetBlock(currentHash) + if err != nil { + if errors.Is(err, ErrBlockNotFound) { + break + } + return nil, fmt.Errorf("get block in sequence: %w", err) + } + } + + return blocks, nil +} + +// Close closes the chain store +func (c *Chain) Close() error { + if !c.closed.CompareAndSwap(false, true) { + return nil + } + return c.db.Close() +} + +// PrefixToString converts a prefix byte to a string +func PrefixToString(p byte) string { + switch p { + case prefixHeader: + return "header" + case prefixBlock: + return "block" + default: + return "unknown" + } +} + +// makeKey creates a key from a prefix and hash +func makeKey(prefix byte, hash []byte) []byte { + key := make([]byte, 1+len(hash)) + key[0] = prefix + copy(key[1:], hash) + return key +} diff --git a/internal/store/chain_test.go b/internal/store/chain_test.go new file mode 100644 index 00000000..3bf7ec71 --- /dev/null +++ b/internal/store/chain_test.go @@ -0,0 +1,264 @@ +package store + +import ( + "testing" + + "github.com/eigerco/strawberry/internal/block" + "github.com/eigerco/strawberry/internal/crypto" + "github.com/eigerco/strawberry/internal/jamtime" + "github.com/eigerco/strawberry/internal/testutils" + "github.com/eigerco/strawberry/pkg/db/pebble" + "github.com/eigerco/strawberry/pkg/serialization/codec/jam" + "github.com/stretchr/testify/require" +) + +func Test_PutGetBlock(t *testing.T) { + chain := newStore(t) + header := block.Header{ + ParentHash: testutils.RandomHash(t), + } + hb, err := jam.Marshal(header) + require.NoError(t, err) + hh := crypto.HashData(hb) + block := block.Block{ + Header: header, + } + err = chain.PutBlock(block) + require.NoError(t, err) + resultBlock, err := chain.GetBlock(hh) + require.NoError(t, err) + require.Equal(t, header.ParentHash, resultBlock.Header.ParentHash) +} + +func Test_GetBlockNotFound(t *testing.T) { + chain := newStore(t) + _, err := chain.GetBlock(testutils.RandomHash(t)) + require.Error(t, err) + require.Equal(t, ErrBlockNotFound, err) +} + +func Test_Close(t *testing.T) { + chain := newStore(t) + err := chain.Close() + require.NoError(t, err) + err = chain.Close() + // Closing a closed chain should have no effect/error + require.NoError(t, err) +} + +func Test_ChainClosed(t *testing.T) { + chain := newStore(t) + chain.Close() + _, err := chain.GetBlock(testutils.RandomHash(t)) + require.Error(t, err) + require.Equal(t, ErrChainClosed, err) +} + +func Test_FindChildren(t *testing.T) { + chain := newStore(t) + + // Create parent block + parentBlock := block.Block{ + Header: block.Header{ + ParentHash: crypto.Hash{}, + }, + } + err := chain.PutBlock(parentBlock) + require.NoError(t, err) + + ph, err := jam.Marshal(parentBlock.Header) + require.NoError(t, err) + + // Create child blocks + childBlock1 := block.Block{ + Header: block.Header{ + ParentHash: crypto.HashData(ph), + // Random data so that the block hash is unique + ExtrinsicHash: testutils.RandomHash(t), + }, + } + err = chain.PutBlock(childBlock1) + require.NoError(t, err) + + childBlock2 := block.Block{ + Header: block.Header{ + ParentHash: crypto.HashData(ph), + // Random data so that the block hash is unique + ExtrinsicHash: testutils.RandomHash(t), + }, + } + err = chain.PutBlock(childBlock2) + require.NoError(t, err) + + // Find children of parent block + children, err := chain.FindChildren(crypto.HashData(ph)) + require.NoError(t, err) + require.Len(t, children, 2) + require.ElementsMatch(t, []block.Block{childBlock1, childBlock2}, children) +} + +func Test_FindChildren_NoChildren(t *testing.T) { + chain := newStore(t) + + // Create parent block + parentBlock := block.Block{ + Header: block.Header{ + ParentHash: crypto.Hash{}, + }, + } + err := chain.PutBlock(parentBlock) + require.NoError(t, err) + + ph, err := jam.Marshal(parentBlock.Header) + require.NoError(t, err) + // Find children of parent block (should be none) + children, err := chain.FindChildren(crypto.HashData(ph)) + require.NoError(t, err) + require.Empty(t, children) +} + +func Test_FindChildren_ChainClosed(t *testing.T) { + chain := newStore(t) + chain.Close() + + _, err := chain.FindChildren(testutils.RandomHash(t)) + require.Error(t, err) + require.Equal(t, ErrChainClosed, err) +} + +func Test_GetBlockSequence_Ascending(t *testing.T) { + chain := newStore(t) + + // Create a sequence of blocks + blocks := createNumOfRandomBlocks(5, t) + for _, b := range blocks { + err := chain.PutBlock(b) + require.NoError(t, err) + } + + // Get the hash of the first block + startHash, err := blocks[0].Header.Hash() + require.NoError(t, err) + + // Retrieve the sequence in ascending order + sequence, err := chain.GetBlockSequence(startHash, true, 4) + require.NoError(t, err) + require.Len(t, sequence, 4) + require.Equal(t, blocks[1:], sequence) // Should exclude the start block +} + +func Test_GetBlockSequence_AscendingRequestTooMany(t *testing.T) { + chain := newStore(t) + + // Create a sequence of blocks + blocks := createNumOfRandomBlocks(5, t) + for _, b := range blocks { + err := chain.PutBlock(b) + require.NoError(t, err) + } + + // Get the hash of the first block + startHash, err := blocks[0].Header.Hash() + require.NoError(t, err) + + // Request more blocks than available + sequence, err := chain.GetBlockSequence(startHash, true, 10) + require.NoError(t, err) + require.Len(t, sequence, 4) + require.Equal(t, blocks[1:], sequence) // Should exclude the start block +} + +func Test_GetBlockSequence_Descending(t *testing.T) { + chain := newStore(t) + + // Create a sequence of blocks + blocks := createNumOfRandomBlocks(5, t) + for _, b := range blocks { + err := chain.PutBlock(b) + require.NoError(t, err) + } + + // Get the hash of the last block + startHash, err := blocks[len(blocks)-1].Header.Hash() + require.NoError(t, err) + + // Retrieve the sequence in descending order + sequence, err := chain.GetBlockSequence(startHash, false, 5) + require.NoError(t, err) + require.Len(t, sequence, 5) // Should include the start block + for i := range sequence { + // The sequence should be in reverse order + require.Equal(t, blocks[len(blocks)-1-i], sequence[i]) + } +} + +func Test_GetBlockSequence_DescendingRequestTooMany(t *testing.T) { + chain := newStore(t) + + // Create a sequence of blocks + blocks := createNumOfRandomBlocks(5, t) + for _, b := range blocks { + err := chain.PutBlock(b) + require.NoError(t, err) + } + + // Get the hash of the last block + startHash, err := blocks[len(blocks)-1].Header.Hash() + require.NoError(t, err) + + // Retrieve the sequence in descending order + sequence, err := chain.GetBlockSequence(startHash, false, 10) + require.NoError(t, err) + require.Len(t, sequence, 5) // Should include the start block + for i := range sequence { + // The sequence should be in reverse order + require.Equal(t, blocks[len(blocks)-1-i], sequence[i]) + } +} + +func Test_GetBlockSequence_ChainClosed(t *testing.T) { + chain := newStore(t) + chain.Close() + + _, err := chain.GetBlockSequence(testutils.RandomHash(t), true, 5) + require.Error(t, err) + require.Equal(t, ErrChainClosed, err) +} + +// CreateRandomBlock generates a random block for testing purposes +func createRandomBlock(parentHash crypto.Hash, slot jamtime.Timeslot, t *testing.T) block.Block { + // Generate a random block header + header := block.Header{ + ParentHash: parentHash, // Parent hash (passed as argument) + TimeSlotIndex: slot, // Slot (passed as argument) + // Populate other header fields with random data + PriorStateRoot: testutils.RandomHash(t), + ExtrinsicHash: testutils.RandomHash(t), + } + // Return the random block + return block.Block{ + Header: header, + } +} + +func createNumOfRandomBlocks(num int, t *testing.T) []block.Block { + blocks := []block.Block{} + prevB := createRandomBlock(crypto.Hash{}, jamtime.MinTimeslot, t) + hh, _ := prevB.Header.Hash() + blocks = append(blocks, prevB) + for range num - 1 { + b := createRandomBlock(hh, prevB.Header.TimeSlotIndex+1, t) + blocks = append(blocks, b) + prevB = b + h, _ := prevB.Header.Bytes() + hh = crypto.HashData(h) + } + return blocks +} + +func newStore(t *testing.T) *Chain { + kvStore, err := pebble.NewKVStore() + require.NoError(t, err) + chain := NewChain(kvStore) + return chain +}