Skip to content

Commit

Permalink
review feedback - pass peer ID instead of boolean
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Feb 7, 2025
1 parent 8c4b94b commit a21cfea
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 29 deletions.
42 changes: 31 additions & 11 deletions activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"math/bits"
"time"

"github.com/libp2p/go-libp2p/core/peer"
"github.com/spacemeshos/post/shared"
"github.com/spacemeshos/post/verifying"
"go.uber.org/zap"
Expand Down Expand Up @@ -321,7 +322,12 @@ func (h *HandlerV1) cacheAtx(ctx context.Context, atx *types.ActivationTx, malic
}

// checkDoublePublish verifies if a node has already published an ATX in the same epoch.
func (h *HandlerV1) checkDoublePublish(ctx context.Context, tx sql.Executor, atx *wire.ActivationTxV1, publishing bool) (bool, error) {
func (h *HandlerV1) checkDoublePublish(
ctx context.Context,
tx sql.Executor,
atx *wire.ActivationTxV1,
peer peer.ID,
) (bool, error) {
prev, err := atxs.GetByEpochAndNodeID(tx, atx.PublishEpoch, atx.SmesherID)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return false, err
Expand All @@ -331,7 +337,7 @@ func (h *HandlerV1) checkDoublePublish(ctx context.Context, tx sql.Executor, atx
return false, nil
}

if publishing {
if peer == h.local {
// if we land here we tried to publish 2 ATXs in the same epoch
// don't punish ourselves but fail validation and thereby the handling of the incoming ATX
return false, fmt.Errorf(
Expand Down Expand Up @@ -380,7 +386,12 @@ func (h *HandlerV1) checkDoublePublish(ctx context.Context, tx sql.Executor, atx
}

// checkWrongPrevAtx verifies if the previous ATX referenced in the ATX is correct.
func (h *HandlerV1) checkWrongPrevAtx(ctx context.Context, tx sql.Executor, atx *wire.ActivationTxV1, publishing bool) (bool, error) {
func (h *HandlerV1) checkWrongPrevAtx(
ctx context.Context,
tx sql.Executor,
atx *wire.ActivationTxV1,
peer peer.ID,
) (bool, error) {
expectedPrevID, err := atxs.PrevIDByNodeID(tx, atx.SmesherID, atx.PublishEpoch)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return false, fmt.Errorf("get last atx by node id: %w", err)
Expand All @@ -389,7 +400,7 @@ func (h *HandlerV1) checkWrongPrevAtx(ctx context.Context, tx sql.Executor, atx
return false, nil
}

if publishing {
if peer == h.local {
// if we land here we tried to publish an ATX with a wrong prevATX
h.logger.Warn(
"Node produced an ATX with a wrong prevATX. This can happened when the node wasn't synced when "+
Expand Down Expand Up @@ -457,23 +468,33 @@ func (h *HandlerV1) checkWrongPrevAtx(ctx context.Context, tx sql.Executor, atx
return true, h.malPublisher.PublishProof(ctx, atx.SmesherID, proof)
}

func (h *HandlerV1) checkMalicious(ctx context.Context, tx sql.Transaction, watx *wire.ActivationTxV1, publishing bool) (bool, error) {
malicious, err := h.checkDoublePublish(ctx, tx, watx, publishing)
func (h *HandlerV1) checkMalicious(
ctx context.Context,
tx sql.Transaction,
watx *wire.ActivationTxV1,
peer peer.ID,
) (bool, error) {
malicious, err := h.checkDoublePublish(ctx, tx, watx, peer)
if err != nil {
return malicious, fmt.Errorf("check double publish: %w", err)
}
if malicious {
return true, nil
}
malicious, err = h.checkWrongPrevAtx(ctx, tx, watx, publishing)
malicious, err = h.checkWrongPrevAtx(ctx, tx, watx, peer)
if err != nil {
return malicious, fmt.Errorf("check wrong prev atx: %w", err)
}
return malicious, nil
}

// storeAtx stores an ATX and notifies subscribers of the ATXID.
func (h *HandlerV1) storeAtx(ctx context.Context, atx *types.ActivationTx, watx *wire.ActivationTxV1, publishing bool) error {
func (h *HandlerV1) storeAtx(
ctx context.Context,
atx *types.ActivationTx,
watx *wire.ActivationTxV1,
peer peer.ID,
) error {
var malicious bool
if err := h.cdb.WithTxImmediate(ctx, func(tx sql.Transaction) error {
var err error
Expand All @@ -487,7 +508,7 @@ func (h *HandlerV1) storeAtx(ctx context.Context, atx *types.ActivationTx, watx
}
malicious = malicious || malicious2
if !malicious {
malicious, err = h.checkMalicious(ctx, tx, watx, publishing)
malicious, err = h.checkMalicious(ctx, tx, watx, peer)
if err != nil {
return fmt.Errorf("check malicious: %w", err)
}
Expand Down Expand Up @@ -560,8 +581,7 @@ func (h *HandlerV1) processATX(
return fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err)
}

publishing := h.local == peer
if err := h.storeAtx(ctx, atx, watx, publishing); err != nil {
if err := h.storeAtx(ctx, atx, watx, peer); err != nil {
return fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err)
}

Expand Down
36 changes: 18 additions & 18 deletions activation/handler_v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx.PublishEpoch+1, watx.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx, watx, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx, watx, p2p.Peer("other")))

atxFromDb, err := atxs.Get(atxHdlr.cdb, atx.ID())
require.NoError(t, err)
Expand All @@ -601,13 +601,13 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx.PublishEpoch+1, watx.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx, watx, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx, watx, p2p.Peer("other")))

atxHdlr.mBeacon.EXPECT().OnAtx(gomock.Cond(func(atx *types.ActivationTx) bool {
return atx.ID() == watx.ID()
}))
// Note: tortoise is not informed about the same ATX again
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx, watx, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx, watx, p2p.Peer("other")))
})

t.Run("stores ATX of malicious identity", func(t *testing.T) {
Expand All @@ -625,7 +625,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx.PublishEpoch+1, watx.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx, watx, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx, watx, p2p.Peer("other")))

atxFromDb, err := atxs.Get(atxHdlr.cdb, atx.ID())
require.NoError(t, err)
Expand All @@ -644,7 +644,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx0.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx0.PublishEpoch+1, watx0.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx0, watx0, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx0, watx0, p2p.Peer("other")))

watx1 := newInitialATXv1(t, goldenATXID)
watx1.Coinbase = types.GenerateAddress([]byte("aaaa"))
Expand All @@ -667,7 +667,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return nil
},
)
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx1, watx1, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx1, watx1, p2p.Peer("other")))
})

t.Run("another atx for the same epoch for registered ID doesn't create a malfeasance proof", func(t *testing.T) {
Expand All @@ -681,15 +681,15 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx0.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx0.PublishEpoch+1, watx0.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx0, watx0, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx0, watx0, atxHdlr.local))

watx1 := newInitialATXv1(t, goldenATXID)
watx1.Coinbase = types.GenerateAddress([]byte("aaaa"))
watx1.Sign(sig)
atx1 := toAtx(t, watx1)

require.ErrorContains(t,
atxHdlr.storeAtx(context.Background(), atx1, watx1, true),
atxHdlr.storeAtx(context.Background(), atx1, watx1, atxHdlr.local),
fmt.Sprintf("%s already published an ATX", sig.NodeID().ShortString()),
)
})
Expand All @@ -705,7 +705,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == initialATX.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(initialATX.PublishEpoch+1, initialATX.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), wInitialATX, initialATX, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), wInitialATX, initialATX, p2p.Peer("other")))

// valid first non-initial ATX
watx1 := newChainedActivationTxV1(t, initialATX, goldenATXID)
Expand All @@ -716,7 +716,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx1.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx1.PublishEpoch+1, watx1.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx1, watx1, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx1, watx1, p2p.Peer("other")))

watx2 := newChainedActivationTxV1(t, watx1, goldenATXID)
watx2.Sign(sig)
Expand All @@ -726,7 +726,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx2.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx2.PublishEpoch+1, watx2.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx2, watx2, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx2, watx2, p2p.Peer("other")))

// third non-initial ATX references initial ATX as prevATX
watx3 := newChainedActivationTxV1(t, initialATX, goldenATXID)
Expand All @@ -751,7 +751,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
},
)

require.NoError(t, atxHdlr.storeAtx(context.Background(), atx3, watx3, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx3, watx3, p2p.Peer("other")))
})

t.Run("another atx of v2 with the same prevatx is considered malicious", func(t *testing.T) {
Expand All @@ -765,7 +765,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == initialATX.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(initialATX.PublishEpoch+1, initialATX.ID(), gomock.Any())
require.NoError(t, atxHdlr.v1.storeAtx(context.Background(), wInitialATX, initialATX, false))
require.NoError(t, atxHdlr.v1.storeAtx(context.Background(), wInitialATX, initialATX, p2p.Peer("other")))

// valid first non-initial ATX
watx1 := newChainedActivationTxV1(t, initialATX, goldenATXID)
Expand All @@ -776,7 +776,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx1.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx1.PublishEpoch+1, watx1.ID(), gomock.Any())
require.NoError(t, atxHdlr.v1.storeAtx(context.Background(), atx1, watx1, false))
require.NoError(t, atxHdlr.v1.storeAtx(context.Background(), atx1, watx1, p2p.Peer("other")))

watx2 := newSoloATXv2(t, watx1.PublishEpoch+1, watx1.ID(), watx1.ID())
watx2.Sign(sig)
Expand Down Expand Up @@ -812,7 +812,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
},
)

require.NoError(t, atxHdlr.v1.storeAtx(context.Background(), atx3, watx3, false))
require.NoError(t, atxHdlr.v1.storeAtx(context.Background(), atx3, watx3, p2p.Peer("other")))
})

t.Run("another atx with the same prevatx when publishing doesn't create a malfeasance proof", func(t *testing.T) {
Expand All @@ -827,7 +827,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == wInitialATX.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(wInitialATX.PublishEpoch+1, wInitialATX.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), initialAtx, wInitialATX, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), initialAtx, wInitialATX, atxHdlr.local))

// valid first non-initial ATX
watx1 := newChainedActivationTxV1(t, wInitialATX, goldenATXID)
Expand All @@ -838,7 +838,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
return atx.ID() == watx1.ID()
}))
atxHdlr.mTortoise.EXPECT().OnAtx(watx1.PublishEpoch+1, watx1.ID(), gomock.Any())
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx1, watx1, false))
require.NoError(t, atxHdlr.storeAtx(context.Background(), atx1, watx1, atxHdlr.local))

// second non-initial ATX references empty as prevATX
watx2 := newInitialATXv1(t, goldenATXID)
Expand All @@ -847,7 +847,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) {
atx2 := toAtx(t, watx2)

require.ErrorContains(t,
atxHdlr.storeAtx(context.Background(), atx2, watx2, true),
atxHdlr.storeAtx(context.Background(), atx2, watx2, atxHdlr.local),
fmt.Sprintf("%s referenced incorrect previous ATX", sig.NodeID().ShortString()),
)
})
Expand Down

0 comments on commit a21cfea

Please sign in to comment.