Skip to content

Commit

Permalink
fix: mul 64 bit
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvladco committed Feb 18, 2025
1 parent 3ef7b36 commit 13de0f6
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 48 deletions.
106 changes: 58 additions & 48 deletions internal/polkavm/interpreter/mutations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package interpreter

import (
"math"
"math/big"
"math/bits"

"github.com/eigerco/strawberry/internal/polkavm"
Expand Down Expand Up @@ -39,7 +40,7 @@ func (i *Instance) StoreImmU32(address uint64, value uint64) error {

// StoreImmU64 store_imm_u64 μ′↺{νX...+8} = E8(νY)
func (i *Instance) StoreImmU64(address uint64, value uint64) error {
return i.store(address, uint64(value))
return i.store(address, value)
}

// Jump jump branch(νX , ⊺)
Expand Down Expand Up @@ -129,17 +130,17 @@ func (i *Instance) LoadU64(dst polkavm.Reg, address uint64) error {

// StoreU8 store_u8 μ′↺_νX = ωA mod 2^8
func (i *Instance) StoreU8(src polkavm.Reg, address uint64) error {
return i.store(address, uint8(uint32(i.regs[src])))
return i.store(address, uint8(i.regs[src]))
}

// StoreU16 store_u16 μ′↺_{νX...+2} = E2(ωA mod 2^16)
func (i *Instance) StoreU16(src polkavm.Reg, address uint64) error {
return i.store(address, uint16(uint32(i.regs[src])))
return i.store(address, uint16(i.regs[src]))
}

// StoreU32 store_u32 μ′↺_{νX...+4} = E4(ωA mod 2^32)
func (i *Instance) StoreU32(src polkavm.Reg, address uint64) error {
return i.store(address, uint32(uint32(i.regs[src])))
return i.store(address, uint32(i.regs[src]))
}

// StoreU64 store_u64 μ′↺_{νX...+8} = E8(ωA)
Expand Down Expand Up @@ -205,22 +206,22 @@ func (i *Instance) BranchGreaterUnsignedImm(regA polkavm.Reg, valueX uint64, tar

// BranchLessSignedImm branch_lt_s_imm branch(νY, Z8(ωA) < Z8(νX))
func (i *Instance) BranchLessSignedImm(regA polkavm.Reg, valueX uint64, target uint64) error {
return i.branch(int32(uint32(i.regs[regA])) < int32(valueX), target)
return i.branch(int64(i.regs[regA]) < int64(valueX), target)
}

// BranchLessOrEqualSignedImm branch_le_s_imm branch(νY , Z8(ωA) ≤ Z8(νX))
func (i *Instance) BranchLessOrEqualSignedImm(regA polkavm.Reg, valueX uint64, target uint64) error {
return i.branch(int32(uint32(i.regs[regA])) <= int32(valueX), target)
return i.branch(int64(i.regs[regA]) <= int64(valueX), target)
}

// BranchGreaterOrEqualSignedImm branch_ge_s_imm branch(νY, Z8(ωA) ≥ Z8(νX))
func (i *Instance) BranchGreaterOrEqualSignedImm(regA polkavm.Reg, valueX uint64, target uint64) error {
return i.branch(int32(uint32(i.regs[regA])) >= int32(valueX), target)
return i.branch(int64(i.regs[regA]) >= int64(valueX), target)
}

// BranchGreaterSignedImm branch_gt_s_imm branch(νY, Z8(ωA) > Z8(νX))
func (i *Instance) BranchGreaterSignedImm(regA polkavm.Reg, valueX uint64, target uint64) error {
return i.branch(int32(uint32(i.regs[regA])) > int32(valueX), target)
return i.branch(int64(i.regs[regA]) > int64(valueX), target)
}

// MoveReg move_reg ω′D = ωA
Expand Down Expand Up @@ -250,7 +251,7 @@ func (i *Instance) CountSetBits64(dst polkavm.Reg, s polkavm.Reg) {

// CountSetBits32 count_set_bits_32 ω′D = {31;i=0}∑ B4(ωA mod 2^32)_i
func (i *Instance) CountSetBits32(dst polkavm.Reg, s polkavm.Reg) {
i.setAndSkip(dst, uint64(uint32(bits.OnesCount32(uint32(i.regs[s])))))
i.setAndSkip(dst, uint64(bits.OnesCount32(uint32(i.regs[s]))))
}

// LeadingZeroBits64 leading_zero_bits_64 ω′D = max(n ∈ N65) where {i<n;i=0}∑ B8(ωA)_i = 0
Expand All @@ -260,7 +261,7 @@ func (i *Instance) LeadingZeroBits64(dst polkavm.Reg, s polkavm.Reg) {

// LeadingZeroBits32 leading_zero_bits_32 ω′D = max(n ∈ N33) where {i<n;i=0}∑ B4(ωA mod 232)_i = 0
func (i *Instance) LeadingZeroBits32(dst polkavm.Reg, s polkavm.Reg) {
i.setAndSkip(dst, uint64(uint32(bits.LeadingZeros32(uint32(i.regs[s])))))
i.setAndSkip(dst, uint64(bits.LeadingZeros32(uint32(i.regs[s]))))
}

// TrailingZeroBits64 trailing_zero_bits_64 ω′D = max(n ∈ N65) where {i<n;i=0}∑ B8(ωA)_63−i = 0
Expand All @@ -270,17 +271,17 @@ func (i *Instance) TrailingZeroBits64(dst polkavm.Reg, s polkavm.Reg) {

// TrailingZeroBits32 trailing_zero_bits_32 ω′D = max(n ∈ N33) where {i<n;i=0}∑ B4(ωA mod 232)_31−i = 0
func (i *Instance) TrailingZeroBits32(dst polkavm.Reg, s polkavm.Reg) {
i.setAndSkip(dst, uint64(uint32(bits.TrailingZeros32(uint32(i.regs[s])))))
i.setAndSkip(dst, uint64(bits.TrailingZeros32(uint32(i.regs[s]))))
}

// SignExtend8 sign_extend_8 ω′D = Z−1_8(Z_1(ωA mod 2^8))
func (i *Instance) SignExtend8(dst polkavm.Reg, s polkavm.Reg) {
i.setAndSkip(dst, uint64(int64(int8(uint8(i.regs[s])))))
i.setAndSkip(dst, uint64(int8(uint8(i.regs[s]))))
}

// SignExtend16 sign_extend_16 ω′D = Z−1_8(Z_2(ωA mod 2^16))
func (i *Instance) SignExtend16(dst polkavm.Reg, s polkavm.Reg) {
i.setAndSkip(dst, uint64(int64(int16(uint16(i.regs[s])))))
i.setAndSkip(dst, uint64(int16(uint16(i.regs[s]))))
}

// ZeroExtend16 zero_extend_16 ω′D = ωA mod 2^16
Expand Down Expand Up @@ -385,7 +386,7 @@ func (i *Instance) LoadIndirectU64(dst polkavm.Reg, base polkavm.Reg, offset uin

// AddImm32 add_imm_32 ω′A = X4((ωB + νX) mod 2^32)
func (i *Instance) AddImm32(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA])+uint32(value)), 4))
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA]+value)), 4))
}

// AndImm and_imm ∀i ∈ N64 ∶ B8(ω′A)_i = B8(ωB)_i ∧ B8(νX)_i
Expand All @@ -403,9 +404,9 @@ func (i *Instance) OrImm(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, (i.regs[regA])|value)
}

// MulImm32 mul_imm_32 ω′A = X4((ωB ⋅ νX ) mod 2^32)
// MulImm32 mul_imm_32 ω′A = X4((ωB ⋅ νX) mod 2^32)
func (i *Instance) MulImm32(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA])*uint32(value)), 4))
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA]*value)), 4))
}

// SetLessThanUnsignedImm set_lt_u_imm ω′A = ωB < νX
Expand All @@ -415,15 +416,15 @@ func (i *Instance) SetLessThanUnsignedImm(dst polkavm.Reg, regA polkavm.Reg, val

// SetLessThanSignedImm set_lt_s_imm ω′A = Z8(ωB) < Z8(νX)
func (i *Instance) SetLessThanSignedImm(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, bool2uint64(int32(uint32(i.regs[regA])) < int32(value)))
i.setAndSkip(dst, bool2uint64(int64(i.regs[regA]) < int64(value)))
}

// ShiftLogicalLeftImm32 shlo_l_imm_32 ω′A = X4((ωB ⋅ 2νX mod 32) mod 2^32)
// ShiftLogicalLeftImm32 shlo_l_imm_32 ω′A = X4((ωB ⋅ 2^νX mod 32) mod 2^32)
func (i *Instance) ShiftLogicalLeftImm32(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA])<<value), 4))
}

// ShiftLogicalRightImm32 shlo_r_imm_32 ω′A = X4(⌊ ωB mod 232 ÷ 2νX mod 32 ⌋)
// ShiftLogicalRightImm32 shlo_r_imm_32 ω′A = X4(⌊ ωB mod 2^32 ÷ 2^νX mod 32 ⌋)
func (i *Instance) ShiftLogicalRightImm32(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA])>>value), 4))
}
Expand All @@ -435,25 +436,25 @@ func (i *Instance) ShiftArithmeticRightImm32(dst polkavm.Reg, regA polkavm.Reg,

// NegateAndAddImm32 neg_add_imm_32 ω′A = X4((νX + 2^32 − ωB) mod 2^32)
func (i *Instance) NegateAndAddImm32(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, sext(uint64(uint32(value)-uint32(i.regs[regA])), 4))
i.setAndSkip(dst, sext(uint64(uint32(value-i.regs[regA])), 4))
}

// SetGreaterThanUnsignedImm set_gt_u_imm ω′A = ωB > νX
func (i *Instance) SetGreaterThanUnsignedImm(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, bool2uint64(i.regs[regA] > value))
}

// SetGreaterThanSignedImm set_gt_s_imm ω′A = Z8(ωB ) > Z8(νX)
// SetGreaterThanSignedImm set_gt_s_imm ω′A = Z8(ωB) > Z8(νX)
func (i *Instance) SetGreaterThanSignedImm(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, bool2uint64(int32(uint32(i.regs[regA])) > int32(value)))
i.setAndSkip(dst, bool2uint64(int64(i.regs[regA]) > int64(value)))
}

// ShiftLogicalLeftImmAlt32 shlo_l_imm_alt_32 ω′A = X4((νX ⋅ 2ωB mod 32) mod 2^32)
func (i *Instance) ShiftLogicalLeftImmAlt32(dst polkavm.Reg, regB polkavm.Reg, value uint64) {
i.setAndSkip(dst, sext(uint64(uint32(value)<<uint32(i.regs[regB])), 4))
i.setAndSkip(dst, sext(uint64(uint32(value<<i.regs[regB])), 4))
}

// ShiftLogicalRightImmAlt32 shlo_r_imm_alt_32 ω′A = X4(⌊ νX mod 2^32 ÷ 2ωB mod 32 ⌋)
// ShiftLogicalRightImmAlt32 shlo_r_imm_alt_32 ω′A = X4(⌊ νX mod 2^32 ÷ 2^ωB mod 32 ⌋)
func (i *Instance) ShiftLogicalRightImmAlt32(dst polkavm.Reg, regB polkavm.Reg, value uint64) {
i.setAndSkip(dst, sext(uint64(uint32(value)>>uint32(i.regs[regB])), 4))
}
Expand All @@ -465,15 +466,15 @@ func (i *Instance) ShiftArithmeticRightImmAlt32(dst polkavm.Reg, regB polkavm.Re

// CmovIfZeroImm cmov_iz_imm ω′A = νX if ωB = 0 otherwise ωA
func (i *Instance) CmovIfZeroImm(dst polkavm.Reg, c polkavm.Reg, s uint64) {
if uint32(i.regs[c]) == 0 {
if i.regs[c] == 0 {
i.regs[dst] = s
}
i.skip()
}

// CmovIfNotZeroImm cmov_nz_imm ω′A = νX if ωB ≠ 0 otherwise ωA
func (i *Instance) CmovIfNotZeroImm(dst polkavm.Reg, c polkavm.Reg, s uint64) {
if uint32(i.regs[c]) != 0 {
if i.regs[c] != 0 {
i.regs[dst] = s
}

Expand Down Expand Up @@ -532,7 +533,7 @@ func (i *Instance) RotateRight64Imm(dst polkavm.Reg, regA polkavm.Reg, value uin

// RotateRight64ImmAlt rot_r_64_imm_alt ∀i ∈ N64 ∶ B8(ω′A)i = B8(νX)_{(i+ωB) mod 64}
func (i *Instance) RotateRight64ImmAlt(dst polkavm.Reg, regA polkavm.Reg, value uint64) {
i.setAndSkip(dst, bits.RotateLeft64(uint64(value), -int(i.regs[regA])))
i.setAndSkip(dst, bits.RotateLeft64(value, -int(i.regs[regA])))
}

// RotateRight32Imm rot_r_32_imm ω′A = X4(x) where x ∈ N2^32, ∀i ∈ N32 ∶ B4(x)_i = B4(ωB)_{(i+νX ) mod 32}
Expand All @@ -547,54 +548,54 @@ func (i *Instance) RotateRight32ImmAlt(dst polkavm.Reg, regA polkavm.Reg, value

// BranchEq branch_eq branch(νX, ωA = ωB)
func (i *Instance) BranchEq(regA polkavm.Reg, regB polkavm.Reg, target uint64) error {
return i.branch(uint32(i.regs[regA]) == uint32(i.regs[regB]), target)
return i.branch(i.regs[regA] == i.regs[regB], target)
}

// BranchNotEq branch_ne branch(νX, ωA ≠ ωB)
func (i *Instance) BranchNotEq(regA polkavm.Reg, regB polkavm.Reg, target uint64) error {
return i.branch(uint32(i.regs[regA]) != uint32(i.regs[regB]), target)
return i.branch(i.regs[regA] != i.regs[regB], target)
}

// BranchLessUnsigned branch_lt_u branch(νX, ωA < ωB)
func (i *Instance) BranchLessUnsigned(regA polkavm.Reg, regB polkavm.Reg, target uint64) error {
return i.branch(uint32(i.regs[regA]) < uint32(i.regs[regB]), target)
return i.branch(i.regs[regA] < i.regs[regB], target)
}

// BranchLessSigned branch_lt_s branch(νX, Z8(ωA) < Z8(ωB))
func (i *Instance) BranchLessSigned(regA polkavm.Reg, regB polkavm.Reg, target uint64) error {
return i.branch(int32(uint32(i.regs[regA])) < int32(uint32(i.regs[regB])), target)
return i.branch(int64(i.regs[regA]) < int64(i.regs[regB]), target)
}

// BranchGreaterOrEqualUnsigned branch_ge_u branch(νX, ωA ≥ ωB)
func (i *Instance) BranchGreaterOrEqualUnsigned(regA polkavm.Reg, regB polkavm.Reg, target uint64) error {
return i.branch(uint32(i.regs[regA]) >= uint32(i.regs[regB]), target)
return i.branch(i.regs[regA] >= i.regs[regB], target)
}

// BranchGreaterOrEqualSigned branch_ge_s branch(νX, Z8(ωA) ≥ Z8(ωB))
func (i *Instance) BranchGreaterOrEqualSigned(regA polkavm.Reg, regB polkavm.Reg, target uint64) error {
return i.branch(int32(uint32(i.regs[regA])) >= int32(uint32(i.regs[regB])), target)
return i.branch(int64(i.regs[regA]) >= int64(i.regs[regB]), target)
}

// LoadImmAndJumpIndirect load_imm_jump_ind djump((ωB + νY) mod 232), ω′A = νX
func (i *Instance) LoadImmAndJumpIndirect(ra polkavm.Reg, base polkavm.Reg, value, offset uint64) error {
target := i.regs[base] + offset
i.regs[ra] = uint64(value)
i.regs[ra] = value
return i.djump(target)
}

// Add32 add_32 ω′D = X4((ωA + ωB) mod 2^32)
func (i *Instance) Add32(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA])+uint32(i.regs[regB])), 4))
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA]+i.regs[regB])), 4))
}

// Sub32 sub_32 ω′D = X4((ωA + 2^32 − (ωB mod 2^32)) mod 2^32)
func (i *Instance) Sub32(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA])-uint32(i.regs[regB])), 4))
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA]-i.regs[regB])), 4))
}

// Mul32 ω′D = X4((ωA ⋅ ωB) mod 2^32)
// Mul32 mul_32 ω′D = X4((ωA ⋅ ωB) mod 2^32)
func (i *Instance) Mul32(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA])*uint32(i.regs[regB])), 4))
i.setAndSkip(dst, sext(uint64(uint32(i.regs[regA]*i.regs[regB])), 4))
}

// DivUnsigned32 div_u_32 ω′D = 2^64 − 1 if ωB mod 2^32 = 0 otherwise X4(⌊ (ωA mod 2^32) ÷ (ωB mod 2^32) ⌋)
Expand Down Expand Up @@ -778,40 +779,49 @@ func (i *Instance) Or(dst polkavm.Reg, regA, regB polkavm.Reg) {

// MulUpperSignedSigned mul_upper_s_s ω′D = Z−1_8(⌊ (Z8(ωA) ⋅ Z8(ωB)) ÷ 2^64 ⌋)
func (i *Instance) MulUpperSignedSigned(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, uint64(int32((int64(uint32(i.regs[regA]))*int64(uint32(i.regs[regB])))>>32)))
lhs := big.NewInt(int64(i.regs[regA]))
rhs := big.NewInt(int64(i.regs[regB]))
mul := lhs.Mul(lhs, rhs)
i.setAndSkip(dst, uint64(mul.Rsh(mul, 64).Int64()))
}

// MulUpperUnsignedUnsigned mul_upper_u_u ω′D = ⌊ (ωA ⋅ ωB ) ÷ 2^64 ⌋
func (i *Instance) MulUpperUnsignedUnsigned(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, uint64(int32((int64(uint32(i.regs[regA]))*int64(uint32(i.regs[regB])))>>32)))
lhs := (&big.Int{}).SetUint64(i.regs[regA])
rhs := (&big.Int{}).SetUint64(i.regs[regB])
mul := lhs.Mul(lhs, rhs)
i.setAndSkip(dst, uint64(mul.Rsh(mul, 64).Int64()))
}

// MulUpperSignedUnsigned mul_upper_s_u ω′D = Z−1_8(⌊ (Z8(ωA) ⋅ ωB) ÷ 2^64 ⌋)
func (i *Instance) MulUpperSignedUnsigned(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, uint64((int64(uint32(i.regs[regA]))*int64(uint32(i.regs[regB])))>>32))
lhs := big.NewInt(int64(i.regs[regA]))
rhs := (&big.Int{}).SetUint64(i.regs[regB])
mul := lhs.Mul(lhs, rhs)
i.setAndSkip(dst, uint64(mul.Rsh(mul, 64).Int64()))
}

// SetLessThanUnsigned set_lt_u ω′D = ωA < ωB
func (i *Instance) SetLessThanUnsigned(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, bool2uint64(uint32(i.regs[regA]) < uint32(i.regs[regB])))
i.setAndSkip(dst, bool2uint64(i.regs[regA] < i.regs[regB]))
}

// SetLessThanSigned set_lt_s ω′D = Z8(ωA) < Z8(ωB)
func (i *Instance) SetLessThanSigned(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, bool2uint64(int32(uint32(i.regs[regA])) < int32(uint32(i.regs[regB]))))
i.setAndSkip(dst, bool2uint64(int32(i.regs[regA]) < int32(i.regs[regB])))
}

// CmovIfZero cmov_iz ω′D = ωA if ωB = 0 otherwise ωD
func (i *Instance) CmovIfZero(dst polkavm.Reg, s, c polkavm.Reg) {
if uint32(i.regs[c]) == 0 {
if i.regs[c] == 0 {
i.regs[dst] = i.regs[s]
}
i.skip()
}

// CmovIfNotZero cmov_nz ω′D = ωA if ωB ≠ 0 otherwise ωD
func (i *Instance) CmovIfNotZero(dst polkavm.Reg, s, c polkavm.Reg) {
if uint32(i.regs[c]) != 0 {
if i.regs[c] != 0 {
i.regs[dst] = i.regs[s]
}
i.skip()
Expand All @@ -822,9 +832,9 @@ func (i *Instance) RotateLeft64(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, bits.RotateLeft64(i.regs[regA], int(i.regs[regB])))
}

// RotateLeft32 rot_l_32 ω′D = X4(x) where x ∈ N232, ∀i ∈ N32 ∶ B4(x)_{(i+ωB) mod 32} = B4(ωA)_i
// RotateLeft32 rot_l_32 ω′D = X4(x) where x ∈ N2^32, ∀i ∈ N32 ∶ B4(x)_{(i+ωB) mod 32} = B4(ωA)_i
func (i *Instance) RotateLeft32(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, sext(uint64(bits.RotateLeft32(uint32(i.regs[regA]), int(uint32(i.regs[regB])))), 4))
i.setAndSkip(dst, sext(uint64(bits.RotateLeft32(uint32(i.regs[regA]), int(i.regs[regB]))), 4))
}

// RotateRight64 rot_r_64 ∀i ∈ N64 ∶ B8(ω′D)_i = B8(ωA)_{(i+ωB ) mod 64}
Expand All @@ -834,7 +844,7 @@ func (i *Instance) RotateRight64(dst polkavm.Reg, regA, regB polkavm.Reg) {

// RotateRight32 rot_r_32 ω′D = X4(x) where x ∈ N2^32, ∀i ∈ N32 ∶ B4(x)_i = B4(ωA)_{(i+ωB) mod 32}
func (i *Instance) RotateRight32(dst polkavm.Reg, regA, regB polkavm.Reg) {
i.setAndSkip(dst, sext(uint64(bits.RotateLeft32(uint32(i.regs[regA]), -int(uint32(i.regs[regB])))), 4))
i.setAndSkip(dst, sext(uint64(bits.RotateLeft32(uint32(i.regs[regA]), -int(i.regs[regB]))), 4))
}

// AndInverted and_inv ∀i ∈ N64 ∶ B8(ω′D)_i = B8(ωA)i ∧ ¬B8(ωB)_i
Expand Down

0 comments on commit 13de0f6

Please sign in to comment.