From 5fdee67ba5de6ce46135a689ed7830515e87e974 Mon Sep 17 00:00:00 2001 From: "Mr.UNIX" Date: Sun, 7 Jul 2024 18:52:07 +0100 Subject: [PATCH] VM: add support for multiple types in binary expressions --- CMakeLists.txt | 2 +- include/utils.h | 103 ++++++++++++++++++++++++++++++++++++++++++++- src/ast.cpp | 92 ++++++++++++++++++++++------------------ tests/vm_tests.cpp | 10 +++++ 4 files changed, 164 insertions(+), 43 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5754616..33e2610 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ project(SPL) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) -add_compile_options(-O3) +add_compile_options(-O0) file(GLOB_RECURSE TEST_SOURCES tests/*.cpp) file(GLOB_RECURSE SOURCES src/*.cpp) diff --git a/include/utils.h b/include/utils.h index 4d9856c..18b3efb 100644 --- a/include/utils.h +++ b/include/utils.h @@ -80,4 +80,105 @@ static inline Variable::Type varTypeConvert(AbstractSyntaxTree *ast) { #define VAR_CASE(OP, TYPE) \ case Variable::Type::TYPE: \ segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::OP##TYPE}); \ - break; \ No newline at end of file + break; + +static Variable::Type deduceType(Program &program, Segment &segment, AbstractSyntaxTree *ast) { + switch (ast->nodeType) { + case AbstractSyntaxTree::Type::Node: { + auto token = dynamic_cast(ast)->token; + switch (token.type) { + case Number: { + try { + std::stoi(token.value); + return Variable::Type::I32; + } catch (std::out_of_range &) { + goto long64; + } + long64: + try { + std::stol(token.value); + return Variable::Type::I64; + } catch (std::exception &) { + throw std::runtime_error("Invalid number: " + token.value); + } + } + case Identifier: { + if (segment.find_local(token.value) != -1) + return segment.locals[token.value].type; + if (program.find_global(token.value) != -1) + return program.segments[0].locals[token.value].type; + throw std::runtime_error("Identifier not found: " + token.value); + } + default: + throw std::runtime_error("Invalid type: " + token.value); + } + } + case AbstractSyntaxTree::Type::UnaryExpression: { + auto unary = dynamic_cast(ast); + return deduceType(program, segment, unary->expression); + } + case AbstractSyntaxTree::Type::BinaryExpression: { + auto binary = dynamic_cast(ast); + auto left = deduceType(program, segment, binary->left); + auto right = deduceType(program, segment, binary->right); + if ((left == Variable::Type::I32 || left == Variable::Type::I64) && + (right == Variable::Type::I32 || right == Variable::Type::I64)) + return left == Variable::Type::I64 || right == Variable::Type::I64 ? Variable::Type::I64 + : Variable::Type::I32; + if (left != right) + throw std::runtime_error("Type mismatch"); + return left; + } + case AbstractSyntaxTree::Type::FunctionCall: { + // TODO: add support for multiple return types + auto call = dynamic_cast(ast); + auto function = program.find_function(program.segments[segment.id], call->identifier.token.value); + if (function == -1) + throw std::runtime_error("Function not found: " + call->identifier.token.value); + return Variable::Type::I32; + } + default: + throw std::runtime_error("Invalid type: " + ast->typeStr); + } +} + +enum class GenericInstruction { + Add, + Sub, + Mul, + Div, + Mod, + Equal, + Less, + Greater, + GreaterEqual, + LessEqual, + NotEqual +}; + +#define TYPE_CASE(INS) \ + case GenericInstruction::INS: { \ + switch (type) { \ + case Variable::Type::I32: \ + return {Instruction::InstructionType::INS##I32}; \ + case Variable::Type::I64: \ + return {Instruction::InstructionType::INS##I64}; \ + default: \ + throw std::runtime_error("[getInstructionWithType] Invalid type"); \ + } \ + } +static inline Instruction getInstructionWithType(GenericInstruction instruction, Variable::Type type) { + switch (instruction) { + TYPE_CASE(Add) + TYPE_CASE(Sub) + TYPE_CASE(Mul) + TYPE_CASE(Div) + TYPE_CASE(Mod) + TYPE_CASE(Equal) + TYPE_CASE(Less) + TYPE_CASE(Greater) + TYPE_CASE(GreaterEqual) + TYPE_CASE(LessEqual) + TYPE_CASE(NotEqual) + } +} diff --git a/src/ast.cpp b/src/ast.cpp index c6871d6..c7d12bb 100644 --- a/src/ast.cpp +++ b/src/ast.cpp @@ -20,26 +20,27 @@ bool Node::operator==(const AbstractSyntaxTree &other) const { void Node::compile(Program &program, Segment &segment) const { switch (token.type) { case Number: { - try { - segment.instructions.push_back( - Instruction{ - .type = Instruction::InstructionType::LoadI32, - .params = {.i32 = std::stoi(token.value)}, - }); - return; - } catch (std::out_of_range &) { - goto long64; + auto type = deduceType(program, segment, (AbstractSyntaxTree *) this); + switch (type) { + case Variable::Type::I32: + return segment.instructions.push_back( + Instruction{ + .type = Instruction::InstructionType::LoadI32, + .params = {.i32 = std::stoi(token.value)}, + }); + case Variable::Type::I64: + return segment.instructions.push_back( + Instruction{ + .type = Instruction::InstructionType::LoadI64, + .params = {.i64 = std::stol(token.value)}, + }); + default: + throw std::runtime_error("[Node::compile] Invalid type: " + token.value); } - long64: - segment.instructions.push_back( - Instruction{ - .type = Instruction::InstructionType::LoadI64, - .params = {.i64 = std::stol(token.value)}, - }); - } break; + } case Identifier: { - emitLoad(program, segment, token.value); - } break; + return emitLoad(program, segment, token.value); + } default: throw std::runtime_error("[Node::compile] This should not be accessed!"); } @@ -60,50 +61,59 @@ bool BinaryExpression::operator==(const AbstractSyntaxTree &other) const { op == otherBinaryExpression.op; } -// TODO: Add support for other types void BinaryExpression::compile(Program &program, Segment &segment) const { if (op.type == Assign) { right->compile(program, segment); emitStore(program, segment, dynamic_cast(*left).token.value); return; } + auto leftType = deduceType(program, segment, left); + auto rightType = deduceType(program, segment, right); + auto finalType = deduceType(program, segment, (AbstractSyntaxTree*) this); left->compile(program, segment); + if (leftType != finalType) { + switch (leftType) { + case Variable::Type::I32: + segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::ConvertI32toI64}); + break; + default: + throw std::runtime_error("[BinaryExpression::compile] Invalid type: " + left->typeStr); + } + } right->compile(program, segment); + if (rightType != finalType) { + switch (rightType) { + case Variable::Type::I32: + segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::ConvertI32toI64}); + break; + default: + throw std::runtime_error("[BinaryExpression::compile] Invalid type: "+ right->typeStr); + } + } switch (op.type) { case Plus: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::AddI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Add, finalType)); case Minus: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::SubI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Sub, finalType)); case Multiply: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::MulI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Mul, finalType)); case Divide: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::DivI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Div, finalType)); case Modulo: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::ModI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Mod, finalType)); case Greater: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::GreaterI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Greater, finalType)); case Less: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::LessI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Less, finalType)); case GreaterEqual: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::GreaterEqualI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::GreaterEqual, finalType)); case LessEqual: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::LessEqualI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::LessEqual, finalType)); case Equal: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::EqualI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::Equal, finalType)); case NotEqual: - segment.instructions.push_back(Instruction{.type = Instruction::InstructionType::NotEqualI32}); - break; + return segment.instructions.push_back(getInstructionWithType(GenericInstruction::NotEqual, finalType)); default: throw std::runtime_error("[BinaryExpression::compile] Invalid operator: " + op.value); } diff --git a/tests/vm_tests.cpp b/tests/vm_tests.cpp index 90cde6a..f21b911 100644 --- a/tests/vm_tests.cpp +++ b/tests/vm_tests.cpp @@ -67,6 +67,16 @@ TEST(VM, SimpleI64VariableDeclaration) { ASSERT_EQ(*static_cast(vm.topStack(sizeof(int64_t))), 42); } +TEST(VM, BinaryExpressionWithMultipleTypes) { + const char *input = "define a : i64 = 42;" + "define b : i32 = 42;" + "a + b;"; + VM vm; + auto program = compile(input); + vm.run(program); + ASSERT_EQ(*static_cast(vm.topStack(sizeof(int64_t))), 84); +} + TEST(VM, SimpleVariableAssignment) { const char *input = "define a : i32 = 42; a = 43; a;"; VM vm;