Skip to content

Commit

Permalink
fix(eds): WriteEDS thread safety for concurrent writingSessions (#1498)
Browse files Browse the repository at this point in the history
  • Loading branch information
distractedm1nd authored Dec 20, 2022
1 parent d57e493 commit d96f79e
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 27 deletions.
34 changes: 18 additions & 16 deletions share/eds/eds.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
format "github.com/ipfs/go-ipld-format"
"github.com/ipld/go-car"
"github.com/ipld/go-car/util"
"github.com/minio/sha256-simd"

"github.com/celestiaorg/celestia-app/pkg/appconsts"
"github.com/celestiaorg/celestia-app/pkg/da"
Expand All @@ -33,11 +34,10 @@ var ErrEmptySquare = errors.New("share: importing empty data")
// writingSession contains the components needed to write an EDS to a CARv1 file with our custom
// node order.
type writingSession struct {
eds *rsmt2d.ExtendedDataSquare
// store is an in-memory blockstore, used to cache the inner nodes (proofs) while we walk the nmt
// tree.
store bstore.Blockstore
w io.Writer
eds *rsmt2d.ExtendedDataSquare
store bstore.Blockstore // caches inner nodes (proofs) while we walk the nmt tree.
hasher *nmt.Hasher
w io.Writer
}

// WriteEDS writes the entire EDS into the given io.Writer as CARv1 file.
Expand Down Expand Up @@ -106,9 +106,10 @@ func initializeWriter(ctx context.Context, eds *rsmt2d.ExtendedDataSquare, w io.
}

return &writingSession{
eds: eds,
store: store,
w: w,
eds: eds,
store: store,
hasher: nmt.NewNmtHasher(sha256.New(), ipld.NamespaceSize, ipld.NMTIgnoreMaxNamespace),
w: w,
}, nil
}

Expand All @@ -129,7 +130,7 @@ func (w *writingSession) writeHeader() error {
func (w *writingSession) writeQuadrants() error {
shares := quadrantOrder(w.eds)
for _, share := range shares {
cid, err := ipld.CidFromNamespacedSha256(nmt.Sha256Namespace8FlaggedLeaf(share))
cid, err := ipld.CidFromNamespacedSha256(w.hasher.HashLeaf(share))
if err != nil {
return fmt.Errorf("getting cid from share: %w", err)
}
Expand All @@ -151,15 +152,18 @@ func (w *writingSession) writeProofs(ctx context.Context) error {
return fmt.Errorf("getting all keys from the blockstore: %w", err)
}
for proofCid := range proofs {
node, err := w.store.Get(ctx, proofCid)
block, err := w.store.Get(ctx, proofCid)
if err != nil {
return fmt.Errorf("getting proof from the blockstore: %w", err)
}
cid, err := ipld.CidFromNamespacedSha256(nmt.Sha256Namespace8FlaggedInner(node.RawData()))

node := block.RawData()
left, right := node[:ipld.NmtHashSize], node[ipld.NmtHashSize:]
cid, err := ipld.CidFromNamespacedSha256(w.hasher.HashNode(left, right))
if err != nil {
return fmt.Errorf("getting cid: %w", err)
}
err = util.LdWrite(w.w, cid.Bytes(), node.RawData())
err = util.LdWrite(w.w, cid.Bytes(), node)
if err != nil {
return fmt.Errorf("writing proof to the car: %w", err)
}
Expand Down Expand Up @@ -201,11 +205,10 @@ func getQuadrantCells(eds *rsmt2d.ExtendedDataSquare, i, j uint) [][]byte {

// prependNamespace adds the namespace to the passed share if in the first quadrant,
// otherwise it adds the ParitySharesNamespace to the beginning.
// TODO(@walldiss): this method will be obselete once the redundant namespace is removed
func prependNamespace(quadrant int, share []byte) []byte {
switch quadrant {
case 0:
return append(share[:appconsts.NamespaceSize], share...)
return append(share[:ipld.NamespaceSize], share...)
case 1, 2, 3:
return append(appconsts.ParitySharesNamespaceID, share...)
default:
Expand Down Expand Up @@ -250,8 +253,7 @@ func ReadEDS(ctx context.Context, r io.Reader, root share.DataHash) (*rsmt2d.Ext
}
// the stored first quadrant shares are wrapped with the namespace twice.
// we cut it off here, because it is added again while importing to the tree below
// TODO(@walldiss): remove redundant namespace
shares[i] = block.RawData()[appconsts.NamespaceSize:]
shares[i] = block.RawData()[ipld.NamespaceSize:]
}

eds, err := rsmt2d.ComputeExtendedDataSquare(
Expand Down
8 changes: 4 additions & 4 deletions share/ipld/namespace_hasher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func TestNamespaceHasherWrite(t *testing.T) {
leafSize := appconsts.ShareSize + appconsts.NamespaceSize
innerSize := nmtHashSize * 2
innerSize := NmtHashSize * 2
tt := []struct {
name string
expectedSize int
Expand Down Expand Up @@ -59,20 +59,20 @@ func TestNamespaceHasherWrite(t *testing.T) {

func TestNamespaceHasherSum(t *testing.T) {
leafSize := appconsts.ShareSize + appconsts.NamespaceSize
innerSize := nmtHashSize * 2
innerSize := NmtHashSize * 2
tt := []struct {
name string
expectedSize int
writtenSize int
}{
{
"Leaf",
nmtHashSize,
NmtHashSize,
leafSize,
},
{
"Inner",
nmtHashSize,
NmtHashSize,
innerSize,
},
{
Expand Down
12 changes: 6 additions & 6 deletions share/ipld/nmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ const (
// NamespaceSize is a system-wide size for NMT namespaces.
NamespaceSize = appconsts.NamespaceSize

// nmtHashSize is the size of a digest created by an NMT in bytes.
nmtHashSize = 2*NamespaceSize + sha256.Size
// NmtHashSize is the size of a digest created by an NMT in bytes.
NmtHashSize = 2*NamespaceSize + sha256.Size

// innerNodeSize is the size of data in inner nodes.
innerNodeSize = nmtHashSize * 2
innerNodeSize = NmtHashSize * 2

// leafNodeSize is the size of data in leaf nodes.
leafNodeSize = NamespaceSize + appconsts.ShareSize
Expand Down Expand Up @@ -98,8 +98,8 @@ func (n nmtNode) Links() []*ipld.Link {
default:
panic(fmt.Sprintf("unexpected size %v", len(n.RawData())))
case innerNodeSize:
leftCid := MustCidFromNamespacedSha256(n.RawData()[:nmtHashSize])
rightCid := MustCidFromNamespacedSha256(n.RawData()[nmtHashSize:])
leftCid := MustCidFromNamespacedSha256(n.RawData()[:NmtHashSize])
rightCid := MustCidFromNamespacedSha256(n.RawData()[NmtHashSize:])

return []*ipld.Link{{Cid: leftCid}, {Cid: rightCid}}
case leafNodeSize:
Expand Down Expand Up @@ -129,7 +129,7 @@ func (n nmtNode) Size() (uint64, error) {

// CidFromNamespacedSha256 uses a hash from an nmt tree to create a CID
func CidFromNamespacedSha256(namespacedHash []byte) (cid.Cid, error) {
if got, want := len(namespacedHash), nmtHashSize; got != want {
if got, want := len(namespacedHash), NmtHashSize; got != want {
return cid.Cid{}, fmt.Errorf("invalid namespaced hash length, got: %v, want: %v", got, want)
}
buf, err := mh.Encode(namespacedHash, sha256Namespace8Flagged)
Expand Down
2 changes: 1 addition & 1 deletion share/ipld/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func RandNamespacedCID(t *testing.T) cid.Cid {
raw := make([]byte, nmtHashSize)
raw := make([]byte, NmtHashSize)
_, err := mrand.Read(raw)
require.NoError(t, err)
id, err := CidFromNamespacedSha256(raw)
Expand Down

0 comments on commit d96f79e

Please sign in to comment.