Skip to content

Commit

Permalink
Hotfix: runtime error when trying to shift a negative amount in RISC-V
Browse files Browse the repository at this point in the history
  • Loading branch information
gboncoffee committed Jul 11, 2024
1 parent 3cb5549 commit 447e3eb
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions riscv/riscv.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func parseJ(i uint32) (uint8, uint32) {
return rd, signExtend(imm, 20)
}

func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func7 uint8) {
func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func7 uint8) error {
rs1v64, _ := m.GetRegister(uint64(rs1))
rs1v := int32(rs1v64)
rs2v64, _ := m.GetRegister(uint64(rs2))
Expand All @@ -146,7 +146,7 @@ func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func
case 0x2:
rs1v64 := int64(signExtend64(uint32(rs1v)))
rs2v64 := int64(uint64(rs2v)) & 0x00000000ffffffff
tmp := int64(rs1v64 * rs2v64) >> 32
tmp := int64(rs1v64*rs2v64) >> 32
r = int32(tmp)
case 0x3:
rs1v64 := uint64(rs1v) & 0x00000000ffffffff
Expand Down Expand Up @@ -187,7 +187,7 @@ func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func
r = rs1v + rs2v
}
case 0x1:
r = rs1v << rs2v
r = int32(uint32(rs1v) << uint32(rs2v))
case 0x2:
// In C this would look terrible but in Go bools and ints are different.
if rs1v < rs2v {
Expand All @@ -206,6 +206,9 @@ func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func
r = rs1v ^ rs2v
case 0x5:
if func7 == 0x20 {
if rs2v < 0 {
return fmt.Errorf("error executing sra: negative shift amount")
}
r = rs1v >> rs2v
} else {
r = int32(uint32(rs1v) >> uint32(rs2v))
Expand All @@ -220,6 +223,8 @@ func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func
m.SetRegister(uint64(rd), uint64(r))

m.pc += 4

return nil
}

func (m *RiscV) execImmArithmetic(rd uint8, rs1 uint8, imm uint32, func3 uint8) {
Expand Down Expand Up @@ -395,10 +400,12 @@ func (m *RiscV) execAuipc(rd uint8, imm uint32) {

func (m *RiscV) execute(i uint32) (*machine.Call, error) {
opcode := i & 0b01111111
var err error = nil

switch opcode {
case 0b0110011:
rd, rs1, rs2, func3, func7 := parseR(i)
m.execArithmetic(rd, rs1, rs2, func3, func7)
err = m.execArithmetic(rd, rs1, rs2, func3, func7)
case 0b0010011:
rd, rs1, imm, func3 := parseI(i)
m.execImmArithmetic(rd, rs1, imm, func3)
Expand Down Expand Up @@ -439,7 +446,7 @@ func (m *RiscV) execute(i uint32) (*machine.Call, error) {
return nil, fmt.Errorf(machine.InterCtx.Get("unknown opcode: %b"), opcode)
}

return nil, nil
return nil, err
}

func (m *RiscV) LoadProgram(program []uint8) error {
Expand Down

0 comments on commit 447e3eb

Please sign in to comment.