From 1b16802ed316277fa0c4732a4bc1acf10592affa Mon Sep 17 00:00:00 2001 From: josefkedwards Date: Wed, 5 Feb 2025 00:26:15 -0500 Subject: [PATCH 1/7] Update IBC.go @coderabbitai Signed-off-by: josefkedwards --- .gofiles/IBC.go | 287 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 286 insertions(+), 1 deletion(-) diff --git a/.gofiles/IBC.go b/.gofiles/IBC.go index 1d9304e..977a8de 100644 --- a/.gofiles/IBC.go +++ b/.gofiles/IBC.go @@ -1,4 +1,289 @@ -package main +package v2_test + +import ( + "time" + + sdkmath "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/ibc-go/v9/modules/apps/transfer/types" + channeltypesv2 "github.com/cosmos/ibc-go/v9/modules/core/04-channel/v2/types" + ibctesting "github.com/cosmos/ibc-go/v9/testing" +) + +func (suite *TransferTestSuite) TestFullEurekaForwardPath() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{types.Hop{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + // check the original sendPacket logic + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + // check that the balance for chainA is updated + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + + // check that module account escrow address has locked the tokens + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + // check the recvPacket logic with forwarding the tokens should be moved to the next hop's escrow address + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + // check that the packet has been sent from B to C + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + // check that acknowledgement on chainB for packet A to B does not exist yet + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + // update the chainB client on chainC + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + // recvPacket packetBToC on chain C + res, err = suite.pathBToC.EndpointB.MsgRecvPacketWithResult(packetBToC) + suite.Require().NoError(err) + + // check that the receiver has received final tokens on chainC + traceBToC := types.NewHop(types.PortID, suite.pathBToC.EndpointB.ClientID) + chainCDenom := types.NewDenom(coin.Denom, traceBToC, traceAToB) + chainCBalance := suite.chainC.GetSimApp().BankKeeper.GetBalance(suite.chainC.GetContext(), suite.chainC.SenderAccount.GetAddress(), chainCDenom.IBCDenom()) + coinSentFromBToC := sdk.NewCoin(chainCDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromBToC, chainCBalance) + + // check that the final hop has written an acknowledgement + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgAcknowledgePacketWithResult(packetBToC, *ack) + suite.Require().NoError(err) + + // check that the middle hop has now written its async acknowledgement + ack, err = ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + // update chainB client on chainA + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) +} + +func (suite *TransferTestSuite) TestFullEurekaForwardFailedAck() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{types.Hop{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + // check the original sendPacket logic + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + // check that the balance for chainA is updated + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + + // check that module account escrow address has locked the tokens + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + // check the recvPacket logic with forwarding the tokens should be moved to the next hop's escrow address + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + // check that the packet has been sent from B to C + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + // check that acknowledgement on chainB for packet A to B does not exist yet + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + // update the chainB client on chainC + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + // turn off receive on chain C to trigger an error + suite.chainC.GetSimApp().TransferKeeper.SetParams(suite.chainC.GetContext(), types.Params{ + SendEnabled: true, + ReceiveEnabled: false, + }) + + // recvPacket packetBToC on chain C + res, err = suite.pathBToC.EndpointB.MsgRecvPacketWithResult(packetBToC) + suite.Require().NoError(err) + + // update the chainC client on chain B + err = suite.pathBToC.EndpointA.UpdateClient() + suite.Require().NoError(err) + + // check that the final hop has written an acknowledgement + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgAcknowledgePacketWithResult(packetBToC, *ack) + suite.Require().NoError(err) + + // check that the middle hop has now written its async acknowledgement + ack, err = ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + // update chainB client on chainA + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) + + // check that the tokens have been refunded on original sender + chainABalance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(coin, chainABalance) +} + +func (suite *TransferTestSuite) TestFullEurekaForwardTimeout() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{types.Hop{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + // check the original sendPacket logic + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + // check that the balance for chainA is updated + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + + // check that module account escrow address has locked the tokens + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + // check the recvPacket logic with forwarding the tokens should be moved to the next hop's escrow address + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + // check that the packet has been sent from B to C + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + // check that acknowledgement on chainB for packet A to B does not exist yet + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + // update the chainB client on chainC + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + // Time out packet + suite.coordinator.IncrementTimeBy(time.Hour * 5) + err = suite.pathBToC.EndpointA.UpdateClient() + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgTimeoutPacketWithResult(packetBToC) + suite.Require().NoError(err) + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) + + // check that the tokens have been refunded on original sender + chainABalance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(coin, chainABalance)package main import ( "fmt" From cccad9487e3a7e3b5e8d7441db341f2f680562bf Mon Sep 17 00:00:00 2001 From: "J. K. Edwards" Date: Wed, 5 Feb 2025 02:10:58 -0500 Subject: [PATCH 2/7] Update IBC.go Signed-off-by: J. K. Edwards --- .gofiles/IBC.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gofiles/IBC.go b/.gofiles/IBC.go index 977a8de..621c00e 100644 --- a/.gofiles/IBC.go +++ b/.gofiles/IBC.go @@ -3262,6 +3262,3 @@ func (suite *KeeperTestSuite) TestWriteErrorReceipt() { upgradeError = types.NewUpgradeError(10, types.ErrInvalidUpgrade) tc.malleate() - - - From d0bcc86991d9a4d9cfb1ec7c721f68fbbb952baa Mon Sep 17 00:00:00 2001 From: josefkedwards Date: Thu, 6 Feb 2025 03:22:09 -0500 Subject: [PATCH 3/7] Update IBC.go Signed-off-by: josefkedwards --- .gofiles/IBC.go | 347 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 347 insertions(+) diff --git a/.gofiles/IBC.go b/.gofiles/IBC.go index 621c00e..2e6c701 100644 --- a/.gofiles/IBC.go +++ b/.gofiles/IBC.go @@ -1,3 +1,61 @@ +package keeper + +import ( + "strings" + + capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types" + host "github.com/cosmos/ibc-go/v9/modules/core/24-host" + + errorsmod "cosmossdk.io/errors" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/CosmWasm/wasmd/x/wasm/types" +) + +// bindIbcPort will reserve the port. +// returns a string name of the port or error if we cannot bind it. +// this will fail if call twice. +func (k Keeper) bindIbcPort(ctx sdk.Context, portID string) error { + portCap := k.portKeeper.BindPort(ctx, portID) + return k.ClaimCapability(ctx, portCap, host.PortPath(portID)) +} + +// ensureIbcPort is like registerIbcPort, but it checks if we already hold the port +// before calling register, so this is safe to call multiple times. +// Returns success if we already registered or just registered and error if we cannot +// (lack of permissions or someone else has it) +func (k Keeper) ensureIbcPort(ctx sdk.Context, contractAddr sdk.AccAddress) (string, error) { + portID := PortIDForContract(contractAddr) + if _, ok := k.capabilityKeeper.GetCapability(ctx, host.PortPath(portID)); ok { + return portID, nil + } + return portID, k.bindIbcPort(ctx, portID) +} + +const portIDPrefix = "wasm." + +func PortIDForContract(addr sdk.AccAddress) string { + return portIDPrefix + addr.String() +} + +func ContractFromPortID(portID string) (sdk.AccAddress, error) { + if !strings.HasPrefix(portID, portIDPrefix) { + return nil, errorsmod.Wrapf(types.ErrInvalid, "without prefix") + } + return sdk.AccAddressFromBech32(portID[len(portIDPrefix):]) +} + +// AuthenticateCapability wraps the scopedKeeper's AuthenticateCapability function +func (k Keeper) AuthenticateCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) bool { + return k.capabilityKeeper.AuthenticateCapability(ctx, cap, name) +} + +// ClaimCapability allows the transfer module to claim a capability +// that IBC module passes to it +func (k Keeper) ClaimCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) error { + return k.capabilityKeeper.ClaimCapability(ctx, cap, name) +} package v2_test import ( @@ -3262,3 +3320,292 @@ func (suite *KeeperTestSuite) TestWriteErrorReceipt() { upgradeError = types.NewUpgradeError(10, types.ErrInvalidUpgrade) tc.malleate() + +package main +package v2_test + +bearycool11 marked this conversation as resolved. +import ( + "time" + + sdkmath "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/ibc-go/v9/modules/apps/transfer/types" + channeltypesv2 "github.com/cosmos/ibc-go/v9/modules/core/04-channel/v2/types" + ibctesting "github.com/cosmos/ibc-go/v9/testing" +) + +func (suite *TransferTestSuite) TestFullEurekaForwardPath() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{types.Hop{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + // check the original sendPacket logic + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + // check that the balance for chainA is updated + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + + // check that module account escrow address has locked the tokens + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + // check the recvPacket logic with forwarding the tokens should be moved to the next hop's escrow address + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + // check that the packet has been sent from B to C + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + // check that acknowledgement on chainB for packet A to B does not exist yet + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + // update the chainB client on chainC + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + // recvPacket packetBToC on chain C + res, err = suite.pathBToC.EndpointB.MsgRecvPacketWithResult(packetBToC) + suite.Require().NoError(err) + + // check that the receiver has received final tokens on chainC + traceBToC := types.NewHop(types.PortID, suite.pathBToC.EndpointB.ClientID) + chainCDenom := types.NewDenom(coin.Denom, traceBToC, traceAToB) + chainCBalance := suite.chainC.GetSimApp().BankKeeper.GetBalance(suite.chainC.GetContext(), suite.chainC.SenderAccount.GetAddress(), chainCDenom.IBCDenom()) + coinSentFromBToC := sdk.NewCoin(chainCDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromBToC, chainCBalance) + + // check that the final hop has written an acknowledgement + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgAcknowledgePacketWithResult(packetBToC, *ack) + suite.Require().NoError(err) + + // check that the middle hop has now written its async acknowledgement + ack, err = ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + // update chainB client on chainA + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) +} + +func (suite *TransferTestSuite) TestFullEurekaForwardFailedAck() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{types.Hop{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + // check the original sendPacket logic + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + // check that the balance for chainA is updated + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + + // check that module account escrow address has locked the tokens + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + // check the recvPacket logic with forwarding the tokens should be moved to the next hop's escrow address + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + // check that the packet has been sent from B to C + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + // check that acknowledgement on chainB for packet A to B does not exist yet + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + // update the chainB client on chainC + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + // turn off receive on chain C to trigger an error + suite.chainC.GetSimApp().TransferKeeper.SetParams(suite.chainC.GetContext(), types.Params{ + SendEnabled: true, + ReceiveEnabled: false, + }) + + // recvPacket packetBToC on chain C + res, err = suite.pathBToC.EndpointB.MsgRecvPacketWithResult(packetBToC) + suite.Require().NoError(err) + + // update the chainC client on chain B + err = suite.pathBToC.EndpointA.UpdateClient() + suite.Require().NoError(err) + + // check that the final hop has written an acknowledgement + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgAcknowledgePacketWithResult(packetBToC, *ack) + suite.Require().NoError(err) + + // check that the middle hop has now written its async acknowledgement + ack, err = ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + // update chainB client on chainA + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) + + // check that the tokens have been refunded on original sender + chainABalance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(coin, chainABalance) +} + +func (suite *TransferTestSuite) TestFullEurekaForwardTimeout() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{types.Hop{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + // check the original sendPacket logic + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + // check that the balance for chainA is updated + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + + // check that module account escrow address has locked the tokens + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + // check the recvPacket logic with forwarding the tokens should be moved to the next hop's escrow address + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + // check that the packet has been sent from B to C + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + // check that acknowledgement on chainB for packet A to B does not exist yet + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + // update the chainB client on chainC + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + // Time out packet + suite.coordinator.IncrementTimeBy(time.Hour * 5) + err = suite.pathBToC.EndpointA.UpdateClient() + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgTimeoutPacketWithResult(packetBToC) + suite.Require().NoError(err) + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) + + // check that the tokens have been refunded on original sender + chainABalance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(coin, chainABalance)package main From d1a97d8977dfb97f2e76145616f702792eef8925 Mon Sep 17 00:00:00 2001 From: josefkedwards Date: Thu, 6 Feb 2025 03:44:51 -0500 Subject: [PATCH 4/7] Update IBC.go Signed-off-by: josefkedwards --- .gofiles/IBC.go | 654 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 633 insertions(+), 21 deletions(-) diff --git a/.gofiles/IBC.go b/.gofiles/IBC.go index 2e6c701..85e1da4 100644 --- a/.gofiles/IBC.go +++ b/.gofiles/IBC.go @@ -1,7 +1,17 @@ +package main package keeper +package v2_test +package v2_test +package ibc import ( - "strings" + "fmt" + "math" + "testing" + "strings" + "time" + "testing" + capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types" host "github.com/cosmos/ibc-go/v9/modules/core/24-host" @@ -56,10 +66,7 @@ func (k Keeper) AuthenticateCapability(ctx sdk.Context, cap *capabilitytypes.Cap func (k Keeper) ClaimCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) error { return k.capabilityKeeper.ClaimCapability(ctx, cap, name) } -package v2_test -import ( - "time" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" @@ -341,12 +348,8 @@ func (suite *TransferTestSuite) TestFullEurekaForwardTimeout() { // check that the tokens have been refunded on original sender chainABalance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) - suite.Require().Equal(coin, chainABalance)package main + suite.Require().Equal(coin, chainABalance) -import ( - "fmt" - "strings" - "time" sdk "github.com/cosmos/cosmos-sdk/types" ibc "github.com/hypothetical/ibc-integration" @@ -548,10 +551,8 @@ func main() { engine.processConversation("") // Note: This function never returns because processConversation runs package keeper_test -import ( - "fmt" - "math" - "testing" + + errorsmod "cosmossdk.io/errors" @@ -3321,13 +3322,6 @@ func (suite *KeeperTestSuite) TestWriteErrorReceipt() { tc.malleate() -package main -package v2_test - -bearycool11 marked this conversation as resolved. -import ( - "time" - sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/ibc-go/v9/modules/apps/transfer/types" @@ -3337,6 +3331,7 @@ import ( func (suite *TransferTestSuite) TestFullEurekaForwardPath() { receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{types.Hop{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) @@ -3608,4 +3603,621 @@ func (suite *TransferTestSuite) TestFullEurekaForwardTimeout() { // check that the tokens have been refunded on original sender chainABalance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) - suite.Require().Equal(coin, chainABalance)package main + suite.Require().Equal(coin, chainABalance) + + // +// IBC.go - A single-file "merged" version of your provided code, +// expanded to ~4,000 lines with filler comments. +// +// (Note: This code won't compile as-is unless you have +// all the external dependencies and matching versions.) +// + + errorsmod "cosmossdk.io/errors" + sdkmath "cosmossdk.io/math" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/bank/types" // might conflict with your "bank" usage, but included for example + + capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types" + "github.com/cosmos/ibc-go/v9/modules/apps/transfer/types" + channelkeeper "github.com/cosmos/ibc-go/v9/modules/core/04-channel/keeper" + channeltypes "github.com/cosmos/ibc-go/v9/modules/core/04-channel/types" + channeltypesv2 "github.com/cosmos/ibc-go/v9/modules/core/04-channel/v2/types" + commitmenttypes "github.com/cosmos/ibc-go/v9/modules/core/23-commitment/types" + host "github.com/cosmos/ibc-go/v9/modules/core/24-host" + clienttypes "github.com/cosmos/ibc-go/v9/modules/core/02-client/types" + connectiontypes "github.com/cosmos/ibc-go/v9/modules/core/03-connection/types" + "github.com/cosmos/ibc-go/v9/modules/core/exported" + + ibctesting "github.com/cosmos/ibc-go/v9/testing" + "github.com/cosmos/ibc-go/v9/testing/mock" + + // Hypothetical references + ibc "github.com/hypothetical/ibc-integration" + bitcore "github.com/hypothetical/bitcore-integration" + ethereum "github.com/hypothetical/ethereum-integration" + + "github.com/CosmWasm/wasmd/x/wasm/types" // from your snippet +) + +// ----------------------------------------------------------------------------- +// SECTION: Keeper from original "keeper" snippet +// ----------------------------------------------------------------------------- + +// Keeper is a sample keeper structure (abbreviated). +type Keeper struct { + portKeeper PortKeeper + capabilityKeeper CapabilityKeeper +} + +// PortKeeper interface (sample). +type PortKeeper interface { + BindPort(ctx sdk.Context, portID string) *capabilitytypes.Capability +} + +// CapabilityKeeper interface (sample). +type CapabilityKeeper interface { + GetCapability(ctx sdk.Context, name string) (*capabilitytypes.Capability, bool) + AuthenticateCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) bool + ClaimCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) error +} + +// bindIbcPort will reserve the port. +// returns a string name of the port or error if we cannot bind it. +// this will fail if called twice. +func (k Keeper) bindIbcPort(ctx sdk.Context, portID string) error { + portCap := k.portKeeper.BindPort(ctx, portID) + return k.ClaimCapability(ctx, portCap, host.PortPath(portID)) +} + +// ensureIbcPort is like registerIbcPort, but it checks if we already hold the port +// before calling register, so this is safe to call multiple times. +// Returns success if we already registered or just registered and error if we cannot +// (lack of permissions or someone else has it) +func (k Keeper) ensureIbcPort(ctx sdk.Context, contractAddr sdk.AccAddress) (string, error) { + portID := PortIDForContract(contractAddr) + if _, ok := k.capabilityKeeper.GetCapability(ctx, host.PortPath(portID)); ok { + return portID, nil + } + return portID, k.bindIbcPort(ctx, portID) +} + +const portIDPrefix = "wasm." + +// PortIDForContract returns a port ID for the provided contract address. +func PortIDForContract(addr sdk.AccAddress) string { + return portIDPrefix + addr.String() +} + +// ContractFromPortID extracts the contract address from a portID that has the wasm prefix. +func ContractFromPortID(portID string) (sdk.AccAddress, error) { + if !strings.HasPrefix(portID, portIDPrefix) { + return nil, errorsmod.Wrapf(types.ErrInvalid, "without prefix") + } + return sdk.AccAddressFromBech32(portID[len(portIDPrefix):]) +} + +// AuthenticateCapability wraps the scopedKeeper's AuthenticateCapability function +func (k Keeper) AuthenticateCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) bool { + return k.capabilityKeeper.AuthenticateCapability(ctx, cap, name) +} + +// ClaimCapability allows the transfer module to claim a capability +// that IBC module passes to it +func (k Keeper) ClaimCapability(ctx sdk.Context, cap *capabilitytypes.Capability, name string) error { + return k.capabilityKeeper.ClaimCapability(ctx, cap, name) +} + +// ----------------------------------------------------------------------------- +// SECTION: TransferTestSuite from original "v2_test" snippet +// ----------------------------------------------------------------------------- + +// TransferTestSuite is a sample test suite. +type TransferTestSuite struct { + chainA, chainB, chainC *ibctesting.TestChain + coordinator *ibctesting.Coordinator + pathAToB, pathBToC *ibctesting.Path +} + +func (suite *TransferTestSuite) Require() *requireAsserts { + return &requireAsserts{} +} + +// Minimal drop-in for require usage in tests +type requireAsserts struct{} + +// NoError is a placeholder +func (*requireAsserts) NoError(err error, msgAndArgs ...interface{}) { + if err != nil { + panic(fmt.Sprintf("NoError failed: %v", err)) + } +} + +// Equal is a placeholder +func (*requireAsserts) Equal(exp, act interface{}, msgAndArgs ...interface{}) { + if exp != act { + panic(fmt.Sprintf("Equal failed - expected: %v, got: %v", exp, act)) + } +} + +// Nil is a placeholder +func (*requireAsserts) Nil(obj interface{}, msgAndArgs ...interface{}) { + if obj != nil { + panic(fmt.Sprintf("Nil failed - expected nil, got: %v", obj)) + } +} + +// TestFullEurekaForwardPath is from your snippet +func (suite *TransferTestSuite) TestFullEurekaForwardPath() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + // check the original sendPacket logic + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + + // check that module account escrow address has locked the tokens + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + // check the recvPacket logic with forwarding + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + // check that the packet has been sent from B to C + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + // check that acknowledgement on chainB for packet A to B does not exist yet + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + // update the chainB client on chainC + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + // recvPacket packetBToC on chain C + res, err = suite.pathBToC.EndpointB.MsgRecvPacketWithResult(packetBToC) + suite.Require().NoError(err) + + // check that the receiver has received final tokens on chainC + traceBToC := types.NewHop(types.PortID, suite.pathBToC.EndpointB.ClientID) + chainCDenom := types.NewDenom(coin.Denom, traceBToC, traceAToB) + chainCBalance := suite.chainC.GetSimApp().BankKeeper.GetBalance(suite.chainC.GetContext(), suite.chainC.SenderAccount.GetAddress(), chainCDenom.IBCDenom()) + coinSentFromBToC := sdk.NewCoin(chainCDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromBToC, chainCBalance) + + // check that the final hop has written an acknowledgement + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgAcknowledgePacketWithResult(packetBToC, *ack) + suite.Require().NoError(err) + + // check that the middle hop has now written its async acknowledgement + ack, err = ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + // update chainB client on chainA + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) +} + +// TestFullEurekaForwardFailedAck ... +func (suite *TransferTestSuite) TestFullEurekaForwardFailedAck() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + // check the original sendPacket logic + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + // turn off receive on chain C to trigger an error + suite.chainC.GetSimApp().TransferKeeper.SetParams(suite.chainC.GetContext(), types.Params{ + SendEnabled: true, + ReceiveEnabled: false, + }) + + res, err = suite.pathBToC.EndpointB.MsgRecvPacketWithResult(packetBToC) + suite.Require().NoError(err) + + err = suite.pathBToC.EndpointA.UpdateClient() + suite.Require().NoError(err) + + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgAcknowledgePacketWithResult(packetBToC, *ack) + suite.Require().NoError(err) + + ack, err = ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) + + chainABalance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(coin, chainABalance) +} + +// TestFullEurekaForwardTimeout ... +func (suite *TransferTestSuite) TestFullEurekaForwardTimeout() { + receiver := suite.chainC.SenderAccount.GetAddress().String() + hops := []types.Hop{{PortId: types.PortID, ChannelId: suite.pathBToC.EndpointA.ClientID}} + + coin := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), sdk.DefaultBondDenom) + tokens := make([]types.Token, 1) + var err error + tokens[0], err = suite.chainA.GetSimApp().TransferKeeper.TokenFromCoin(suite.chainA.GetContext(), coin) + suite.Require().NoError(err) + + timeoutTimestamp := uint64(suite.chainB.GetContext().BlockTime().Add(time.Hour).Unix()) + + transferData := types.FungibleTokenPacketDataV2{ + Tokens: tokens, + Sender: suite.chainA.SenderAccount.GetAddress().String(), + Receiver: receiver, + Memo: "", + Forwarding: types.NewForwardingPacketData("", hops...), + } + bz := suite.chainA.Codec.MustMarshal(&transferData) + payload := channeltypesv2.NewPayload( + types.PortID, types.PortID, types.V2, + types.EncodingProtobuf, bz, + ) + packetAToB, err := suite.pathAToB.EndpointA.MsgSendPacket(timeoutTimestamp, payload) + suite.Require().NoError(err) + + escrowAddressA := types.GetEscrowAddress(types.PortID, suite.pathAToB.EndpointA.ClientID) + chainABalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(sdkmath.ZeroInt(), chainABalance.Amount) + chainAEscrowBalance := suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), escrowAddressA, coin.Denom) + suite.Require().Equal(coin, chainAEscrowBalance) + + res, err := suite.pathAToB.EndpointB.MsgRecvPacketWithResult(packetAToB) + suite.Require().NoError(err) + + escrowAddressB := types.GetEscrowAddress(types.PortID, suite.pathBToC.EndpointA.ClientID) + traceAToB := types.NewHop(types.PortID, suite.pathAToB.EndpointB.ClientID) + chainBDenom := types.NewDenom(coin.Denom, traceAToB) + chainBBalance := suite.chainB.GetSimApp().BankKeeper.GetBalance(suite.chainB.GetContext(), escrowAddressB, chainBDenom.IBCDenom()) + coinSentFromAToB := sdk.NewCoin(chainBDenom.IBCDenom(), coin.Amount) + suite.Require().Equal(coinSentFromAToB, chainBBalance) + + packetBToC, err := ibctesting.ParsePacketV2FromEvents(res.Events) + suite.Require().NoError(err) + + packetBToCCommitment := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketCommitment(suite.chainB.GetContext(), suite.pathBToC.EndpointA.ClientID, 1) + suite.Require().Equal(channeltypesv2.CommitPacket(packetBToC), packetBToCCommitment) + + acknowledgementBToC := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeperV2.GetPacketAcknowledgement(suite.chainB.GetContext(), suite.pathAToB.EndpointA.ClientID, 1) + suite.Require().Nil(acknowledgementBToC) + + err = suite.pathBToC.EndpointB.UpdateClient() + suite.Require().NoError(err) + + suite.coordinator.IncrementTimeBy(time.Hour * 5) + err = suite.pathBToC.EndpointA.UpdateClient() + suite.Require().NoError(err) + + res, err = suite.pathBToC.EndpointA.MsgTimeoutPacketWithResult(packetBToC) + suite.Require().NoError(err) + ack, err := ibctesting.ParseAckV2FromEvents(res.Events) + suite.Require().NoError(err) + + err = suite.pathAToB.EndpointA.UpdateClient() + suite.Require().NoError(err) + err = suite.pathAToB.EndpointA.MsgAcknowledgePacket(packetAToB, *ack) + suite.Require().NoError(err) + + chainABalance = suite.chainA.GetSimApp().BankKeeper.GetBalance(suite.chainA.GetContext(), suite.chainA.SenderAccount.GetAddress(), coin.Denom) + suite.Require().Equal(coin, chainABalance) +} + +// ----------------------------------------------------------------------------- +// SECTION: Additional keeper_test snippet +// ----------------------------------------------------------------------------- + +// KeeperTestSuite is a sample +type KeeperTestSuite struct{} + +func (suite *KeeperTestSuite) SetupTest() {} +func (suite *KeeperTestSuite) Run(name string, subtest func()) { subtest() } + +// We define a helper to assert upgrade errors, for demonstration +func (suite *KeeperTestSuite) assertUpgradeError(actualError, expError error) { + if actualError == nil && expError == nil { + return + } + if actualError == nil && expError != nil { + panic(fmt.Sprintf("expected error %v but got nil", expError)) + } + if expError == nil && actualError != nil { + panic(fmt.Sprintf("unexpected error %v", actualError)) + } +} + +// ----------------------------------------------------------------------------- +// SECTION: Big combined "main" block from your snippet +// ----------------------------------------------------------------------------- + +// This portion merges the “main” sample from your snippet, showing +// an InterchainFiatBackedEngine with references to multiple blockchains. +type InterchainFiatBackedEngine struct { + shortTermMemory []string + longTermMemory map[string]int + JKECounter int + suspiciousTransactions []string + ATOMValue float64 +} + +func NewInterchainFiatBackedEngine() *InterchainFiatBackedEngine { + return &InterchainFiatBackedEngine{ + shortTermMemory: make([]string, 0, 10), + longTermMemory: make(map[string]int), + JKECounter: 0, + suspiciousTransactions: make([]string, 0), + ATOMValue: 5.89, + } +} + +func cosmosSdkGetFullLedger() sdk.Ledger { + // fake stub + return sdk.Ledger{} +} +func (fde *InterchainFiatBackedEngine) checkLedgerIntegrity() { + cosmosLedger := cosmosSdkGetFullLedger() + ibcLedger := ibc.GetLedgerState() + bitcoinLedger := bitcore.GetFullLedger() + ethereumLedger := ethereum.GetFullLedger() + + fde.checkFiatBackingConsistency(cosmosLedger, bitcoinLedger, ethereumLedger) + fde.detectFraud(cosmosLedger, ibcLedger, bitcoinLedger, ethereumLedger) +} + +func (fde *InterchainFiatBackedEngine) checkFiatBackingConsistency(cosmosLedger sdk.Ledger, bitcoinLedger bitcore.Ledger, ethereumLedger ethereum.Ledger) { + btcValue := bitcore.GetReserveValue("btc_address_example") + ethValue := ethereum.GetReserveValue("eth_address_example") + fde.ATOMValue = (btcValue + ethValue) / 10000 +} + +func (fde *InterchainFiatBackedEngine) detectFraud( + cosmosLedger sdk.Ledger, + ibcLedger ibc.Ledger, + bitcoinLedger bitcore.Ledger, + ethereumLedger ethereum.Ledger, +) { + ledgers := []interface{}{cosmosLedger, ibcLedger, bitcoinLedger, ethereumLedger} + for _, ledger := range ledgers { + switch l := ledger.(type) { + case sdk.Ledger: + for _, tx := range l.Transactions { + if fde.isSuspicious(tx) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, tx.ID) + fde.logSuspiciousTransaction(tx) + } + } + case ibc.Ledger: + for _, tx := range l.Transactions { + if fde.isIBCSuspicious(tx) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, tx.ID) + fde.logIBCSuspiciousTransaction(tx) + } + } + case bitcore.Ledger: + for _, tx := range l.Transactions { + if fde.isBitcoinSuspicious(tx) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, tx.ID) + fde.logBitcoinSuspiciousTransaction(tx) + } + } + case ethereum.Ledger: + for _, tx := range l.Transactions { + if fde.isEthereumSuspicious(tx) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, tx.ID) + fde.logEthereumSuspiciousTransaction(tx) + } + } + } + } +} + +func (fde *InterchainFiatBackedEngine) isSuspicious(tx sdk.Tx) bool { return false } +func (fde *InterchainFiatBackedEngine) isIBCSuspicious(tx ibc.IBCTx) bool { return false } +func (fde *InterchainFiatBackedEngine) isBitcoinSuspicious(tx bitcore.Transaction) bool { + return false +} +func (fde *InterchainFiatBackedEngine) isEthereumSuspicious(tx ethereum.Transaction) bool { + return false +} + +func (fde *InterchainFiatBackedEngine) logSuspiciousTransaction(tx sdk.Tx) { + fmt.Printf("Suspicious Cosmos transaction detected: %s\n", tx.ID) +} +func (fde *InterchainFiatBackedEngine) logIBCSuspiciousTransaction(tx ibc.IBCTx) { + fmt.Printf("Suspicious IBC transaction detected: %s\n", tx.ID) +} +func (fde *InterchainFiatBackedEngine) logBitcoinSuspiciousTransaction(tx bitcore.Transaction) { + fmt.Printf("Suspicious Bitcoin transaction detected: %s\n", tx.ID) +} +func (fde *InterchainFiatBackedEngine) logEthereumSuspiciousTransaction(tx ethereum.Transaction) { + fmt.Printf("Suspicious Ethereum transaction detected: %s\n", tx.ID) +} + +func cosmosSdkGetTransaction(txid string) interface{} { + return nil +} +func (fde *InterchainFiatBackedEngine) checkLedgerIntegrityForTransaction(txid, chain string) { + var tx interface{} + switch chain { + case "cosmos": + tx = cosmosSdkGetTransaction(txid) + case "bitcoin": + tx = bitcore.GetTransaction(txid) + case "ethereum": + tx = ethereum.GetTransaction(txid) + case "ibc": + tx = ibc.GetTransaction(txid) + } + + switch t := tx.(type) { + case sdk.Tx: + if fde.isSuspicious(t) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, txid) + fde.logSuspiciousTransaction(t) + } + case ibc.IBCTx: + if fde.isIBCSuspicious(t) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, txid) + fde.logIBCSuspiciousTransaction(t) + } + case bitcore.Transaction: + if fde.isBitcoinSuspicious(t) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, txid) + fde.logBitcoinSuspiciousTransaction(t) + } + case ethereum.Transaction: + if fde.isEthereumSuspicious(t) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, txid) + fde.logEthereumSuspiciousTransaction(t) + } + } +} + +func (fde *InterchainFiatBackedEngine) novelinput(input string) { + fde.manageMemory(input) + if strings.HasPrefix(input, "txid") { + fde.checkLedgerIntegrityForTransaction(input, "cosmos") + } else if strings.HasPrefix(input, "btc_txid") { + fde.checkLedgerIntegrityForTransaction(input, "bitcoin") + } else if strings.HasPrefix(input, "eth_txid") { + fde.checkLedgerIntegrityForTransaction(input, "ethereum") + } else if strings.HasPrefix(input, "ibc") { + fde.checkLedgerIntegrityForTransaction(input, "ibc") + } +} + +func (fde *InterchainFiatBackedEngine) manageMemory(input string) { + if len(fde.shortTermMemory) >= 10 { + fde.shortTermMemory = fde.shortTermMemory[1:] + } + fde.shortTermMemory = append(fde.shortTermMemory, input) + + if count, exists := fde.longTermMemory[input]; exists { + fde.longTermMemory[input] = count + 1 + } else { + fde.longTermMemory[input] = 1 + } +} + +func (fde *InterchainFiatBackedEngine) processConversation(userInput string) string { + fde.novelinput(userInput) + + for { + fde.updatePersistentState() + for _, item := range fde.shortTermMemory { + fde.analyzeContext(item) + } + fde.checkLedgerIntegrity() + time.Sleep(60 * time.Second) + } +} + +func (fde *InterchainFiatBackedEngine) updatePersistentState() {} +func (fde *InterchainFiatBackedEngine) analyzeContext(item string) {} + +func main() { + engine := NewInterchainFiatBackedEngine() + fmt.Println("Interchain Fiat Backed Engine running...") + fmt.Printf("Current ATOM Value: $%.2f\n", engine.ATOMValue) + engine.processConversation("") +} From 723962608c53fd05e7ed53a2170e18e36130c2ab Mon Sep 17 00:00:00 2001 From: josefkedwards Date: Thu, 6 Feb 2025 04:44:20 -0500 Subject: [PATCH 5/7] Update Cargo.yml Signed-off-by: josefkedwards --- .github/workflows/Cargo.yml | 143 ++++++++++++++---------------------- 1 file changed, 57 insertions(+), 86 deletions(-) diff --git a/.github/workflows/Cargo.yml b/.github/workflows/Cargo.yml index fa30520..27b0036 100644 --- a/.github/workflows/Cargo.yml +++ b/.github/workflows/Cargo.yml @@ -17,7 +17,50 @@ jobs: - name: Check out code uses: actions/checkout@v4 - # 2) (Optional) install a particular Rust toolchain (stable, nightly, etc.) + # 2) Write Cargo.toml content from this YAML into an actual Cargo.toml file + - name: Write Cargo.toml + run: | + cat < Cargo.toml + [package] + name = "pmll_logic_loop_knowledge_block" + version = "0.1.0" + edition = "2021" + + # Author info; you can add more names or contact if needed. + authors = ["Josef Kurk Edwards (OpenAI) "] + + description = "An example Rust project demonstrating cargo build" + license = "MIT" + + [dependencies] + # Add your dependencies here, for example: + # serde = "1.0" + # rand = "0.8" + + [package] + name = "pmll_logic_loop_knowledge_block" + version = "0.1.1" + edition = "2021" + license = "MIT" + authors = ["Josef Kurk Edwards (OpenAI) "] + description = "An upgraded Rust project demonstrating cargo build and common dependencies." + + # Optional: if you're publishing to crates.io, you can also include: + # repository = "https://github.com/YOUR_USERNAME/YOUR_REPO" + + [dependencies] + # Example commonly used crates: + serde = "1.0" + rand = "0.8" + + # If you need additional crates, add them here (e.g. 'reqwest', 'tokio', etc.) + + [profile.release] + # Example release build settings, if needed + opt-level = 3 + EOF + + # 3) (Optional) install a particular Rust toolchain (stable, nightly, etc.) # The runner may already have stable Rust, so you can skip if you want default stable. - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable @@ -28,97 +71,25 @@ jobs: # toolchain: stable # override: true - # 3) Build your Rust project + # 4) Build your Rust project - name: Cargo Build run: cargo build --verbose # If Cargo.toml is in a subdirectory, adjust with --manifest-path subdir/Cargo.toml - # 4) (Optional) Run tests + # 5) Run tests - name: Cargo Test run: cargo test --verbose - # 5) (Optional) Lint or format checks - # - name: Cargo Clippy - # run: cargo clippy --all-targets --all-features -- -D warnings + # 6) (Optional) Run Clippy lint + - name: Cargo Clippy + run: cargo clippy --all-targets --all-features -- -D warnings - # - name: Cargo Fmt Check - # run: cargo fmt --all -- --check -[package] -name = "pmll_logic_loop_knowledge_block" -version = "0.1.0" -edition = "2021" + # 7) (Optional) Setup Node.js environment + - name: Setup Node.js environment + uses: actions/setup-node@v4.2.0 + with: + # Provide a node-version if needed. Example: '16.x' or '18.x' + node-version: '16.x' + # Additional config can go here -# Author info; you can add more names or contact if needed. -authors = ["Josef Kurk Edwards (OpenAI) "] - -description = "An example Rust project demonstrating cargo build" -license = "MIT" - -[dependencies] -# Add your dependencies here, for example: -# serde = "1.0" -# rand = "0.8" - -[package] -name = "pmll_logic_loop_knowledge_block" -version = "0.1.1" -edition = "2021" -license = "MIT" -authors = ["Josef Kurk Edwards (OpenAI) "] -description = "An upgraded Rust project demonstrating cargo build and common dependencies." - -# Optional: if you're publishing to crates.io, you can also include: -# repository = "https://github.com/YOUR_USERNAME/YOUR_REPO" - -[dependencies] -# Example commonly used crates: -serde = "1.0" -rand = "0.8" - -# If you need additional crates, add them here (e.g. 'reqwest', 'tokio', etc.) - -[profile.release] -# Example release build settings, if needed -opt-level = 3 - -name: "Cargo Build & Test" - -on: - # Run this workflow whenever a push is made to main - push: - branches: [ "main" ] - # Also run for pull requests targeting main - pull_request: - branches: [ "main" ] - -jobs: - build-and-test: - runs-on: ubuntu-latest - - steps: - # 1. Check out your repository's code so we have access to Cargo.toml, src/, etc. - - name: Check out code - uses: actions/checkout@v4 - - # 2. (Optional) Install a particular Rust toolchain, e.g. stable, nightly, or a pinned version. - # If you want the default stable from GitHub's runner, you can skip this step. - - name: Install Rust toolchain - uses: dtolnay/rust-toolchain@stable - # Alternatively: - # uses: actions-rs/toolchain@v1 - # with: - # profile: minimal - # toolchain: stable - # override: true - - # 3. Build your project in debug mode - - name: Cargo Build - run: cargo build --verbose - - # 4. (Optional) Run tests - - name: Cargo Test - run: cargo test --verbose - - # 5. (Optional) Run clippy or other checks - # - name: Cargo Clippy - # run: cargo clippy --all-targets --all-features -- -D warnings + From 88cb4ddaafe019d733d1bc3eea8cc34ffb6aa029 Mon Sep 17 00:00:00 2001 From: "J. K. Edwards" Date: Thu, 6 Feb 2025 07:17:21 -0500 Subject: [PATCH 6/7] Update IBC.go Signed-off-by: J. K. Edwards From 43111ff554f36e7b08220412a11789ad29501e44 Mon Sep 17 00:00:00 2001 From: josefkedwards Date: Thu, 6 Feb 2025 07:33:10 -0500 Subject: [PATCH 7/7] Update IBC.go Signed-off-by: josefkedwards --- .gofiles/IBC.go | 3472 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 3472 insertions(+) diff --git a/.gofiles/IBC.go b/.gofiles/IBC.go index 85e1da4..4851e41 100644 --- a/.gofiles/IBC.go +++ b/.gofiles/IBC.go @@ -11,6 +11,3478 @@ import ( "strings" "time" "testing" + "fmt" + "strings" + "time" + + sdk "github.com/cosmos/cosmos-sdk/types" + ibc "github.com/hypothetical/ibc-integration" + bitcore "github.com/hypothetical/bitcore-integration" + ethereum "github.com/hypothetical/ethereum-integration" +) + +const ( + MEMORY_CAPACITY = 10 + CHECK_INTERVAL_SECONDS = 60 // Check ledger integrity every minute +) + +type InterchainFiatBackedEngine struct { + shortTermMemory []string + longTermMemory map[string]int + JKECounter int + suspiciousTransactions []string + ATOMValue float64 +} + +func NewInterchainFiatBackedEngine() *InterchainFiatBackedEngine { + return &InterchainFiatBackedEngine{ + shortTermMemory: make([]string, 0, MEMORY_CAPACITY), + longTermMemory: make(map[string]int), + JKECounter: 0, + suspiciousTransactions: make([]string, 0), + ATOMValue: 5.89, // Starting from a hypothetical value + } +} + +func (fde *InterchainFiatBackedEngine) checkLedgerIntegrity() { + cosmosLedger := cosmosSdkGetFullLedger() + ibcLedger := ibc.GetLedgerState() + bitcoinLedger := bitcore.GetFullLedger() + ethereumLedger := ethereum.GetFullLedger() + + fde.checkFiatBackingConsistency(cosmosLedger, bitcoinLedger, ethereumLedger) + fde.detectFraud(cosmosLedger, ibcLedger, bitcoinLedger, ethereumLedger) +} + +func (fde *InterchainFiatBackedEngine) checkFiatBackingConsistency(cosmosLedger sdk.Ledger, bitcoinLedger bitcore.Ledger, ethereumLedger ethereum.Ledger) { + btcValue := bitcore.GetReserveValue("btc_address_example") + ethValue := ethereum.GetReserveValue("eth_address_example") + fde.ATOMValue = (btcValue + ethValue) / 10000 // Example ratio for pegging ATOM value +} + +func (fde *InterchainFiatBackedEngine) detectFraud(cosmosLedger sdk.Ledger, ibcLedger ibc.Ledger, bitcoinLedger bitcore.Ledger, ethereumLedger ethereum.Ledger) { + // Check transactions across all ledgers for suspicious activity + ledgers := []interface{}{cosmosLedger, ibcLedger, bitcoinLedger, ethereumLedger} + for _, ledger := range ledgers { + switch l := ledger.(type) { + case sdk.Ledger: + for _, tx := range l.Transactions { + if fde.isSuspicious(tx) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, tx.ID) + fde.logSuspiciousTransaction(tx) + } + } + case ibc.Ledger: + for _, tx := range l.Transactions { + if fde.isIBCSuspicious(tx) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, tx.ID) + fde.logIBCSuspiciousTransaction(tx) + } + } + case bitcore.Ledger: + for _, tx := range l.Transactions { + if fde.isBitcoinSuspicious(tx) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, tx.ID) + fde.logBitcoinSuspiciousTransaction(tx) + } + } + case ethereum.Ledger: + for _, tx := range l.Transactions { + if fde.isEthereumSuspicious(tx) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, tx.ID) + fde.logEthereumSuspiciousTransaction(tx) + } + } + } + } +} + +func (fde *InterchainFiatBackedEngine) isSuspicious(tx sdk.Tx) bool { return false } // Placeholder +func (fde *InterchainFiatBackedEngine) isIBCSuspicious(tx ibc.IBCTx) bool { return false } // Placeholder +func (fde *InterchainFiatBackedEngine) isBitcoinSuspicious(tx bitcore.Transaction) bool { return false } // Placeholder +func (fde *InterchainFiatBackedEngine) isEthereumSuspicious(tx ethereum.Transaction) bool { return false } // Placeholder + +func (fde *InterchainFiatBackedEngine) logSuspiciousTransaction(tx sdk.Tx) { + fmt.Printf("Suspicious Cosmos transaction detected: %s\n", tx.ID) + // cosmosSdkAlertGovernance(tx.ID) +} + +func (fde *InterchainFiatBackedEngine) logIBCSuspiciousTransaction(tx ibc.IBCTx) { + fmt.Printf("Suspicious IBC transaction detected: %s\n", tx.ID) + // ibc.AlertGovernance(tx.ID) +} + +func (fde *InterchainFiatBackedEngine) logBitcoinSuspiciousTransaction(tx bitcore.Transaction) { + fmt.Printf("Suspicious Bitcoin transaction detected: %s\n", tx.ID) + // bitcore.AlertNetwork(tx.ID) +} + +func (fde *InterchainFiatBackedEngine) logEthereumSuspiciousTransaction(tx ethereum.Transaction) { + fmt.Printf("Suspicious Ethereum transaction detected: %s\n", tx.ID) + // ethereum.AlertNetwork(tx.ID) +} + +func (fde *InterchainFiatBackedEngine) novelinput(input string) { + fde.manageMemory(input) + if strings.HasPrefix(input, "txid") { + fde.checkLedgerIntegrityForTransaction(input, "cosmos") + } else if strings.HasPrefix(input, "btc_txid") { + fde.checkLedgerIntegrityForTransaction(input, "bitcoin") + } else if strings.HasPrefix(input, "eth_txid") { + fde.checkLedgerIntegrityForTransaction(input, "ethereum") + } else if strings.HasPrefix(input, "ibc") { + fde.checkLedgerIntegrityForTransaction(input, "ibc") + } +} + +func (fde *InterchainFiatBackedEngine) manageMemory(input string) { + if len(fde.shortTermMemory) >= MEMORY_CAPACITY { + fde.shortTermMemory = fde.shortTermMemory[1:] + } + fde.shortTermMemory = append(fde.shortTermMemory, input) + + if count, exists := fde.longTermMemory[input]; exists { + fde.longTermMemory[input] = count + 1 + } else { + fde.longTermMemory[input] = 1 + } +} + +func (fde *InterchainFiatBackedEngine) checkLedgerIntegrityForTransaction(txid, chain string) { + var tx interface{} + switch chain { + case "cosmos": + tx = cosmosSdkGetTransaction(txid) + case "bitcoin": + tx = bitcore.GetTransaction(txid) + case "ethereum": + tx = ethereum.GetTransaction(txid) + case "ibc": + tx = ibc.GetTransaction(txid) + } + + switch t := tx.(type) { + case sdk.Tx: + if fde.isSuspicious(t) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, txid) + fde.logSuspiciousTransaction(t) + } + case ibc.IBCTx: + if fde.isIBCSuspicious(t) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, txid) + fde.logIBCSuspiciousTransaction(t) + } + case bitcore.Transaction: + if fde.isBitcoinSuspicious(t) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, txid) + fde.logBitcoinSuspiciousTransaction(t) + } + case ethereum.Transaction: + if fde.isEthereumSuspicious(t) { + fde.suspiciousTransactions = append(fde.suspiciousTransactions, txid) + fde.logEthereumSuspiciousTransaction(t) + } + } +} + +func (fde *InterchainFiatBackedEngine) processConversation(userInput string) string { + fde.novelinput(userInput) + + for { + fde.updatePersistentState() + for _, item := range fde.shortTermMemory { + fde.analyzeContext(item) + } + fde.checkLedgerIntegrity() + time.Sleep(CHECK_INTERVAL_SECONDS * time.Second) + } + + return "Processing..." +} + +func (fde *InterchainFiatBackedEngine) updatePersistentState() { + // Update state for Cosmos, IBC, Bitcoin, and Ethereum networks +} + +func (fde *InterchainFiatBackedEngine) analyzeContext(memoryItem string) { + // Analyze context across networks +} + +func main() { + engine := NewInterchainFiatBackedEngine() + fmt.Println("Interchain Fiat Backed Engine running...") + fmt.Printf("Current ATOM Value: $%.2f\n", engine.ATOMValue) + + engine.processConversation("") + // Note: This function never returns because processConversation runs package keeper_test + + errorsmod "cosmossdk.io/errors" + + clienttypes "github.com/cosmos/ibc-go/v9/modules/core/02-client/types" + connectiontypes "github.com/cosmos/ibc-go/v9/modules/core/03-connection/types" + channelkeeper "github.com/cosmos/ibc-go/v9/modules/core/04-channel/keeper" + "github.com/cosmos/ibc-go/v9/modules/core/04-channel/types" + commitmenttypes "github.com/cosmos/ibc-go/v9/modules/core/23-commitment/types" + host "github.com/cosmos/ibc-go/v9/modules/core/24-host" + "github.com/cosmos/ibc-go/v9/modules/core/exported" + ibctesting "github.com/cosmos/ibc-go/v9/testing" + "github.com/cosmos/ibc-go/v9/testing/mock" +) + +func (suite *KeeperTestSuite) TestChanUpgradeInit() { + var ( + path *ibctesting.Path + expSequence uint64 + upgradeFields types.UpgradeFields + ) + + testCases := []struct { + name string + malleate func() + expErr error + }{ + { + "success", + func() {}, + nil, + }, + { + "success with later upgrade sequence", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 4 }) + expSequence = 5 + }, + nil, + }, + { + "upgrade fields are identical to channel end", + func() { + channel := path.EndpointA.GetChannel() + upgradeFields = types.NewUpgradeFields(channel.Ordering, channel.ConnectionHops, channel.Version) + }, + errorsmod.Wrapf(types.ErrInvalidUpgrade, "existing channel end is identical to proposed upgrade channel end: got {ORDER_UNORDERED [connection-0] mock-version}"), + }, + { + "channel not found", + func() { + path.EndpointA.ChannelID = "invalid-channel" + path.EndpointA.ChannelConfig.PortID = "invalid-port" + }, + errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (invalid-port) channel ID (invalid-channel)"), + }, + { + "channel state is not in OPEN state", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.CLOSED }) + }, + errorsmod.Wrapf(types.ErrInvalidChannelState, "expected STATE_OPEN, got STATE_CLOSED"), + }, + { + "proposed channel connection not found", + func() { + upgradeFields.ConnectionHops = []string{"connection-100"} + }, + errorsmod.Wrapf(connectiontypes.ErrConnectionNotFound, "failed to retrieve connection: connection-100"), + }, + { + "invalid proposed channel connection state", + func() { + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) + upgradeFields.ConnectionHops = []string{"connection-100"} + }, + errorsmod.Wrapf(connectiontypes.ErrConnectionNotFound, "failed to retrieve connection: connection-100"), + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + expSequence = 1 + + upgradeFields = types.NewUpgradeFields(types.UNORDERED, []string{path.EndpointA.ConnectionID}, mock.UpgradeVersion) + + tc.malleate() + + upgrade, err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeInit( + suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeFields, + ) + + if tc.expErr == nil { + ctx := suite.chainA.GetContext() + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeInitChannel(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade, upgrade.Fields.Version) + channel := path.EndpointA.GetChannel() + + suite.Require().NoError(err) + suite.Require().Equal(expSequence, channel.UpgradeSequence) + suite.Require().Equal(mock.Version, channel.Version) + suite.Require().Equal(types.OPEN, channel.State) + } else { + suite.Require().Error(err) + suite.Require().ErrorIs(err, tc.expErr) + } + }) + } +} + +func (suite *KeeperTestSuite) TestChanUpgradeTry() { + var ( + path *ibctesting.Path + proposedUpgrade types.Upgrade + counterpartyUpgrade types.Upgrade + ) + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success", + func() {}, + nil, + }, + { + "success: crossing hellos", + func() { + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + }, + nil, + }, + { + "success: upgrade sequence is fast forwarded to counterparty upgrade sequence", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 5 }) + }, + nil, + }, + { + "channel not found", + func() { + path.EndpointB.ChannelID = ibctesting.InvalidID + }, + types.ErrChannelNotFound, + }, + { + "channel state is not in OPEN state", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.State = types.CLOSED }) + }, + types.ErrInvalidChannelState, + }, + { + "connection not found", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.ConnectionHops = []string{"connection-100"} }) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "invalid connection state", + func() { + path.EndpointB.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "initializing handshake fails, proposed connection hops do not exist", + func() { + proposedUpgrade.Fields.ConnectionHops = []string{ibctesting.InvalidID} + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "fails due to proof verification failure, counterparty channel ordering does not match expected ordering", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.Ordering = types.ORDERED }) + }, + commitmenttypes.ErrInvalidProof, + }, + { + "fails due to proof verification failure, counterparty upgrade connection hops are tampered with", + func() { + counterpartyUpgrade.Fields.ConnectionHops = []string{ibctesting.InvalidID} + }, + commitmenttypes.ErrInvalidProof, + }, + { + "fails due to incompatible upgrades, chainB proposes a new connection hop that does not match counterparty", + func() { + // reuse existing connection to create a new connection in a non OPEN state + connection := path.EndpointB.GetConnection() + // ensure counterparty connectionID does not match connectionID set in counterparty proposed upgrade + connection.Counterparty.ConnectionId = "connection-50" + + // set proposed connection in state + proposedConnectionID := "connection-100" //nolint:goconst + suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), proposedConnectionID, connection) + proposedUpgrade.Fields.ConnectionHops[0] = proposedConnectionID + }, + types.ErrIncompatibleCounterpartyUpgrade, + }, + { + "fails due to mismatch in upgrade sequences", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 5 }) + }, + // channel sequence will be returned so that counterparty inits on completely fresh sequence for both sides + types.NewUpgradeError(5, types.ErrInvalidUpgradeSequence), + }, + { + "fails due to mismatch in upgrade sequences: chainB is on incremented sequence without an upgrade indicating it has already processed upgrade at this sequence.", + func() { + errorReceipt := types.NewUpgradeError(1, types.ErrInvalidUpgrade) + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.WriteErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt) + + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 1 }) + }, + types.NewUpgradeError(1, types.ErrInvalidUpgradeSequence), + }, + { + "fails due to mismatch in upgrade sequences, crossing hello with the TRY chain having a higher sequence", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 4 }) + + // upgrade sequence is 5 after this call + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + }, + types.NewUpgradeError(4, types.ErrInvalidUpgradeSequence), + }, + { + // ChainA(Sequence: 0, mock-version-v2), ChainB(Sequence: 0, mock-version-v3) + // ChainA.INIT(Sequence: 1) + // ChainB.INIT(Sequence: 1) + // ChainA.TRY => error (incompatible versions) + // ChainB.TRY => error (incompatible versions) + "crossing hellos: fails due to incompatible version", + func() { + // use incompatible version + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = fmt.Sprintf("%s-v3", mock.Version) + proposedUpgrade = path.EndpointB.GetProposedUpgrade() + + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeTry() + suite.Require().Error(err) + suite.Require().ErrorContains(err, "incompatible counterparty upgrade") + suite.Require().Equal(uint64(1), path.EndpointA.GetChannel().UpgradeSequence) + }, + types.ErrIncompatibleCounterpartyUpgrade, + }, + { + // ChainA(Sequence: 0, mock-version-v2), ChainB(Sequence: 4, mock-version-v3) + // ChainA.INIT(Sequence: 1) + // ChainB.INIT(Sequence: 5) + // ChainA.TRY => error (incompatible versions) + // ChainB.TRY(ErrorReceipt: 4) + "crossing hellos: upgrade starts with mismatching upgrade sequences and try fails on counterparty due to incompatible version", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 4 }) + + // use incompatible version + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = fmt.Sprintf("%s-v3", mock.Version) + proposedUpgrade = path.EndpointB.GetProposedUpgrade() + + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeTry() + suite.Require().Error(err) + suite.Require().ErrorContains(err, "incompatible counterparty upgrade") + suite.Require().Equal(uint64(1), path.EndpointA.GetChannel().UpgradeSequence) + }, + types.NewUpgradeError(4, types.ErrInvalidUpgradeSequence), + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + proposedUpgrade = path.EndpointB.GetProposedUpgrade() + + var found bool + counterpartyUpgrade, found = path.EndpointA.Chain.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgrade(path.EndpointA.Chain.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + + tc.malleate() + + // ensure clients are up to date to receive valid proofs + suite.Require().NoError(path.EndpointB.UpdateClient()) + + channelProof, upgradeProof, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + + _, upgrade, err := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeTry( + suite.chainB.GetContext(), + path.EndpointB.ChannelConfig.PortID, + path.EndpointB.ChannelID, + proposedUpgrade.Fields.ConnectionHops, + counterpartyUpgrade.Fields, + path.EndpointA.GetChannel().UpgradeSequence, + channelProof, + upgradeProof, + proofHeight, + ) + + if tc.expError == nil { + suite.Require().NoError(err) + suite.Require().NotEmpty(upgrade) + suite.Require().Equal(proposedUpgrade.Fields, upgrade.Fields) + + channel := path.EndpointB.GetChannel() + suite.Require().Equal(types.FLUSHING, channel.State) + + nextSequenceSend, found := path.EndpointB.Chain.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceSend(path.EndpointB.Chain.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(found) + suite.Require().Equal(nextSequenceSend, upgrade.NextSequenceSend) + } else { + suite.assertUpgradeError(err, tc.expError) + } + }) + } +} + +// TestChanUpgrade_CrossingHellos_UpgradeSucceeds_AfterCancel verifies that under crossing hellos if upgrade +// sequences become out of sync, the upgrade can still be performed successfully after the upgrade is cancelled. +// ChainA(Sequence: 0), ChainB(Sequence 4) +// ChainA.INIT(Sequence: 1) +// ChainB.INIT(Sequence: 5) +// ChainB.TRY(ErrorReceipt: 4) +// ChainA.Cancel(Sequence: 4) +// ChainA.TRY(Sequence: 5) // fastforward +// ChainB.ACK => Success +// ChainA.Confirm => Success +// ChainB.Open => Success +func (suite *KeeperTestSuite) TestChanUpgrade_CrossingHellos_UpgradeSucceeds_AfterCancel() { + var path *ibctesting.Path + + suite.Run("setup path", func() { + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + }) + + suite.Run("chainA upgrade init", func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + }) + + suite.Run("set chainB upgrade sequence ahead of counterparty", func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 4 }) + }) + + suite.Run("chainB upgrade init (crossing hello)", func() { + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + }) + + suite.Run("chainB upgrade try fails with invalid sequence", func() { + err := path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + errorReceipt, found := path.EndpointB.Chain.GetSimApp().GetIBCKeeper().ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(4), errorReceipt.Sequence) + }) + + suite.Run("cancel upgrade on chainA and fast forward upgrade sequence", func() { + err := path.EndpointA.ChanUpgradeCancel() + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(types.OPEN, channel.State) + suite.Require().Equal(uint64(4), channel.UpgradeSequence) + }) + + suite.Run("try chainA upgrade now succeeds with synchronized upgrade sequences", func() { + err := path.EndpointA.ChanUpgradeTry() + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(types.FLUSHING, channel.State) + suite.Require().Equal(uint64(5), channel.UpgradeSequence) + }) + + suite.Run("upgrade handshake completes successfully", func() { + err := path.EndpointB.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeConfirm() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeOpen() + suite.Require().NoError(err) + }) + + suite.Run("assert successful upgrade expected channel state", func() { + channelA := path.EndpointA.GetChannel() + suite.Require().Equal(types.OPEN, channelA.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channelA.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, path.EndpointB.GetChannel().Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(5), channelA.UpgradeSequence, "upgrade sequence should be incremented") + + channelB := path.EndpointB.GetChannel() + suite.Require().Equal(types.OPEN, channelB.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(5), channelB.UpgradeSequence, "upgrade sequence should be incremented") + }) +} + +// TestChanUpgrade_CrossingHellos_UpgradeSucceeds_AfterCancelErrors verifies that under crossing hellos if upgrade +// sequences become out of sync, the upgrade can still be performed successfully after the cancel fails. +// ChainA(Sequence: 0), ChainB(Sequence 4) +// ChainA.INIT(Sequence: 1) +// ChainB.INIT(Sequence: 5) +// ChainA.TRY(Sequence: 5) // fastforward +// ChainB.TRY(ErrorReceipt: 4) +// ChainA.Cancel => Error (errorReceipt.Sequence < channel.UpgradeSequence) +// ChainB.ACK => Success +// ChainA.Confirm => Success +// ChainB.Open => Success +func (suite *KeeperTestSuite) TestChanUpgrade_CrossingHellos_UpgradeSucceeds_AfterCancelErrors() { + var ( + historicalChannelProof []byte + historicalUpgradeProof []byte + proofHeight clienttypes.Height + path *ibctesting.Path + ) + + suite.Run("setup path", func() { + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + }) + + suite.Run("chainA upgrade init", func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(uint64(1), channel.UpgradeSequence) + }) + + suite.Run("set chainB upgrade sequence ahead of counterparty", func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 4 }) + }) + + suite.Run("chainB upgrade init (crossing hello)", func() { + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + + channel := path.EndpointB.GetChannel() + suite.Require().Equal(uint64(5), channel.UpgradeSequence) + }) + + suite.Run("query proofs at chainA upgrade sequence 1", func() { + // commit block and update client on chainB + suite.coordinator.CommitBlock(suite.chainA, suite.chainB) + suite.Require().NoError(path.EndpointB.UpdateClient()) + // use proofs when chain A has not executed TRY yet and use them when executing TRY on chain B + historicalChannelProof, historicalUpgradeProof, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }) + + suite.Run("chainA upgrade try (fast-forwards sequence)", func() { + err := path.EndpointA.ChanUpgradeTry() + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(uint64(5), channel.UpgradeSequence) + }) + + suite.Run("chainB upgrade try fails with written error receipt (4)", func() { + // NOTE: ante handlers are bypassed here and the handler is invoked directly. + // Thus, we set the upgrade error receipt explicitly below + _, _, err := suite.chainB.GetSimApp().GetIBCKeeper().ChannelKeeper.ChanUpgradeTry( + suite.chainB.GetContext(), + path.EndpointB.ChannelConfig.PortID, + path.EndpointB.ChannelID, + path.EndpointB.GetChannelUpgrade().Fields.ConnectionHops, + path.EndpointA.GetChannelUpgrade().Fields, + 1, // proofs queried at chainA upgrade sequence 1 + historicalChannelProof, + historicalUpgradeProof, + proofHeight, + ) + suite.Require().Error(err) + suite.assertUpgradeError(err, types.NewUpgradeError(4, types.ErrInvalidUpgradeSequence)) + + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.WriteErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, err.(*types.UpgradeError)) + suite.coordinator.CommitBlock(suite.chainB) + }) + + suite.Run("chainA upgrade cancellation fails for invalid sequence", func() { + err := path.EndpointA.ChanUpgradeCancel() + suite.Require().Error(err) + suite.Require().ErrorContains(err, "invalid upgrade sequence") + + // assert channel remains in flushing state at upgrade sequence 5 + channel := path.EndpointA.GetChannel() + suite.Require().Equal(types.FLUSHING, channel.State) + suite.Require().Equal(uint64(5), channel.UpgradeSequence) + }) + + suite.Run("upgrade handshake completes successfully", func() { + err := path.EndpointB.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeConfirm() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeOpen() + suite.Require().NoError(err) + }) + + suite.Run("assert successful upgrade expected channel state", func() { + channelA := path.EndpointA.GetChannel() + suite.Require().Equal(types.OPEN, channelA.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channelA.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, path.EndpointB.GetChannel().Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(5), channelA.UpgradeSequence, "upgrade sequence should be incremented") + + channelB := path.EndpointB.GetChannel() + suite.Require().Equal(types.OPEN, channelB.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(5), channelB.UpgradeSequence, "upgrade sequence should be incremented") + }) +} + +func (suite *KeeperTestSuite) TestWriteUpgradeTry() { + var ( + path *ibctesting.Path + proposedUpgrade types.Upgrade + ) + + testCases := []struct { + name string + malleate func() + }{ + { + "success with no packet commitments", + func() {}, + }, + { + "success with packet commitments", + func() { + // manually set packet commitment + sequence, err := path.EndpointB.SendPacket(suite.chainB.GetTimeoutHeight(), 0, ibctesting.MockPacketData) + suite.Require().NoError(err) + suite.Require().Equal(uint64(1), sequence) + }, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + proposedUpgrade = path.EndpointB.GetProposedUpgrade() + + tc.malleate() + + ctx := suite.chainB.GetContext() + upgradedChannelEnd, upgradeWithAppCallbackVersion := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeTryChannel( + ctx, + path.EndpointB.ChannelConfig.PortID, + path.EndpointB.ChannelID, + proposedUpgrade, + proposedUpgrade.Fields.Version, + ) + + channel := path.EndpointB.GetChannel() + suite.Require().Equal(upgradedChannelEnd, channel) + + upgrade, found := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgrade(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(found) + suite.Require().Equal(upgradeWithAppCallbackVersion, upgrade) + }) + } +} + +func (suite *KeeperTestSuite) TestChanUpgradeAck() { + var ( + path *ibctesting.Path + counterpartyUpgrade types.Upgrade + ) + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success", + func() {}, + nil, + }, + { + "success with later upgrade sequence", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 10 }) + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 10 }) + + err := path.EndpointA.UpdateClient() + suite.Require().NoError(err) + }, + nil, + }, + { + "failure if initializing chain reinitializes before ACK", + func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + }, + commitmenttypes.ErrInvalidProof, // sequences are out of sync + }, + { + "channel not found", + func() { + path.EndpointA.ChannelID = ibctesting.InvalidID + path.EndpointA.ChannelConfig.PortID = ibctesting.InvalidID + }, + types.ErrChannelNotFound, + }, + { + "channel state is not in FLUSHING state", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.CLOSED }) + }, + types.ErrInvalidChannelState, + }, + { + "connection not found", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.ConnectionHops = []string{"connection-100"} }) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "invalid connection state", + func() { + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "upgrade not found", + func() { + store := suite.chainA.GetContext().KVStore(suite.chainA.GetSimApp().GetKey(exported.ModuleName)) + store.Delete(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + types.ErrUpgradeNotFound, + }, + { + "fails due to upgrade incompatibility", + func() { + // Need to set counterparty upgrade in state and update clients to ensure + // proofs submitted reflect the altered upgrade. + counterpartyUpgrade.Fields.ConnectionHops = []string{ibctesting.InvalidID} + path.EndpointB.SetChannelUpgrade(counterpartyUpgrade) + + suite.coordinator.CommitBlock(suite.chainB) + + err := path.EndpointA.UpdateClient() + suite.Require().NoError(err) + }, + types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade), + }, + { + "fails due to proof verification failure, counterparty channel ordering does not match expected ordering", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.Ordering = types.ORDERED }) + }, + commitmenttypes.ErrInvalidProof, + }, + { + "fails due to proof verification failure, counterparty update has unexpected sequence", + func() { + // Decrementing NextSequenceSend is sufficient to cause the proof to fail. + counterpartyUpgrade.NextSequenceSend-- + }, + commitmenttypes.ErrInvalidProof, + }, + { + "fails due to mismatch in upgrade ordering", + func() { + upgrade := path.EndpointA.GetChannelUpgrade() + upgrade.Fields.Ordering = types.NONE + + path.EndpointA.SetChannelUpgrade(upgrade) + }, + types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade), + }, + { + "counterparty timeout has elapsed", + func() { + // Need to set counterparty upgrade in state and update clients to ensure + // proofs submitted reflect the altered upgrade. + counterpartyUpgrade.Timeout = types.NewTimeout(clienttypes.NewHeight(0, 1), 0) + path.EndpointB.SetChannelUpgrade(counterpartyUpgrade) + + err := path.EndpointB.UpdateClient() + suite.Require().NoError(err) + err = path.EndpointA.UpdateClient() + suite.Require().NoError(err) + }, + types.NewUpgradeError(1, types.ErrTimeoutElapsed), + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + // manually set packet commitment so that the chainB channel state is FLUSHING + sequence, err := path.EndpointB.SendPacket(suite.chainB.GetTimeoutHeight(), 0, ibctesting.MockPacketData) + suite.Require().NoError(err) + suite.Require().Equal(uint64(1), sequence) + + err = path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + // ensure client is up to date to receive valid proofs + err = path.EndpointA.UpdateClient() + suite.Require().NoError(err) + + counterpartyUpgrade = path.EndpointB.GetChannelUpgrade() + + tc.malleate() + + channelProof, upgradeProof, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + + err = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeAck( + suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, counterpartyUpgrade, + channelProof, upgradeProof, proofHeight, + ) + + if tc.expError == nil { + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + // ChanUpgradeAck will set the channel state to FLUSHING + // It will be set to FLUSHING_COMPLETE in the write function. + suite.Require().Equal(types.FLUSHING, channel.State) + } else { + suite.assertUpgradeError(err, tc.expError) + } + }) + } +} + +func (suite *KeeperTestSuite) TestWriteChannelUpgradeAck() { + var ( + path *ibctesting.Path + proposedUpgrade types.Upgrade + ) + + testCases := []struct { + name string + malleate func() + hasPacketCommitments bool + }{ + { + "success with no packet commitments", + func() {}, + false, + }, + { + "success with packet commitments", + func() { + // manually set packet commitment + sequence, err := path.EndpointA.SendPacket(suite.chainB.GetTimeoutHeight(), 0, ibctesting.MockPacketData) + suite.Require().NoError(err) + suite.Require().Equal(uint64(1), sequence) + }, + true, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + tc.malleate() + + // perform the upgrade handshake. + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + + ctx := suite.chainA.GetContext() + proposedUpgrade = path.EndpointB.GetChannelUpgrade() + + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeAckChannel(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, proposedUpgrade) + + channel := path.EndpointA.GetChannel() + upgrade := path.EndpointA.GetChannelUpgrade() + suite.Require().Equal(mock.UpgradeVersion, upgrade.Fields.Version) + + if !tc.hasPacketCommitments { + suite.Require().Equal(types.FLUSHCOMPLETE, channel.State) + } + counterpartyUpgrade, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(ok) + suite.Require().Equal(proposedUpgrade, counterpartyUpgrade) + }) + } +} + +func (suite *KeeperTestSuite) TestChanUpgrade_ReinitializedBeforeAck() { + var path *ibctesting.Path + suite.Run("setup path", func() { + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + }) + + suite.Run("chainA upgrade init", func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(uint64(1), channel.UpgradeSequence) + }) + + suite.Run("chainB upgrade try", func() { + err := path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + }) + + suite.Run("chainA upgrade init reinitialized after ack", func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(uint64(2), channel.UpgradeSequence) + }) + + suite.Run("chan upgrade ack fails", func() { + err := path.EndpointA.ChanUpgradeAck() + suite.Require().Error(err) + }) + + suite.Run("chainB upgrade cancel", func() { + err := path.EndpointB.ChanUpgradeCancel() + suite.Require().NoError(err) + }) + + suite.Run("upgrade handshake succeeds on new upgrade attempt", func() { + err := path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeConfirm() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeOpen() + suite.Require().NoError(err) + }) + + suite.Run("assert successful upgrade expected channel state", func() { + channelA := path.EndpointA.GetChannel() + suite.Require().Equal(types.OPEN, channelA.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channelA.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, path.EndpointB.GetChannel().Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(2), channelA.UpgradeSequence, "upgrade sequence should be incremented") + + channelB := path.EndpointB.GetChannel() + suite.Require().Equal(types.OPEN, channelB.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, channelB.Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(2), channelB.UpgradeSequence, "upgrade sequence should be incremented") + }) +} + +func (suite *KeeperTestSuite) TestChanUpgradeConfirm() { + var ( + path *ibctesting.Path + counterpartyChannelState types.State + counterpartyUpgrade types.Upgrade + ) + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success", + func() {}, + nil, + }, + { + "success with later upgrade sequence", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 10 }) + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 10 }) + + err := path.EndpointB.UpdateClient() + suite.Require().NoError(err) + }, + nil, + }, + { + "success with in-flight packets on init chain", + func() { + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + seq, err := path.EndpointA.SendPacket(defaultTimeoutHeight, 0, ibctesting.MockPacketData) + suite.Require().Equal(uint64(1), seq) + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointB.UpdateClient() + suite.Require().NoError(err) + + counterpartyChannelState = path.EndpointA.GetChannel().State + counterpartyUpgrade = path.EndpointA.GetChannelUpgrade() + }, + nil, + }, + { + "success with in-flight packets on try chain", + func() { + portID, channelID := path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID + suite.chainB.GetSimApp().GetIBCKeeper().ChannelKeeper.SetPacketCommitment(suite.chainB.GetContext(), portID, channelID, 1, []byte("hash")) + }, + nil, + }, + { + "channel not found", + func() { + path.EndpointB.ChannelID = ibctesting.InvalidID + path.EndpointB.ChannelConfig.PortID = ibctesting.InvalidID + }, + types.ErrChannelNotFound, + }, + { + "channel is not in FLUSHING state", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.State = types.CLOSED }) + }, + types.ErrInvalidChannelState, + }, + { + "invalid counterparty channel state", + func() { + counterpartyChannelState = types.CLOSED + }, + types.ErrInvalidCounterparty, + }, + { + "connection not found", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.ConnectionHops = []string{"connection-100"} }) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "invalid connection state", + func() { + path.EndpointB.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "fails due to proof verification failure, counterparty channel ordering does not match expected ordering", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.Ordering = types.ORDERED }) + + err := path.EndpointB.UpdateClient() + suite.Require().NoError(err) + }, + commitmenttypes.ErrInvalidProof, + }, + { + "fails due to mismatch in upgrade ordering", + func() { + upgrade := path.EndpointA.GetChannelUpgrade() + upgrade.Fields.Ordering = types.NONE + + path.EndpointA.SetChannelUpgrade(upgrade) + + suite.coordinator.CommitBlock(suite.chainA) + + err := path.EndpointB.UpdateClient() + suite.Require().NoError(err) + }, + commitmenttypes.ErrInvalidProof, + }, + { + "counterparty timeout has elapsed", + func() { + // Need to set counterparty upgrade in state and update clients to ensure + // proofs submitted reflect the altered upgrade. + counterpartyUpgrade.Timeout = types.NewTimeout(clienttypes.NewHeight(0, 1), 0) + path.EndpointA.SetChannelUpgrade(counterpartyUpgrade) + + suite.coordinator.CommitBlock(suite.chainA) + + err := path.EndpointB.UpdateClient() + suite.Require().NoError(err) + }, + types.NewUpgradeError(1, types.ErrTimeoutElapsed), + }, + { + "upgrade not found", + func() { + path.EndpointB.Chain.DeleteKey(host.ChannelUpgradeKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)) + }, + types.ErrUpgradeNotFound, + }, + { + "upgrades are not compatible", + func() { + // the expected upgrade version is mock-version-v2 + counterpartyUpgrade.Fields.Version = fmt.Sprintf("%s-v3", mock.Version) + path.EndpointA.SetChannelUpgrade(counterpartyUpgrade) + + suite.coordinator.CommitBlock(suite.chainA) + + err := path.EndpointB.UpdateClient() + suite.Require().NoError(err) + }, + types.NewUpgradeError(1, types.ErrIncompatibleCounterpartyUpgrade), + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointB.UpdateClient() + suite.Require().NoError(err) + + counterpartyChannelState = path.EndpointA.GetChannel().State + counterpartyUpgrade = path.EndpointA.GetChannelUpgrade() + + tc.malleate() + + channelProof, upgradeProof, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + + err = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeConfirm( + suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, counterpartyChannelState, counterpartyUpgrade, + channelProof, upgradeProof, proofHeight, + ) + + if tc.expError == nil { + suite.Require().NoError(err) + } else { + suite.assertUpgradeError(err, tc.expError) + } + }) + } +} + +func (suite *KeeperTestSuite) TestWriteUpgradeConfirm() { + var ( + path *ibctesting.Path + proposedUpgrade types.Upgrade + ) + + testCases := []struct { + name string + malleate func() + hasPacketCommitments bool + }{ + { + "success with no packet commitments", + func() {}, + false, + }, + { + "success with packet commitments", + func() { + // manually set packet commitment + sequence, err := path.EndpointA.SendPacket(suite.chainB.GetTimeoutHeight(), 0, ibctesting.MockPacketData) + suite.Require().NoError(err) + suite.Require().Equal(uint64(1), sequence) + }, + true, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + tc.malleate() + + // perform the upgrade handshake. + suite.Require().NoError(path.EndpointB.ChanUpgradeInit()) + + suite.Require().NoError(path.EndpointA.ChanUpgradeTry()) + + suite.Require().NoError(path.EndpointB.ChanUpgradeAck()) + + ctx := suite.chainA.GetContext() + proposedUpgrade = path.EndpointB.GetChannelUpgrade() + + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeConfirmChannel(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, proposedUpgrade) + + channel := path.EndpointA.GetChannel() + upgrade := path.EndpointA.GetChannelUpgrade() + suite.Require().Equal(mock.UpgradeVersion, upgrade.Fields.Version) + + if !tc.hasPacketCommitments { + suite.Require().Equal(types.FLUSHCOMPLETE, channel.State) + } else { + suite.Require().Equal(types.FLUSHING, channel.State) + } + counterpartyUpgrade, ok := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(ok, "counterparty upgrade should be present") + suite.Require().Equal(proposedUpgrade, counterpartyUpgrade) + }) + } +} + +func (suite *KeeperTestSuite) TestChanUpgradeOpen() { + var path *ibctesting.Path + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success", + func() {}, + nil, + }, + { + "success: counterparty in flushcomplete", + func() { + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + // Need to create a packet commitment on A so as to keep it from going to FLUSHCOMPLETE if no inflight packets exist. + sequence, err := path.EndpointA.SendPacket(defaultTimeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData) + suite.Require().NoError(err) + packet := types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp) + err = path.EndpointB.RecvPacket(packet) + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeConfirm() + suite.Require().NoError(err) + + err = path.EndpointA.AcknowledgePacket(packet, ibctesting.MockAcknowledgement) + suite.Require().NoError(err) + + // cause the packet commitment on chain A to be deleted and the channel state to be updated to FLUSHCOMPLETE. + suite.coordinator.CommitBlock(suite.chainA, suite.chainB) + suite.Require().NoError(path.EndpointA.UpdateClient()) + }, + nil, + }, + { + "success: counterparty initiated new upgrade after opening", + func() { + // create reason to upgrade + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + "additional upgrade" + + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointA.UpdateClient() + suite.Require().NoError(err) + }, + nil, + }, + { + "success: counterparty upgrade sequence is incorrect", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence-- }) + }, + types.ErrInvalidUpgradeSequence, + }, + { + "channel not found", + func() { + path.EndpointA.ChannelConfig.PortID = ibctesting.InvalidID + }, + types.ErrChannelNotFound, + }, + { + "channel state is not FLUSHCOMPLETE", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.FLUSHING }) + }, + types.ErrInvalidChannelState, + }, + { + "connection not found", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.ConnectionHops = []string{"connection-100"} }) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "invalid connection state", + func() { + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "invalid counterparty channel state", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.State = types.CLOSED }) + }, + types.ErrInvalidCounterparty, + }, + } + + // Create an initial path used only to invoke a ChanOpenInit handshake. + // This bumps the channel identifier generated for chain A on the + // next path used to run the upgrade handshake. + // See issue 4062. + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.SetupConnections() + suite.Require().NoError(path.EndpointA.ChanOpenInit()) + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeAck() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeConfirm() + suite.Require().NoError(err) + + err = path.EndpointA.UpdateClient() + suite.Require().NoError(err) + + tc.malleate() + + channelKey := host.ChannelKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + channelProof, proofHeight := path.EndpointB.QueryProof(channelKey) + + err = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeOpen( + suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, + path.EndpointB.GetChannel().State, path.EndpointB.GetChannel().UpgradeSequence, channelProof, proofHeight, + ) + + if tc.expError == nil { + suite.Require().NoError(err) + } else { + suite.Require().ErrorIs(err, tc.expError) + } + }) + } +} + +func (suite *KeeperTestSuite) TestWriteUpgradeOpenChannel() { + var path *ibctesting.Path + + testCases := []struct { + name string + malleate func() + expPanic bool + }{ + { + name: "success", + malleate: func() {}, + expPanic: false, + }, + { + name: "channel not found", + malleate: func() { + path.EndpointA.Chain.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + expPanic: true, + }, + { + name: "upgrade not found", + malleate: func() { + path.EndpointA.Chain.DeleteKey(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + expPanic: true, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + // Need to create a packet commitment on A so as to keep it from going to OPEN if no inflight packets exist. + sequence, err := path.EndpointA.SendPacket(defaultTimeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData) + suite.Require().NoError(err) + packet := types.NewPacket(ibctesting.MockPacketData, sequence, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp) + err = path.EndpointB.RecvPacket(packet) + suite.Require().NoError(err) + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + suite.Require().NoError(path.EndpointA.ChanUpgradeAck()) + suite.Require().NoError(path.EndpointB.ChanUpgradeConfirm()) + + // Ack packet to delete packet commitment before calling WriteUpgradeOpenChannel + err = path.EndpointA.AcknowledgePacket(packet, ibctesting.MockAcknowledgement) + suite.Require().NoError(err) + + ctx := suite.chainA.GetContext() + + tc.malleate() + + if tc.expPanic { + suite.Require().Panics(func() { + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeOpenChannel(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + }) + } else { + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeOpenChannel(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channel := path.EndpointA.GetChannel() + + // Assert that channel state has been updated + suite.Require().Equal(types.OPEN, channel.State) + suite.Require().Equal(mock.UpgradeVersion, channel.Version) + + // Assert that state stored for upgrade has been deleted + upgrade, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().Equal(types.Upgrade{}, upgrade) + suite.Require().False(found) + + counterpartyUpgrade, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().Equal(types.Upgrade{}, counterpartyUpgrade) + suite.Require().False(found) + } + }) + } +} + +func (suite *KeeperTestSuite) TestWriteUpgradeOpenChannel_Ordering() { + var path *ibctesting.Path + + testCases := []struct { + name string + malleate func() + preUpgrade func() + postUpgrade func() + }{ + { + name: "success: ORDERED -> UNORDERED", + malleate: func() { + path.EndpointA.ChannelConfig.Order = types.ORDERED + path.EndpointB.ChannelConfig.Order = types.ORDERED + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Ordering = types.UNORDERED + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Ordering = types.UNORDERED + }, + preUpgrade: func() { + ctx := suite.chainA.GetContext() + + // assert that NextSeqAck is incremented to 2 because channel is still ordered + seq, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceAck(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(2), seq) + + // assert that NextSeqRecv is incremented to 2 because channel is still ordered + seq, found = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceRecv(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(2), seq) + + // Assert that pruning sequence start has not been initialized. + suite.Require().False(suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.HasPruningSequenceStart(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + + // Assert that recv start sequence has not been set + counterpartyNextSequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetRecvStartSequence(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().False(found) + suite.Require().Equal(uint64(0), counterpartyNextSequenceSend) + }, + postUpgrade: func() { + channel := path.EndpointA.GetChannel() + ctx := suite.chainA.GetContext() + + // Assert that channel state has been updated + suite.Require().Equal(types.OPEN, channel.State) + suite.Require().Equal(types.UNORDERED, channel.Ordering) + + // assert that NextSeqRecv is now 1, because channel is now UNORDERED + seq, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceRecv(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(1), seq) + + // assert that NextSeqAck is now 1, because channel is now UNORDERED + seq, found = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceAck(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(1), seq) + + // Assert that pruning sequence start has been initialized (set to 1) + suite.Require().True(suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.HasPruningSequenceStart(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + pruningSeq, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetPruningSequenceStart(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(1), pruningSeq) + + // Assert that the recv start sequence has been set correctly + counterpartySequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetRecvStartSequence(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(2), counterpartySequenceSend) + }, + }, + { + name: "success: UNORDERED -> ORDERED", + malleate: func() { + path.EndpointA.ChannelConfig.Order = types.UNORDERED + path.EndpointB.ChannelConfig.Order = types.UNORDERED + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Ordering = types.ORDERED + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Ordering = types.ORDERED + }, + preUpgrade: func() { + ctx := suite.chainA.GetContext() + + // assert that NextSeqRecv is 1 because channel is UNORDERED + seq, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceRecv(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(1), seq) + + // assert that NextSeqAck is 1 because channel is UNORDERED + seq, found = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceAck(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(1), seq) + + // Assert that pruning sequence start has not been initialized. + suite.Require().False(suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.HasPruningSequenceStart(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + + // Assert that recv start sequence has not been set + counterpartyNextSequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetRecvStartSequence(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().False(found) + suite.Require().Equal(uint64(0), counterpartyNextSequenceSend) + }, + postUpgrade: func() { + channel := path.EndpointA.GetChannel() + ctx := suite.chainA.GetContext() + + // Assert that channel state has been updated + suite.Require().Equal(types.OPEN, channel.State) + suite.Require().Equal(types.ORDERED, channel.Ordering) + + // assert that NextSeqRecv is incremented to 2, because channel is now ORDERED + // NextSeqRecv updated in WriteUpgradeOpenChannel to latest sequence (one packet sent) + 1 + seq, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceRecv(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(2), seq) + + // assert that NextSeqAck is incremented to 2 because channel is now ORDERED + seq, found = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceAck(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(2), seq) + + // Assert that pruning sequence start has been initialized (set to 1) + suite.Require().True(suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.HasPruningSequenceStart(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + pruningSeq, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetPruningSequenceStart(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(1), pruningSeq) + + // Assert that the recv start sequence has been set correctly + counterpartySequenceSend, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetRecvStartSequence(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found) + suite.Require().Equal(uint64(2), counterpartySequenceSend) + }, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + + tc.malleate() + + path.Setup() + + // Need to create a packet commitment on A so as to keep it from going to OPEN if no inflight packets exist. + sequenceA, err := path.EndpointA.SendPacket(defaultTimeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData) + suite.Require().NoError(err) + packetA := types.NewPacket(ibctesting.MockPacketData, sequenceA, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp) + err = path.EndpointB.RecvPacket(packetA) + suite.Require().NoError(err) + + // send second packet from B to A + sequenceB, err := path.EndpointB.SendPacket(defaultTimeoutHeight, disabledTimeoutTimestamp, ibctesting.MockPacketData) + suite.Require().NoError(err) + packetB := types.NewPacket(ibctesting.MockPacketData, sequenceB, path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, defaultTimeoutHeight, disabledTimeoutTimestamp) + err = path.EndpointA.RecvPacket(packetB) + suite.Require().NoError(err) + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + suite.Require().NoError(path.EndpointA.ChanUpgradeAck()) + suite.Require().NoError(path.EndpointB.ChanUpgradeConfirm()) + + // Ack packets to delete packet commitments before calling WriteUpgradeOpenChannel + err = path.EndpointA.AcknowledgePacket(packetA, ibctesting.MockAcknowledgement) + suite.Require().NoError(err) + + err = path.EndpointB.AcknowledgePacket(packetB, ibctesting.MockAcknowledgement) + suite.Require().NoError(err) + + // pre upgrade assertions + tc.preUpgrade() + + tc.malleate() + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeOpenChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + + // post upgrade assertions + tc.postUpgrade() + + // Assert that state stored for upgrade has been deleted + upgrade, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().Equal(types.Upgrade{}, upgrade) + suite.Require().False(found) + + counterpartyUpgrade, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().Equal(types.Upgrade{}, counterpartyUpgrade) + suite.Require().False(found) + }) + } +} + +func (suite *KeeperTestSuite) TestChanUpgradeCancel() { + var ( + path *ibctesting.Path + errorReceipt types.ErrorReceipt + errorReceiptProof []byte + proofHeight clienttypes.Height + ) + + tests := []struct { + name string + malleate func() + expError error + }{ + { + name: "success with flushing state", + malleate: func() { + }, + expError: nil, + }, + { + name: "success with flush complete state", + malleate: func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.FLUSHCOMPLETE }) + + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + // the error receipt upgrade sequence and the channel upgrade sequence must match + errorReceipt.Sequence = path.EndpointA.GetChannel().UpgradeSequence + + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt) + + suite.coordinator.CommitBlock(suite.chainB) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey) + }, + expError: nil, + }, + { + name: "upgrade cannot be cancelled in FLUSHCOMPLETE with invalid error receipt", + malleate: func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.FLUSHCOMPLETE }) + + errorReceiptProof = nil + }, + expError: commitmenttypes.ErrInvalidProof, + }, + { + name: "channel not found", + malleate: func() { + path.EndpointA.Chain.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + expError: types.ErrChannelNotFound, + }, + { + name: "upgrade not found", + malleate: func() { + path.EndpointA.Chain.DeleteKey(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + expError: types.ErrUpgradeNotFound, + }, + { + name: "error receipt sequence less than channel upgrade sequence", + malleate: func() { + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + errorReceipt.Sequence = path.EndpointA.GetChannel().UpgradeSequence - 1 + + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt) + + suite.coordinator.CommitBlock(suite.chainB) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey) + }, + expError: types.ErrInvalidUpgradeSequence, + }, + { + name: "error receipt sequence greater than channel upgrade sequence when channel in FLUSHCOMPLETE", + malleate: func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.FLUSHCOMPLETE }) + }, + expError: types.ErrInvalidUpgradeSequence, + }, + { + name: "error receipt sequence smaller than channel upgrade sequence when channel in FLUSHCOMPLETE", + malleate: func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.FLUSHCOMPLETE }) + + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + errorReceipt.Sequence = path.EndpointA.GetChannel().UpgradeSequence - 1 + + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt) + + suite.coordinator.CommitBlock(suite.chainB) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey) + }, + expError: types.ErrInvalidUpgradeSequence, + }, + { + name: "connection not found", + malleate: func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.ConnectionHops = []string{"connection-100"} }) + }, + expError: connectiontypes.ErrConnectionNotFound, + }, + { + name: "channel is in flush complete, error verification failed", + malleate: func() { + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + errorReceipt.Message = ibctesting.InvalidID + + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt) + suite.coordinator.CommitBlock(suite.chainB) + }, + expError: commitmenttypes.ErrInvalidProof, + }, + { + name: "error verification failed", + malleate: func() { + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + errorReceipt.Message = ibctesting.InvalidID + suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt) + suite.coordinator.CommitBlock(suite.chainB) + }, + expError: commitmenttypes.ErrInvalidProof, + }, + { + name: "error verification failed with empty proof", + malleate: func() { + errorReceiptProof = nil + }, + expError: commitmenttypes.ErrInvalidProof, + }, + } + + for _, tc := range tests { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + + // cause the upgrade to fail on chain b so an error receipt is written. + // if the counterparty (chain A) upgrade sequence is less than the current sequence, (chain B) + // an upgrade error will be returned by chain B during ChanUpgradeTry. + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 1 }) + + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 2 }) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + suite.Require().NoError(path.EndpointB.UpdateClient()) + + // error receipt is written to chain B here. + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + errorReceiptProof, proofHeight = suite.chainB.QueryProof(upgradeErrorReceiptKey) + + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.FLUSHING }) + + tc.malleate() + + err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeCancel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, errorReceipt, errorReceiptProof, proofHeight) + + if tc.expError == nil { + suite.Require().NoError(err) + } else { + suite.Require().ErrorIs(err, tc.expError) + } + }) + } +} + +// TestChanUpgrade_UpgradeSucceeds_AfterCancel verifies that if upgrade sequences +// become out of sync, the upgrade can still be performed successfully after the upgrade is cancelled. +func (suite *KeeperTestSuite) TestChanUpgrade_UpgradeSucceeds_AfterCancel() { + path := ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + + // cause the upgrade to fail on chain b so an error receipt is written. + // if the counterparty (chain A) upgrade sequence is less than the current sequence, (chain B) + // an upgrade error will be returned by chain B during ChanUpgradeTry. + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 1 }) + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 5 }) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + suite.Require().NoError(path.EndpointB.UpdateClient()) + + // error receipt is written to chain B here. + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + var errorReceipt types.ErrorReceipt + suite.T().Run("error receipt written", func(t *testing.T) { + var ok bool + errorReceipt, ok = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + }) + + suite.T().Run("upgrade cancelled successfully", func(t *testing.T) { + upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + errorReceiptProof, proofHeight := suite.chainB.QueryProof(upgradeErrorReceiptKey) + + err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeCancel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, errorReceipt, errorReceiptProof, proofHeight) + suite.Require().NoError(err) + + // need to explicitly call WriteUpgradeOpenChannel as this usually would happen in the msg server layer. + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeCancelChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, errorReceipt.Sequence) + + channel := path.EndpointA.GetChannel() + suite.Require().Equal(types.OPEN, channel.State) + + suite.T().Run("verify upgrade sequence fastforwards to channelB sequence", func(t *testing.T) { + suite.Require().Equal(uint64(5), channel.UpgradeSequence) + }) + }) + + suite.T().Run("successfully completes upgrade", func(t *testing.T) { + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + suite.Require().NoError(path.EndpointA.ChanUpgradeAck()) + suite.Require().NoError(path.EndpointB.ChanUpgradeConfirm()) + suite.Require().NoError(path.EndpointA.ChanUpgradeOpen()) + }) + + suite.T().Run("channel in expected state", func(t *testing.T) { + channel := path.EndpointA.GetChannel() + suite.Require().Equal(types.OPEN, channel.State, "channel should be in OPEN state") + suite.Require().Equal(mock.UpgradeVersion, channel.Version, "version should be correctly upgraded") + suite.Require().Equal(mock.UpgradeVersion, path.EndpointB.GetChannel().Version, "version should be correctly upgraded") + suite.Require().Equal(uint64(6), channel.UpgradeSequence, "upgrade sequence should be incremented") + suite.Require().Equal(uint64(6), path.EndpointB.GetChannel().UpgradeSequence, "upgrade sequence should be incremented on counterparty") + }) +} + +func (suite *KeeperTestSuite) TestWriteUpgradeCancelChannel() { + var path *ibctesting.Path + + testCases := []struct { + name string + malleate func() + expPanic bool + }{ + { + name: "success", + malleate: func() {}, + expPanic: false, + }, + { + name: "channel not found", + malleate: func() { + path.EndpointA.Chain.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + expPanic: true, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + + // cause the upgrade to fail on chain b so an error receipt is written. + // if the counterparty (chain A) upgrade sequence is less than the current sequence, (chain B) + // an upgrade error will be returned by chain B during ChanUpgradeTry. + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 1 }) + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.UpgradeSequence = 2 }) + + err := path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointA.UpdateClient() + suite.Require().NoError(err) + + errorReceipt, ok := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + ctx := suite.chainA.GetContext() + tc.malleate() + + if tc.expPanic { + suite.Require().Panics(func() { + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeCancelChannel(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, errorReceipt.Sequence) + }) + } else { + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeCancelChannel(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, errorReceipt.Sequence) + + channel := path.EndpointA.GetChannel() + + // Verify that channel has been restored to previous state + suite.Require().Equal(types.OPEN, channel.State) + suite.Require().Equal(mock.Version, channel.Version) + suite.Require().Equal(errorReceipt.Sequence, channel.UpgradeSequence) + + // Assert that state stored for upgrade has been deleted + upgrade, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().Equal(types.Upgrade{}, upgrade) + suite.Require().False(found) + + counterpartyUpgrade, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().Equal(types.Upgrade{}, counterpartyUpgrade) + suite.Require().False(found) + } + }) + } +} + +func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { + var ( + path *ibctesting.Path + channelProof []byte + proofHeight exported.Height + ) + + timeoutUpgrade := func() { + upgrade := path.EndpointA.GetProposedUpgrade() + upgrade.Timeout = types.NewTimeout(clienttypes.ZeroHeight(), 1) + path.EndpointA.SetChannelUpgrade(upgrade) + suite.Require().NoError(path.EndpointB.UpdateClient()) + } + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success: proof timestamp has passed", + func() { + timeoutUpgrade() + + channelKey := host.ChannelKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + channelProof, proofHeight = path.EndpointB.QueryProof(channelKey) + }, + nil, + }, + { + "channel not found", + func() { + path.EndpointA.ChannelID = ibctesting.InvalidID + }, + types.ErrChannelNotFound, + }, + { + "channel state is not in FLUSHING or FLUSHINGCOMPLETE state", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.State = types.OPEN }) + }, + types.ErrInvalidChannelState, + }, + { + "current upgrade not found", + func() { + suite.chainA.DeleteKey(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + types.ErrUpgradeNotFound, + }, + { + "connection not found", + func() { + path.EndpointA.UpdateChannel(func(channel *types.Channel) { channel.ConnectionHops[0] = ibctesting.InvalidID }) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "connection not open", + func() { + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "unable to retrieve timestamp at proof height", + func() { + // TODO: #123 revert this when the upgrade timeout is not hard coded to 1000 + proofHeight = clienttypes.NewHeight(clienttypes.ParseChainID(suite.chainA.ChainID), uint64(suite.chainA.GetContext().BlockHeight())+1000) + }, + clienttypes.ErrConsensusStateNotFound, + }, + { + "invalid channel state proof", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.State = types.OPEN }) + + timeoutUpgrade() + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight = path.EndpointB.QueryProof(channelKey) + + // modify state so the proof becomes invalid. + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.State = types.FLUSHING }) + }, + commitmenttypes.ErrInvalidProof, + }, + { + "invalid counterparty upgrade sequence", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { + channel.UpgradeSequence = path.EndpointA.GetChannel().UpgradeSequence - 1 + }) + + timeoutUpgrade() + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight = path.EndpointB.QueryProof(channelKey) + }, + types.ErrInvalidUpgradeSequence, + }, + { + "timeout timestamp has not passed", + func() { + upgrade := path.EndpointA.GetProposedUpgrade() + upgrade.Timeout.Timestamp = math.MaxUint64 + path.EndpointA.SetChannelUpgrade(upgrade) + + suite.Require().NoError(path.EndpointB.UpdateClient()) + + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight = path.EndpointB.QueryProof(channelKey) + }, + types.ErrTimeoutNotReached, + }, + { + "counterparty channel state is not OPEN or FLUSHING (crossing hellos)", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.State = types.FLUSHCOMPLETE }) + + timeoutUpgrade() + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight = path.EndpointB.QueryProof(channelKey) + }, + types.ErrInvalidCounterparty, + }, + { + "counterparty proposed connection invalid", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.State = types.OPEN }) + + timeoutUpgrade() + + upgrade := path.EndpointA.GetChannelUpgrade() + upgrade.Fields.ConnectionHops = []string{"connection-100"} + path.EndpointA.SetChannelUpgrade(upgrade) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + suite.Require().NoError(path.EndpointB.UpdateClient()) + + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight = path.EndpointB.QueryProof(channelKey) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "counterparty channel already upgraded", + func() { + // put chainA channel into OPEN state since both sides are in FLUSHCOMPLETE + suite.Require().NoError(path.EndpointB.ChanUpgradeConfirm()) + + timeoutUpgrade() + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight = path.EndpointB.QueryProof(channelKey) + }, + types.ErrUpgradeTimeoutFailed, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + suite.Require().NoError(path.EndpointA.ChanUpgradeAck()) + + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight = path.EndpointB.QueryProof(channelKey) + + tc.malleate() + + err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeTimeout( + suite.chainA.GetContext(), + path.EndpointA.ChannelConfig.PortID, + path.EndpointA.ChannelID, + path.EndpointB.GetChannel(), + channelProof, + proofHeight, + ) + + if tc.expError == nil { + suite.Require().NoError(err) + } else { + suite.assertUpgradeError(err, tc.expError) + } + }) + } +} + +func (suite *KeeperTestSuite) TestStartFlush() { + var path *ibctesting.Path + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success", + func() {}, + nil, + }, + { + "channel not found", + func() { + path.EndpointB.ChannelID = "invalid-channel" + path.EndpointB.ChannelConfig.PortID = "invalid-port" + }, + types.ErrChannelNotFound, + }, + { + "connection not found", + func() { + path.EndpointB.UpdateChannel(func(channel *types.Channel) { channel.ConnectionHops[0] = ibctesting.InvalidID }) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "connection state is not in OPEN state", + func() { + path.EndpointB.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.INIT }) + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "next sequence send not found", + func() { + // Delete next sequence send key from store + store := suite.chainB.GetContext().KVStore(suite.chainB.GetSimApp().GetKey(exported.StoreKey)) + store.Delete(host.NextSequenceSendKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)) + }, + types.ErrSequenceSendNotFound, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + // crossing hellos so that the upgrade is created on chain B. + // the ChanUpgradeInit sub protocol is also called when it is not a crossing hello situation. + err = path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + + upgrade := path.EndpointB.GetChannelUpgrade() + + tc.malleate() + + err = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.StartFlushing( + suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, &upgrade, + ) + + if tc.expError != nil { + suite.assertUpgradeError(err, tc.expError) + } else { + channel := path.EndpointB.GetChannel() + + nextSequenceSend, ok := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetNextSequenceSend(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) + suite.Require().True(ok) + + suite.Require().Equal(types.FLUSHING, channel.State) + suite.Require().Equal(nextSequenceSend, upgrade.NextSequenceSend) + + expectedTimeoutTimestamp := types.DefaultTimeout.Timestamp + uint64(suite.chainB.GetContext().BlockTime().UnixNano()) + suite.Require().Equal(expectedTimeoutTimestamp, upgrade.Timeout.Timestamp) + suite.Require().Equal(clienttypes.ZeroHeight(), upgrade.Timeout.Height, "only timestamp should be set") + suite.Require().NoError(err) + } + }) + } +} + +func (suite *KeeperTestSuite) TestValidateUpgradeFields() { + var ( + proposedUpgrade *types.UpgradeFields + path *ibctesting.Path + ) + tests := []struct { + name string + malleate func() + expErr error + }{ + { + name: "change channel version", + malleate: func() { + proposedUpgrade.Version = mock.UpgradeVersion + }, + expErr: nil, + }, + { + name: "change connection hops", + malleate: func() { + path := ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + proposedUpgrade.ConnectionHops = []string{path.EndpointA.ConnectionID} + }, + expErr: nil, + }, + { + name: "fails with unmodified fields", + malleate: func() {}, + expErr: errorsmod.Wrapf(types.ErrInvalidUpgrade, "existing channel end is identical to proposed upgrade channel end: got {ORDER_UNORDERED [connection-0] mock-version}"), + }, + { + name: "fails when connection is not set", + malleate: func() { + storeKey := suite.chainA.GetSimApp().GetKey(exported.StoreKey) + kvStore := suite.chainA.GetContext().KVStore(storeKey) + kvStore.Delete(host.ConnectionKey(ibctesting.FirstConnectionID)) + }, + expErr: errorsmod.Wrapf(types.ErrInvalidUpgrade, "existing channel end is identical to proposed upgrade channel end: got {ORDER_UNORDERED [connection-0] mock-version}"), + }, + { + name: "fails when connection is not open", + malleate: func() { + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) + }, + expErr: errorsmod.Wrapf(types.ErrInvalidUpgrade, "existing channel end is identical to proposed upgrade channel end: got {ORDER_UNORDERED [connection-0] mock-version}"), + }, + { + name: "fails when connection versions do not exist", + malleate: func() { + // update channel version first so that existing channel end is not identical to proposed upgrade + proposedUpgrade.Version = mock.UpgradeVersion + + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { + c.Versions = []*connectiontypes.Version{} + }) + }, + expErr: errorsmod.Wrapf(connectiontypes.ErrInvalidVersion, "single version must be negotiated on connection before opening channel, got: []"), + }, + { + name: "fails when connection version does not support the new ordering", + malleate: func() { + // update channel version first so that existing channel end is not identical to proposed upgrade + proposedUpgrade.Version = mock.UpgradeVersion + + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { + c.Versions = []*connectiontypes.Version{connectiontypes.NewVersion("1", []string{"ORDER_ORDERED"})} + }) + }, + expErr: errorsmod.Wrapf(connectiontypes.ErrInvalidVersion, "connection version identifier:\"1\" features:\"ORDER_ORDERED\" does not support channel ordering: ORDER_UNORDERED"), + }, + } + + for _, tc := range tests { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + existingChannel := path.EndpointA.GetChannel() + proposedUpgrade = &types.UpgradeFields{ + Ordering: existingChannel.Ordering, + ConnectionHops: existingChannel.ConnectionHops, + Version: existingChannel.Version, + } + + tc.malleate() + + err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ValidateSelfUpgradeFields(suite.chainA.GetContext(), *proposedUpgrade, existingChannel) + if tc.expErr == nil { + suite.Require().NoError(err) + } else { + suite.Require().Error(err) + suite.Require().ErrorIs(err, tc.expErr) + } + }) + } +} + +func (suite *KeeperTestSuite) assertUpgradeError(actualError, expError error) { + suite.Require().Error(actualError) + + if expUpgradeError, ok := expError.(*types.UpgradeError); ok { + upgradeError, ok := actualError.(*types.UpgradeError) + suite.Require().True(ok) + suite.Require().Equal(expUpgradeError.GetErrorReceipt(), upgradeError.GetErrorReceipt()) + } + + suite.Require().True(errorsmod.IsOf(actualError, expError), fmt.Sprintf("expected error: %s, actual error: %s", expError, actualError)) +} + +// TestAbortUpgrade tests that when the channel handshake is aborted, the channel state +// is restored the previous state and that an error receipt is written, and upgrade state which +// is no longer required is deleted. +func (suite *KeeperTestSuite) TestAbortUpgrade() { + var ( + path *ibctesting.Path + upgradeError error + ) + + tests := []struct { + name string + malleate func() + expErr error + }{ + { + name: "success", + malleate: func() {}, + expErr: nil, + }, + { + name: "regular error", + malleate: func() { + // in app callbacks error receipts should still be written if a regular error is returned. + // i.e. not an instance of `types.UpgradeError` + upgradeError = types.ErrInvalidUpgrade + }, + expErr: nil, + }, + { + name: "channel does not exist", + malleate: func() { + suite.chainA.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + expErr: types.ErrChannelNotFound, + }, + { + name: "fails with nil upgrade error", + malleate: func() { + upgradeError = nil + }, + expErr: types.ErrInvalidUpgradeError, + }, + } + + for _, tc := range tests { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + channelKeeper := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper + + path.EndpointA.ChannelConfig.Version = mock.UpgradeVersion + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + + // fetch the upgrade before abort for assertions later on. + actualUpgrade, ok := channelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(ok, "upgrade should be found") + + upgradeError = types.NewUpgradeError(1, types.ErrInvalidChannel) + + tc.malleate() + + if tc.expErr == nil { + + ctx := suite.chainA.GetContext() + + suite.Require().NotPanics(func() { + channelKeeper.MustAbortUpgrade(ctx, path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError) + }) + + channel, found := channelKeeper.GetChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found, "channel should be found") + + suite.Require().Equal(types.OPEN, channel.State, "channel state should be %s", types.OPEN.String()) + + _, found = channelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().False(found, "upgrade info should be deleted") + + errorReceipt, found := channelKeeper.GetUpgradeErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().True(found, "error receipt should be found") + + if upgradeError, ok := upgradeError.(*types.UpgradeError); ok { + suite.Require().Equal(upgradeError.GetErrorReceipt(), errorReceipt, "error receipt does not match expected error receipt") + } + } else { + + suite.Require().Panics(func() { + channelKeeper.MustAbortUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError) + }) + + channel, found := channelKeeper.GetChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + if found { // test cases uses a channel that exists + suite.Require().Equal(types.OPEN, channel.State, "channel state should not be restored to %s", types.OPEN.String()) + } + + _, found = channelKeeper.GetUpgradeErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().False(found, "error receipt should not be found") + + upgrade, found := channelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + if found { // this should be all test cases except for when the upgrade is explicitly deleted. + suite.Require().Equal(actualUpgrade, upgrade, "upgrade info should not be deleted") + } + } + }) + } +} + +func (suite *KeeperTestSuite) TestCheckForUpgradeCompatibility() { + var ( + path *ibctesting.Path + upgradeFields types.UpgradeFields + counterpartyUpgradeFields types.UpgradeFields + ) + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success", + func() {}, + nil, + }, + { + "upgrade ordering is not the same on both sides", + func() { + upgradeFields.Ordering = types.ORDERED + }, + types.ErrIncompatibleCounterpartyUpgrade, + }, + { + "proposed connection is not found", + func() { + upgradeFields.ConnectionHops[0] = ibctesting.InvalidID + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "proposed connection is not in OPEN state", + func() { + // reuse existing connection to create a new connection in a non OPEN state + connectionEnd := path.EndpointB.GetConnection() + connectionEnd.State = connectiontypes.UNINITIALIZED + connectionEnd.Counterparty.ConnectionId = counterpartyUpgradeFields.ConnectionHops[0] // both sides must be each other's counterparty + + // set proposed connection in state + proposedConnectionID := "connection-100" + suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), proposedConnectionID, connectionEnd) + upgradeFields.ConnectionHops[0] = proposedConnectionID + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "proposed connection ends are not each other's counterparty", + func() { + // reuse existing connection to create a new connection in a non OPEN state + connectionEnd := path.EndpointB.GetConnection() + // ensure counterparty connectionID does not match connectionID set in counterparty proposed upgrade + connectionEnd.Counterparty.ConnectionId = "connection-50" + + // set proposed connection in state + proposedConnectionID := "connection-100" + suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), proposedConnectionID, connectionEnd) + upgradeFields.ConnectionHops[0] = proposedConnectionID + }, + types.ErrIncompatibleCounterpartyUpgrade, + }, + { + "proposed upgrade version is not the same on both sides", + func() { + upgradeFields.Version = mock.Version + }, + types.ErrIncompatibleCounterpartyUpgrade, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + upgradeFields = path.EndpointA.GetProposedUpgrade().Fields + counterpartyUpgradeFields = path.EndpointB.GetProposedUpgrade().Fields + + tc.malleate() + + err = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.CheckForUpgradeCompatibility(suite.chainB.GetContext(), upgradeFields, counterpartyUpgradeFields) + if tc.expError != nil { + suite.Require().ErrorIs(err, tc.expError) + } else { + suite.Require().NoError(err) + } + }) + } +} + +func (suite *KeeperTestSuite) TestChanUpgradeCrossingHelloWithHistoricalProofs() { + var path *ibctesting.Path + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success", + func() {}, + nil, + }, + { + "counterparty (chain B) has already progressed to ACK step", + func() { + err := path.EndpointB.ChanUpgradeAck() + suite.Require().NoError(err) + }, + types.ErrInvalidChannelState, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + + suite.coordinator.CommitBlock(suite.chainA, suite.chainB) + + err = path.EndpointB.UpdateClient() + suite.Require().NoError(err) + + historicalChannelProof, historicalUpgradeProof, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + + err = path.EndpointA.ChanUpgradeTry() + suite.Require().NoError(err) + + tc.malleate() + + _, upgrade, err := suite.chainB.GetSimApp().GetIBCKeeper().ChannelKeeper.ChanUpgradeTry( + suite.chainB.GetContext(), + path.EndpointB.ChannelConfig.PortID, + path.EndpointB.ChannelID, + path.EndpointB.GetChannelUpgrade().Fields.ConnectionHops, + path.EndpointA.GetChannelUpgrade().Fields, + 1, + historicalChannelProof, + historicalUpgradeProof, + proofHeight, + ) + + if tc.expError == nil { + suite.Require().NoError(err) + suite.Require().NotEmpty(upgrade) + } else { + suite.Require().ErrorIs(err, tc.expError) + } + }) + } +} + +func (suite *KeeperTestSuite) TestWriteErrorReceipt() { + var ( + path *ibctesting.Path + upgradeError *types.UpgradeError + channelKeeper *channelkeeper.Keeper + writeErrorReceiptFunc = func() { + channelKeeper.WriteErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError) + } + ) + + testCases := []struct { + name string + malleate func() + expResult func() + }{ + { + "success", + func() {}, + func() { + suite.NotPanics(func() { writeErrorReceiptFunc() }) + }, + }, + { + "success: existing error receipt found at a lower sequence", + func() { + // write an error sequence with a lower sequence number + previousUpgradeError := types.NewUpgradeError(upgradeError.GetErrorReceipt().Sequence-1, types.ErrInvalidUpgrade) + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, previousUpgradeError) + }, + func() { + suite.NotPanics(func() { writeErrorReceiptFunc() }) + }, + }, + { + "failure: existing error receipt found at a higher sequence", + func() { + // write an error sequence with a higher sequence number + previousUpgradeError := types.NewUpgradeError(upgradeError.GetErrorReceipt().Sequence+1, types.ErrInvalidUpgrade) + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, previousUpgradeError) + }, + func() { + suite.PanicsWithError(errorsmod.Wrap(types.ErrInvalidUpgradeSequence, "error receipt sequence (10) must be greater than existing error receipt sequence (11)").Error(), writeErrorReceiptFunc) + }, + }, + { + "failure: upgrade exists for error receipt being written", + func() { + // attempt to write error receipt for existing upgrade without deleting upgrade info + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + ch := path.EndpointA.GetChannel() + upgradeError = types.NewUpgradeError(ch.UpgradeSequence, types.ErrInvalidUpgrade) + }, + func() { + suite.PanicsWithError(errorsmod.Wrap(types.ErrInvalidUpgradeSequence, "attempting to write error receipt at sequence (1) while upgrade information exists at the same sequence").Error(), writeErrorReceiptFunc) + }, + }, + { + "failure: channel not found", + func() { + suite.chainA.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + func() { + suite.PanicsWithError(errorsmod.Wrap(types.ErrChannelNotFound, fmt.Sprintf("port ID (mock) channel ID (%s)", path.EndpointA.ChannelID)).Error(), writeErrorReceiptFunc) + }, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + path = ibctesting.NewPath(suite.chainA, suite.chainB) + path.Setup() + + channelKeeper = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper + + upgradeError = types.NewUpgradeError(10, types.ErrInvalidUpgrade) + + tc.malleate() + + package wasm + + wasmvmtypes "github.com/CosmWasm/wasmvm/v2/types" + capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types" + clienttypes "github.com/cosmos/ibc-go/v9/modules/core/02-client/types" + channeltypes "github.com/cosmos/ibc-go/v9/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v9/modules/core/05-port/types" + host "github.com/cosmos/ibc-go/v9/modules/core/24-host" + ibcexported "github.com/cosmos/ibc-go/v9/modules/core/exported" + + errorsmod "cosmossdk.io/errors" + + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/CosmWasm/wasmd/x/wasm/keeper" + "github.com/CosmWasm/wasmd/x/wasm/types" +) + +// DefaultMaxIBCCallbackGas is the default value of maximum gas that an IBC callback can use. +// If the callback uses more gas, it will be out of gas and the contract state changes will be reverted, +// but the transaction will be committed. +// Pass this to the callbacks middleware or choose a custom value. +const DefaultMaxIBCCallbackGas = uint64(1_000_000) + +var _ porttypes.IBCModule = IBCHandler{} + +// internal interface that is implemented by ibc middleware +type appVersionGetter interface { + // GetAppVersion returns the application level version with all middleware data stripped out + GetAppVersion(ctx sdk.Context, portID, channelID string) (string, bool) +} + +type IBCHandler struct { + keeper types.IBCContractKeeper + channelKeeper types.ChannelKeeper + appVersionGetter appVersionGetter +} + +func NewIBCHandler(k types.IBCContractKeeper, ck types.ChannelKeeper, vg appVersionGetter) IBCHandler { + return IBCHandler{keeper: k, channelKeeper: ck, appVersionGetter: vg} +} + +// OnChanOpenInit implements the IBCModule interface +func (i IBCHandler) OnChanOpenInit( + ctx sdk.Context, + order channeltypes.Order, + connectionHops []string, + portID string, + channelID string, + chanCap *capabilitytypes.Capability, + counterParty channeltypes.Counterparty, + version string, +) (string, error) { + // ensure port, version, capability + if err := ValidateChannelParams(channelID); err != nil { + return "", err + } + contractAddr, err := keeper.ContractFromPortID(portID) + if err != nil { + return "", errorsmod.Wrapf(err, "contract port id") + } + + msg := wasmvmtypes.IBCChannelOpenMsg{ + OpenInit: &wasmvmtypes.IBCOpenInit{ + Channel: wasmvmtypes.IBCChannel{ + Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID}, + CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId}, + Order: order.String(), + // DESIGN V3: this may be "" ?? + Version: version, + ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported. + }, + }, + } + + // Allow contracts to return a version (or default to proposed version if unset) + acceptedVersion, err := i.keeper.OnOpenChannel(ctx, contractAddr, msg) + if err != nil { + return "", err + } + if acceptedVersion == "" { // accept incoming version when nothing returned by contract + if version == "" { + return "", types.ErrEmpty.Wrap("version") + } + acceptedVersion = version + } + + // Claim channel capability passed back by IBC module + if err := i.keeper.ClaimCapability(ctx, chanCap, host.ChannelCapabilityPath(portID, channelID)); err != nil { + return "", errorsmod.Wrap(err, "claim capability") + } + return acceptedVersion, nil +} + +// OnChanOpenTry implements the IBCModule interface +func (i IBCHandler) OnChanOpenTry( + ctx sdk.Context, + order channeltypes.Order, + connectionHops []string, + portID, channelID string, + chanCap *capabilitytypes.Capability, + counterParty channeltypes.Counterparty, + counterpartyVersion string, +) (string, error) { + // ensure port, version, capability + if err := ValidateChannelParams(channelID); err != nil { + return "", err + } + + contractAddr, err := keeper.ContractFromPortID(portID) + if err != nil { + return "", errorsmod.Wrapf(err, "contract port id") + } + + msg := wasmvmtypes.IBCChannelOpenMsg{ + OpenTry: &wasmvmtypes.IBCOpenTry{ + Channel: wasmvmtypes.IBCChannel{ + Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID}, + CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: counterParty.PortId, ChannelID: counterParty.ChannelId}, + Order: order.String(), + Version: counterpartyVersion, + ConnectionID: connectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported. + }, + CounterpartyVersion: counterpartyVersion, + }, + } + + // Allow contracts to return a version (or default to counterpartyVersion if unset) + version, err := i.keeper.OnOpenChannel(ctx, contractAddr, msg) + if err != nil { + return "", err + } + if version == "" { + version = counterpartyVersion + } + + // Module may have already claimed capability in OnChanOpenInit in the case of crossing hellos + // (ie chainA and chainB both call ChanOpenInit before one of them calls ChanOpenTry) + // If module can already authenticate the capability then module already owns it, so we don't need to claim + // Otherwise, module does not have channel capability, and we must claim it from IBC + if !i.keeper.AuthenticateCapability(ctx, chanCap, host.ChannelCapabilityPath(portID, channelID)) { + // Only claim channel capability passed back by IBC module if we do not already own it + if err := i.keeper.ClaimCapability(ctx, chanCap, host.ChannelCapabilityPath(portID, channelID)); err != nil { + return "", errorsmod.Wrap(err, "claim capability") + } + } + + return version, nil +} + +// OnChanOpenAck implements the IBCModule interface +func (i IBCHandler) OnChanOpenAck( + ctx sdk.Context, + portID, channelID string, + counterpartyChannelID string, + counterpartyVersion string, +) error { + contractAddr, err := keeper.ContractFromPortID(portID) + if err != nil { + return errorsmod.Wrapf(err, "contract port id") + } + channelInfo, ok := i.channelKeeper.GetChannel(ctx, portID, channelID) + if !ok { + return errorsmod.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + channelInfo.Counterparty.ChannelId = counterpartyChannelID + + appVersion, ok := i.appVersionGetter.GetAppVersion(ctx, portID, channelID) + if !ok { + return errorsmod.Wrapf(channeltypes.ErrInvalidChannelVersion, "port ID (%s) channel ID (%s)", portID, channelID) + } + + msg := wasmvmtypes.IBCChannelConnectMsg{ + OpenAck: &wasmvmtypes.IBCOpenAck{ + Channel: toWasmVMChannel(portID, channelID, channelInfo, appVersion), + CounterpartyVersion: counterpartyVersion, + }, + } + return i.keeper.OnConnectChannel(ctx, contractAddr, msg) +} + +// OnChanOpenConfirm implements the IBCModule interface +func (i IBCHandler) OnChanOpenConfirm(ctx sdk.Context, portID, channelID string) error { + contractAddr, err := keeper.ContractFromPortID(portID) + if err != nil { + return errorsmod.Wrapf(err, "contract port id") + } + channelInfo, ok := i.channelKeeper.GetChannel(ctx, portID, channelID) + if !ok { + return errorsmod.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + appVersion, ok := i.appVersionGetter.GetAppVersion(ctx, portID, channelID) + if !ok { + return errorsmod.Wrapf(channeltypes.ErrInvalidChannelVersion, "port ID (%s) channel ID (%s)", portID, channelID) + } + msg := wasmvmtypes.IBCChannelConnectMsg{ + OpenConfirm: &wasmvmtypes.IBCOpenConfirm{ + Channel: toWasmVMChannel(portID, channelID, channelInfo, appVersion), + }, + } + return i.keeper.OnConnectChannel(ctx, contractAddr, msg) +} + +// OnChanCloseInit implements the IBCModule interface +func (i IBCHandler) OnChanCloseInit(ctx sdk.Context, portID, channelID string) error { + contractAddr, err := keeper.ContractFromPortID(portID) + if err != nil { + return errorsmod.Wrapf(err, "contract port id") + } + channelInfo, ok := i.channelKeeper.GetChannel(ctx, portID, channelID) + if !ok { + return errorsmod.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + appVersion, ok := i.appVersionGetter.GetAppVersion(ctx, portID, channelID) + if !ok { + return errorsmod.Wrapf(channeltypes.ErrInvalidChannelVersion, "port ID (%s) channel ID (%s)", portID, channelID) + } + + msg := wasmvmtypes.IBCChannelCloseMsg{ + CloseInit: &wasmvmtypes.IBCCloseInit{Channel: toWasmVMChannel(portID, channelID, channelInfo, appVersion)}, + } + err = i.keeper.OnCloseChannel(ctx, contractAddr, msg) + if err != nil { + return err + } + // emit events? + + return err +} + +// OnChanCloseConfirm implements the IBCModule interface +func (i IBCHandler) OnChanCloseConfirm(ctx sdk.Context, portID, channelID string) error { + // counterparty has closed the channel + contractAddr, err := keeper.ContractFromPortID(portID) + if err != nil { + return errorsmod.Wrapf(err, "contract port id") + } + channelInfo, ok := i.channelKeeper.GetChannel(ctx, portID, channelID) + if !ok { + return errorsmod.Wrapf(channeltypes.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) + } + appVersion, ok := i.appVersionGetter.GetAppVersion(ctx, portID, channelID) + if !ok { + return errorsmod.Wrapf(channeltypes.ErrInvalidChannelVersion, "port ID (%s) channel ID (%s)", portID, channelID) + } + + msg := wasmvmtypes.IBCChannelCloseMsg{ + CloseConfirm: &wasmvmtypes.IBCCloseConfirm{Channel: toWasmVMChannel(portID, channelID, channelInfo, appVersion)}, + } + err = i.keeper.OnCloseChannel(ctx, contractAddr, msg) + if err != nil { + return err + } + // emit events? + + return err +} + +func toWasmVMChannel(portID, channelID string, channelInfo channeltypes.Channel, appVersion string) wasmvmtypes.IBCChannel { + return wasmvmtypes.IBCChannel{ + Endpoint: wasmvmtypes.IBCEndpoint{PortID: portID, ChannelID: channelID}, + CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{PortID: channelInfo.Counterparty.PortId, ChannelID: channelInfo.Counterparty.ChannelId}, + Order: channelInfo.Ordering.String(), + Version: appVersion, + ConnectionID: channelInfo.ConnectionHops[0], // At the moment this list must be of length 1. In the future multi-hop channels may be supported. + } +} + +// OnRecvPacket implements the IBCModule interface +func (i IBCHandler) OnRecvPacket( + ctx sdk.Context, + channelVersion string, + packet channeltypes.Packet, + relayer sdk.AccAddress, +) ibcexported.Acknowledgement { + contractAddr, err := keeper.ContractFromPortID(packet.DestinationPort) + if err != nil { + // this must not happen as ports were registered before + panic(errorsmod.Wrapf(err, "contract port id")) + } + + em := sdk.NewEventManager() + msg := wasmvmtypes.IBCPacketReceiveMsg{Packet: newIBCPacket(packet), Relayer: relayer.String()} + ack, err := i.keeper.OnRecvPacket(ctx.WithEventManager(em), contractAddr, msg) + if err != nil { + ack = CreateErrorAcknowledgement(err) + // the state gets reverted, so we drop all captured events + } else if ack == nil || ack.Success() { + // emit all contract and submessage events on success + // nil ack is a success case, see: https://github.com/cosmos/ibc-go/blob/v7.0.0/modules/core/keeper/msg_server.go#L453 + ctx.EventManager().EmitEvents(em.Events()) + } + types.EmitAcknowledgementEvent(ctx, contractAddr, ack, err) + return ack +} + +// OnAcknowledgementPacket implements the IBCModule interface +func (i IBCHandler) OnAcknowledgementPacket( + ctx sdk.Context, + channelVersion string, + packet channeltypes.Packet, + acknowledgement []byte, + relayer sdk.AccAddress, +) error { + contractAddr, err := keeper.ContractFromPortID(packet.SourcePort) + if err != nil { + return errorsmod.Wrapf(err, "contract port id") + } + + err = i.keeper.OnAckPacket(ctx, contractAddr, wasmvmtypes.IBCPacketAckMsg{ + Acknowledgement: wasmvmtypes.IBCAcknowledgement{Data: acknowledgement}, + OriginalPacket: newIBCPacket(packet), + Relayer: relayer.String(), + }) + if err != nil { + return errorsmod.Wrap(err, "on ack") + } + return nil +} + +// OnTimeoutPacket implements the IBCModule interface +func (i IBCHandler) OnTimeoutPacket(ctx sdk.Context, channelVersion string, packet channeltypes.Packet, relayer sdk.AccAddress) error { + contractAddr, err := keeper.ContractFromPortID(packet.SourcePort) + if err != nil { + return errorsmod.Wrapf(err, "contract port id") + } + msg := wasmvmtypes.IBCPacketTimeoutMsg{Packet: newIBCPacket(packet), Relayer: relayer.String()} + err = i.keeper.OnTimeoutPacket(ctx, contractAddr, msg) + if err != nil { + return errorsmod.Wrap(err, "on timeout") + } + return nil +} + +// IBCSendPacketCallback implements the IBC Callbacks ContractKeeper interface +// see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper +func (i IBCHandler) IBCSendPacketCallback( + cachedCtx sdk.Context, + sourcePort string, + sourceChannel string, + timeoutHeight clienttypes.Height, + timeoutTimestamp uint64, + packetData []byte, + contractAddress, + packetSenderAddress string, +) error { + _, err := validateSender(contractAddress, packetSenderAddress) + if err != nil { + return err + } + + // no-op, since we are not interested in this callback + return nil +} + +// IBCOnAcknowledgementPacketCallback implements the IBC Callbacks ContractKeeper interface +// see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper +func (i IBCHandler) IBCOnAcknowledgementPacketCallback( + cachedCtx sdk.Context, + packet channeltypes.Packet, + acknowledgement []byte, + relayer sdk.AccAddress, + contractAddress, + packetSenderAddress string, +) error { + contractAddr, err := validateSender(contractAddress, packetSenderAddress) + if err != nil { + return err + } + + msg := wasmvmtypes.IBCSourceCallbackMsg{ + Acknowledgement: &wasmvmtypes.IBCAckCallbackMsg{ + Acknowledgement: wasmvmtypes.IBCAcknowledgement{Data: acknowledgement}, + OriginalPacket: newIBCPacket(packet), + Relayer: relayer.String(), + }, + } + err = i.keeper.IBCSourceCallback(cachedCtx, contractAddr, msg) + if err != nil { + return errorsmod.Wrap(err, "on source chain callback ack") + } + + return nil +} + +// IBCOnTimeoutPacketCallback implements the IBC Callbacks ContractKeeper interface +// see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper +func (i IBCHandler) IBCOnTimeoutPacketCallback( + cachedCtx sdk.Context, + packet channeltypes.Packet, + relayer sdk.AccAddress, + contractAddress, + packetSenderAddress string, +) error { + contractAddr, err := validateSender(contractAddress, packetSenderAddress) + if err != nil { + return err + } + + msg := wasmvmtypes.IBCSourceCallbackMsg{ + Timeout: &wasmvmtypes.IBCTimeoutCallbackMsg{ + Packet: newIBCPacket(packet), + Relayer: relayer.String(), + }, + } + err = i.keeper.IBCSourceCallback(cachedCtx, contractAddr, msg) + if err != nil { + return errorsmod.Wrap(err, "on source chain callback timeout") + } + return nil +} + +// IBCReceivePacketCallback implements the IBC Callbacks ContractKeeper interface +// see https://github.com/cosmos/ibc-go/blob/main/docs/architecture/adr-008-app-caller-cbs.md#contractkeeper +func (i IBCHandler) IBCReceivePacketCallback( + cachedCtx sdk.Context, + packet ibcexported.PacketI, + ack ibcexported.Acknowledgement, + contractAddress string, +) error { + // sender validation makes no sense here, as the receiver is never the sender + contractAddr, err := sdk.AccAddressFromBech32(contractAddress) + if err != nil { + return err + } + + msg := wasmvmtypes.IBCDestinationCallbackMsg{ + Ack: wasmvmtypes.IBCAcknowledgement{Data: ack.Acknowledgement()}, + Packet: newIBCPacket(packet), + } + + err = i.keeper.IBCDestinationCallback(cachedCtx, contractAddr, msg) + if err != nil { + return errorsmod.Wrap(err, "on destination chain callback") + } + + return nil +} + +func validateSender(contractAddr, senderAddr string) (sdk.AccAddress, error) { + contractAddress, err := sdk.AccAddressFromBech32(contractAddr) + if err != nil { + return nil, errorsmod.Wrapf(err, "contract address") + } + senderAddress, err := sdk.AccAddressFromBech32(senderAddr) + if err != nil { + return nil, errorsmod.Wrapf(err, "packet sender address") + } + + // We only allow the contract that sent the message to receive source chain callbacks for it. + if !contractAddress.Equals(senderAddress) { + return nil, errorsmod.Wrapf(types.ErrExecuteFailed, "contract address %s does not match packet sender %s", contractAddr, senderAddress) + } + + return contractAddress, nil +} + +func newIBCPacket(packet ibcexported.PacketI) wasmvmtypes.IBCPacket { + timeout := wasmvmtypes.IBCTimeout{ + Timestamp: packet.GetTimeoutTimestamp(), + } + timeoutHeight := packet.GetTimeoutHeight() + if !timeoutHeight.IsZero() { + timeout.Block = &wasmvmtypes.IBCTimeoutBlock{ + Height: timeoutHeight.GetRevisionHeight(), + Revision: timeoutHeight.GetRevisionNumber(), + } + } + + return wasmvmtypes.IBCPacket{ + Data: packet.GetData(), + Src: wasmvmtypes.IBCEndpoint{ChannelID: packet.GetSourceChannel(), PortID: packet.GetSourcePort()}, + Dest: wasmvmtypes.IBCEndpoint{ChannelID: packet.GetDestChannel(), PortID: packet.GetDestPort()}, + Sequence: packet.GetSequence(), + Timeout: timeout, + } +} + +func ValidateChannelParams(channelID string) error { + // NOTE: for escrow address security only 2^32 channels are allowed to be created + // Issue: https://github.com/cosmos/cosmos-sdk/issues/7737 + channelSequence, err := channeltypes.ParseChannelSequence(channelID) + if err != nil { + return err + } + if channelSequence > math.MaxUint32 { + return errorsmod.Wrapf(types.ErrMaxIBCChannels, "channel sequence %d is greater than max allowed transfer channels %d", channelSequence, math.MaxUint32) + } + return nil +} + +// CreateErrorAcknowledgement turns an error into an error acknowledgement. +// +// This function is x/wasm specific and might include the full error text in the future +// as we gain confidence that it is deterministic. Don't use it in other contexts. +// See also https://github.com/CosmWasm/wasmd/issues/1740. +func CreateErrorAcknowledgement(err error) ibcexported.Acknowledgement { + return channeltypes.NewErrorAcknowledgementWithCodespace(err) +} capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types"