diff --git a/riscv/riscv.go b/riscv/riscv.go index 76351eb..e89f89c 100644 --- a/riscv/riscv.go +++ b/riscv/riscv.go @@ -127,7 +127,7 @@ func parseJ(i uint32) (uint8, uint32) { return rd, signExtend(imm, 20) } -func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func7 uint8) { +func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func7 uint8) error { rs1v64, _ := m.GetRegister(uint64(rs1)) rs1v := int32(rs1v64) rs2v64, _ := m.GetRegister(uint64(rs2)) @@ -146,7 +146,7 @@ func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func case 0x2: rs1v64 := int64(signExtend64(uint32(rs1v))) rs2v64 := int64(uint64(rs2v)) & 0x00000000ffffffff - tmp := int64(rs1v64 * rs2v64) >> 32 + tmp := int64(rs1v64*rs2v64) >> 32 r = int32(tmp) case 0x3: rs1v64 := uint64(rs1v) & 0x00000000ffffffff @@ -187,7 +187,7 @@ func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func r = rs1v + rs2v } case 0x1: - r = rs1v << rs2v + r = int32(uint32(rs1v) << uint32(rs2v)) case 0x2: // In C this would look terrible but in Go bools and ints are different. if rs1v < rs2v { @@ -206,6 +206,9 @@ func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func r = rs1v ^ rs2v case 0x5: if func7 == 0x20 { + if rs2v < 0 { + return fmt.Errorf("error executing sra: negative shift amount") + } r = rs1v >> rs2v } else { r = int32(uint32(rs1v) >> uint32(rs2v)) @@ -220,6 +223,8 @@ func (m *RiscV) execArithmetic(rd uint8, rs1 uint8, rs2 uint8, func3 uint8, func m.SetRegister(uint64(rd), uint64(r)) m.pc += 4 + + return nil } func (m *RiscV) execImmArithmetic(rd uint8, rs1 uint8, imm uint32, func3 uint8) { @@ -395,10 +400,12 @@ func (m *RiscV) execAuipc(rd uint8, imm uint32) { func (m *RiscV) execute(i uint32) (*machine.Call, error) { opcode := i & 0b01111111 + var err error = nil + switch opcode { case 0b0110011: rd, rs1, rs2, func3, func7 := parseR(i) - m.execArithmetic(rd, rs1, rs2, func3, func7) + err = m.execArithmetic(rd, rs1, rs2, func3, func7) case 0b0010011: rd, rs1, imm, func3 := parseI(i) m.execImmArithmetic(rd, rs1, imm, func3) @@ -439,7 +446,7 @@ func (m *RiscV) execute(i uint32) (*machine.Call, error) { return nil, fmt.Errorf(machine.InterCtx.Get("unknown opcode: %b"), opcode) } - return nil, nil + return nil, err } func (m *RiscV) LoadProgram(program []uint8) error {