From ee8280aca23215796bc86b6249ff9beff4b29d49 Mon Sep 17 00:00:00 2001 From: Akuli Date: Fri, 10 Jan 2025 03:52:53 +0200 Subject: [PATCH] Clean up compiler/ast.jou (#578) --- compiler/ast.jou | 301 ++++++++++++++++++++++++++--------------- compiler/build_cfg.jou | 40 +++--- compiler/parser.jou | 22 +-- compiler/typecheck.jou | 74 +++++----- 4 files changed, 258 insertions(+), 179 deletions(-) diff --git a/compiler/ast.jou b/compiler/ast.jou index 4e6a96e2..ba27104b 100644 --- a/compiler/ast.jou +++ b/compiler/ast.jou @@ -1,3 +1,10 @@ +# This file defines data structures of the Abstract Syntax Tree. They are +# constructed in parser.jou. +# +# Many classes in this file have .print(), which can be used during debugging. If +# .print() exists, you probably don't need to call .print_with_tree_printer(), +# which is basically an internal detail of the .print() implementation. + import "stdlib/io.jou" import "stdlib/str.jou" import "stdlib/mem.jou" @@ -6,22 +13,24 @@ import "./errors_and_warnings.jou" # TODO: move to stdlib declare isprint(b: int) -> int + enum AstTypeKind: Named Pointer Array + class AstArrayType: member_type: AstType* length: AstExpression* - # TODO: use this def free(self) -> None: self->member_type->free() self->length->free() free(self->member_type) free(self->length) + class AstType: kind: AstTypeKind location: Location @@ -31,19 +40,15 @@ class AstType: value_type: AstType* # AstTypeKind::Pointer array: AstArrayType # AstTypeKind::Array - # TODO: use this def is_void(self) -> bool: return self->kind == AstTypeKind::Named and strcmp(self->name, "void") == 0 - # TODO: use this def is_none(self) -> bool: return self->kind == AstTypeKind::Named and strcmp(self->name, "None") == 0 - # TODO: use this def is_noreturn(self) -> bool: return self->kind == AstTypeKind::Named and strcmp(self->name, "noreturn") == 0 - # TODO: use this def print(self, show_lineno: bool) -> None: if self->kind == AstTypeKind::Named: printf("%s", self->name) @@ -59,7 +64,6 @@ class AstType: if show_lineno: printf(" [line %d]", self->location.lineno) - # TODO: use this def free(self) -> None: if self->kind == AstTypeKind::Pointer: self->value_type->free() @@ -67,6 +71,7 @@ class AstType: if self->kind == AstTypeKind::Array: self->array.free() + # Statements and expressions can be printed in a tree. # To see a tree, run: # @@ -86,6 +91,49 @@ class TreePrinter: snprintf(subprinter.prefix, sizeof subprinter.prefix, "%s| ", self->prefix) return subprinter + +# Foo::Bar +class AstEnumMember: + enum_name: byte[100] + member_name: byte[100] + + +# foo.bar, foo->bar +class AstClassField: + instance: AstExpression* + uses_arrow_operator: bool # distinguishes foo.bar and foo->bar + field_name: byte[100] + + def free(self) -> None: + self->instance->free() + free(self->instance) + + +# Foo{bar = 1, baz = 2} +class AstInstantiation: + class_name_location: Location # TODO: probably not necessary, can use location of the instantiate expression + class_name: byte[100] + nfields: int + field_names: byte[100]* + field_values: AstExpression* + + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("instantiate \"%s\"\n", self->class_name) + for i = 0; i < self->nfields; i++: + sub = tp.print_prefix(i == self->nfields - 1) + printf("field \"%s\": ", self->field_names[i]) + self->field_values[i].print_with_tree_printer(sub) + + def free(self) -> None: + for i = 0; i < self->nfields; i++: + self->field_values[i].free() + free(self->field_names) + free(self->field_values) + + enum AstExpressionKind: String Int @@ -116,11 +164,11 @@ enum AstExpressionKind: PostDecr # x-- # binary operators Add # x+y - Subtract # x-y - Multiply # x*y - Divide # x/y + Sub # x-y + Mul # x*y + Div # x/y Indexing # x[y] - Modulo # x % y + Mod # x % y Eq # x == y Ne # x != y Gt # x > y @@ -145,21 +193,16 @@ class AstExpression: bool_value: bool call: AstCall instantiation: AstInstantiation - as_: AstAsExpression* # Must be pointer, because it contains an AstExpression + as_: AstAs* # Must be pointer, because it contains an AstExpression array: AstArray varname: byte[100] float_or_double_text: byte[100] operands: AstExpression* # Only for operators. Length is arity, see get_arity() - # TODO: use this - def print(self) -> None: - self->print_with_tree_printer(TreePrinter{}) - - # TODO: use this def print_with_tree_printer(self, tp: TreePrinter) -> None: printf("[line %d] ", self->location.lineno) if self->kind == AstExpressionKind::String: - printf("\"") + printf("string \"") for s = self->string; *s != 0; s++: if isprint(*s) != 0: putchar(*s) @@ -169,13 +212,13 @@ class AstExpression: printf("\\x%02x", *s) printf("\"\n") elif self->kind == AstExpressionKind::Short: - printf("%hd (16-bit signed)\n", self->short_value) + printf("short %hd\n", self->short_value) elif self->kind == AstExpressionKind::Int: - printf("%d (32-bit signed)\n", self->int_value) + printf("int %d\n", self->int_value) elif self->kind == AstExpressionKind::Long: - printf("%lld (64-bit signed)\n", self->long_value) + printf("long %lld\n", self->long_value) elif self->kind == AstExpressionKind::Byte: - printf("%d (8-bit unsigned)\n", self->byte_value) + printf("byte %d\n", self->byte_value) elif self->kind == AstExpressionKind::Float: printf("float %s\n", self->float_or_double_text) elif self->kind == AstExpressionKind::Double: @@ -194,13 +237,9 @@ class AstExpression: for i = 0; i < self->array.length; i++: self->array.items[i].print_with_tree_printer(tp.print_prefix(i == self->array.length-1)) elif self->kind == AstExpressionKind::Call: - if self->call.uses_arrow_operator: - printf("dereference and ") - printf("call %s \"%s\"\n", self->call.function_or_method(), self->call.name) - self->call.print(tp) + self->call.print_with_tree_printer(tp) elif self->kind == AstExpressionKind::Instantiate: - printf("instantiate \"%s\"\n", self->instantiation.class_name) - self->instantiation.print(tp) + self->instantiation.print_with_tree_printer(tp) elif self->kind == AstExpressionKind::Self: printf("self\n") elif self->kind == AstExpressionKind::GetVariable: @@ -212,15 +251,13 @@ class AstExpression: self->enum_member.enum_name, ) elif self->kind == AstExpressionKind::GetClassField: + printf("get class field \"%s\"", self->class_field.field_name) if self->class_field.uses_arrow_operator: - printf("dereference and ") - printf("get class field \"%s\"\n", self->class_field.field_name) + printf(" with the arrow operator") + printf("\n") self->class_field.instance->print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstExpressionKind::As: - printf("as ") - self->as_->type.print(True) - printf("\n") - self->as_->value.print_with_tree_printer(tp.print_prefix(True)) + self->as_->print_with_tree_printer(tp) elif self->kind == AstExpressionKind::SizeOf: printf("sizeof\n") elif self->kind == AstExpressionKind::AddressOf: @@ -241,13 +278,13 @@ class AstExpression: printf("post-decrement\n") elif self->kind == AstExpressionKind::Add: printf("add\n") - elif self->kind == AstExpressionKind::Subtract: + elif self->kind == AstExpressionKind::Sub: printf("sub\n") - elif self->kind == AstExpressionKind::Multiply: + elif self->kind == AstExpressionKind::Mul: printf("mul\n") - elif self->kind == AstExpressionKind::Divide: + elif self->kind == AstExpressionKind::Div: printf("div\n") - elif self->kind == AstExpressionKind::Modulo: + elif self->kind == AstExpressionKind::Mod: printf("mod\n") elif self->kind == AstExpressionKind::Eq: printf("eq\n") @@ -266,12 +303,11 @@ class AstExpression: elif self->kind == AstExpressionKind::Or: printf("or\n") else: - printf("?????\n") + printf("????? (kind=%d)\n", self->kind as int) for i = 0; i < self->get_arity(); i++: self->operands[i].print_with_tree_printer(tp.print_prefix(i == self->get_arity()-1)) - # TODO: use this def free(self) -> None: if self->kind == AstExpressionKind::Call: self->call.free() @@ -280,6 +316,10 @@ class AstExpression: free(self->as_) elif self->kind == AstExpressionKind::String: free(self->string) + elif self->kind == AstExpressionKind::Array: + self->array.free() + elif self->kind == AstExpressionKind::Instantiate: + self->instantiation.free() elif self->kind == AstExpressionKind::GetClassField: self->class_field.free() @@ -288,7 +328,6 @@ class AstExpression: self->operands[i].free() free(self->operands) - # TODO: use this # arity = number of operands, e.g. 2 for a binary operator such as "+" def get_arity(self) -> int: if ( @@ -305,11 +344,11 @@ class AstExpression: return 1 if ( self->kind == AstExpressionKind::Add - or self->kind == AstExpressionKind::Subtract - or self->kind == AstExpressionKind::Multiply - or self->kind == AstExpressionKind::Divide + or self->kind == AstExpressionKind::Sub + or self->kind == AstExpressionKind::Mul + or self->kind == AstExpressionKind::Div or self->kind == AstExpressionKind::Indexing - or self->kind == AstExpressionKind::Modulo + or self->kind == AstExpressionKind::Mod or self->kind == AstExpressionKind::Eq or self->kind == AstExpressionKind::Ne or self->kind == AstExpressionKind::Gt @@ -322,8 +361,7 @@ class AstExpression: return 2 return 0 - # TODO: use this - def can_have_side_effects(self) -> bool: + def is_valid_as_a_statement(self) -> bool: return ( self->kind == AstExpressionKind::Call or self->kind == AstExpressionKind::PreIncr @@ -332,37 +370,40 @@ class AstExpression: or self->kind == AstExpressionKind::PostDecr ) + +# [foo, bar, baz] class AstArray: length: int items: AstExpression* - # TODO: use this def free(self) -> None: for i = 0; i < self->length; i++: self->items[i].free() free(self->items) -class AstEnumMember: - enum_name: byte[100] - member_name: byte[100] -class AstClassField: - instance: AstExpression* - uses_arrow_operator: bool # distinguishes foo.bar and foo->bar - field_name: byte[100] - - def free(self) -> None: - self->instance->free() - free(self->instance) - -class AstAsExpression: +# foo as bar +class AstAs: value: AstExpression type: AstType + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("as ") + self->type.print(True) + printf("\n") + self->value.print_with_tree_printer(tp.print_prefix(True)) + def free(self) -> None: self->value.free() self->type.free() + +# foo(arg1, arg2, arg3) +# foo.bar(arg1, arg2, arg3) +# foo->bar(arg1, arg2, arg3) class AstCall: location: Location name: byte[100] # name of function or method @@ -371,14 +412,15 @@ class AstCall: nargs: int args: AstExpression* - # Useful for formatting error messages, but not much else. - def function_or_method(self) -> byte*: - if self->method_call_self == NULL: - return "function" - else: - return "method" + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("call %s \"%s\"", self->function_or_method(), self->name) + if self->uses_arrow_operator: + printf(" with the arrow operator") + printf("\n") - def print(self, tp: TreePrinter) -> None: if self->method_call_self != NULL: sub = tp.print_prefix(self->nargs == 0) printf("self: ") @@ -394,29 +436,21 @@ class AstCall: self->args[i].free() free(self->args) -class AstInstantiation: - class_name_location: Location # TODO: probably not necessary, can use location of the instantiate expression - class_name: byte[100] - nfields: int - field_names: byte[100]* - field_values: AstExpression* + # Useful for formatting error messages, but not much else. + # TODO: use this + def function_or_method(self) -> byte*: + if self->method_call_self == NULL: + return "function" + else: + return "method" - def print(self, tp: TreePrinter) -> None: - for i = 0; i < self->nfields; i++: - sub = tp.print_prefix(i == self->nfields - 1) - printf("field \"%s\": ", self->field_names[i]) - self->field_values[i].print_with_tree_printer(sub) - - def free(self) -> None: - for i = 0; i < self->nfields; i++: - self->field_values[i].free() - free(self->field_names) - free(self->field_values) +# assert foo class AstAssertion: condition: AstExpression condition_str: byte* + enum AstStatementKind: ExpressionStatement # Evaluate an expression. Discard the result. Assert @@ -430,10 +464,10 @@ enum AstStatementKind: DeclareLocalVar # x: SomeType = y (the "= y" is optional) Assign # x = y InPlaceAdd # x += y - InPlaceSubtract # x -= y - InPlaceMultiply # x *= y - InPlaceDivide # x /= y - InPlaceModulo # x %= y + InPlaceSub # x -= y + InPlaceMul # x *= y + InPlaceDiv # x /= y + InPlaceMod # x %= y Function Class Enum @@ -475,11 +509,9 @@ class AstStatement: if self->return_value != NULL: self->return_value->print_with_tree_printer(tp.print_prefix(True)) elif self->kind == AstStatementKind::If: - printf("if\n") - self->if_statement.print(tp) + self->if_statement.print_with_tree_printer(tp) elif self->kind == AstStatementKind::ForLoop: - printf("for loop\n") - self->for_loop.print(tp) + self->for_loop.print_with_tree_printer(tp) elif self->kind == AstStatementKind::WhileLoop: printf("while loop\n") self->while_loop.print_with_tree_printer(tp, True) @@ -496,29 +528,28 @@ class AstStatement: elif self->kind == AstStatementKind::InPlaceAdd: printf("in-place add\n") self->assignment.print_with_tree_printer(tp) - elif self->kind == AstStatementKind::InPlaceSubtract: + elif self->kind == AstStatementKind::InPlaceSub: printf("in-place sub\n") self->assignment.print_with_tree_printer(tp) - elif self->kind == AstStatementKind::InPlaceMultiply: + elif self->kind == AstStatementKind::InPlaceMul: printf("in-place mul\n") self->assignment.print_with_tree_printer(tp) - elif self->kind == AstStatementKind::InPlaceDivide: + elif self->kind == AstStatementKind::InPlaceDiv: printf("in-place div\n") self->assignment.print_with_tree_printer(tp) - elif self->kind == AstStatementKind::InPlaceModulo: + elif self->kind == AstStatementKind::InPlaceMod: printf("in-place mod\n") self->assignment.print_with_tree_printer(tp) elif self->kind == AstStatementKind::Function: + # TODO: this sucks, see #514 if self->function.body.nstatements == 0: printf("declare a function: ") else: printf("define a function: ") self->function.print_with_tree_printer(tp) elif self->kind == AstStatementKind::Class: - printf("define a ") self->classdef.print_with_tree_printer(tp) elif self->kind == AstStatementKind::Enum: - printf("define ") self->enumdef.print_with_tree_printer(tp) elif self->kind == AstStatementKind::GlobalVariableDeclaration: printf("declare global var ") @@ -543,6 +574,9 @@ class AstStatement: self->if_statement.free() if self->kind == AstStatementKind::ForLoop: self->for_loop.free() + if self->kind == AstStatementKind::Class: + self->classdef.free() + # Useful for e.g. "while condition: body", "if condition: body" class AstConditionAndBody: @@ -565,6 +599,8 @@ class AstConditionAndBody: self->condition.free() self->body.free() + +# foo = bar class AstAssignment: target: AstExpression value: AstExpression @@ -576,12 +612,25 @@ class AstAssignment: self->target.print_with_tree_printer(tp.print_prefix(False)) self->value.print_with_tree_printer(tp.print_prefix(True)) + +# if foo: +# ... +# elif bar: +# ... +# elif baz: +# ... +# else: +# ... class AstIfStatement: if_and_elifs: AstConditionAndBody* n_if_and_elifs: int # At least 1 (the if statement). The rest, if any, are elifs. else_body: AstBody # Empty if there is no else - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("if\n") for i = 0; i < self->n_if_and_elifs; i++: self->if_and_elifs[i].print_with_tree_printer(tp, i == self->n_if_and_elifs - 1 and self->else_body.nstatements == 0) @@ -596,17 +645,22 @@ class AstIfStatement: free(self->if_and_elifs) self->else_body.free() + +# for init; cond; incr: +# ...body... class AstForLoop: - # for init; cond; incr: - # ...body... - # # init and incr must be pointers because this struct goes inside AstStatement. init: AstStatement* cond: AstExpression incr: AstStatement* body: AstBody - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: + printf("for loop\n") + sub = tp.print_prefix(False) printf("init: ") self->init->print_with_tree_printer(sub) @@ -631,8 +685,9 @@ class AstForLoop: free(self->incr) self->body.free() + +# name: type = value class AstNameTypeValue: - # name: type = value name: byte[100] name_location: Location type: AstType @@ -660,6 +715,8 @@ class AstNameTypeValue: self->value->free() free(self->value) + +# typically multiple indented lines after a ":" at end of line class AstBody: statements: AstStatement* nstatements: int @@ -676,6 +733,8 @@ class AstBody: self->statements[i].free() free(self->statements) + +# function name and parameters in "def" or "declare" class AstSignature: name_location: Location name: byte[100] # name of function or method, after "def" keyword @@ -710,8 +769,13 @@ class AstSignature: printf("\n") def free(self) -> None: + for i = 0; i < self->nargs; i++: + self->args[i].free() + free(self->args) self->return_type.free() + +# import "./foo.jou" class AstImport: location: Location specified_path: byte* # Path in jou code e.g. "stdlib/io.jou" @@ -721,12 +785,15 @@ class AstImport: def print(self) -> None: printf( "line %d: Import \"%s\", which resolves to \"%s\".\n", - self->location.lineno, self->specified_path, self->resolved_path) + self->location.lineno, self->specified_path, self->resolved_path, + ) def free(self) -> None: free(self->specified_path) free(self->resolved_path) + +# Represents the AST of one Jou file. class AstFile: path: byte* # not owned imports: AstImport* @@ -746,6 +813,9 @@ class AstFile: free(self->imports) self->body.free() + +# def foo() -> bar: +# ... class AstFunctionOrMethod: signature: AstSignature body: AstBody # empty body means declaration, otherwise it's a definition @@ -761,11 +831,18 @@ class AstFunctionOrMethod: self->signature.free() self->body.free() + +# union: +# foo: type1 +# bar: type2 class AstUnionFields: fields: AstNameTypeValue* nfields: int - def print(self, tp: TreePrinter) -> None: + def print(self) -> None: + self->print_with_tree_printer(TreePrinter{}) + + def print_with_tree_printer(self, tp: TreePrinter) -> None: for i = 0; i < self->nfields; i++: subprinter = tp.print_prefix(i == self->nfields-1) self->fields[i].print_with_tree_printer(&subprinter) # TODO: does this need to be optional/pointer? @@ -775,11 +852,13 @@ class AstUnionFields: self->fields[i].free() free(self->fields) + enum AstClassMemberKind: Field Union Method +# Anything that goes inside "class" class AstClassMember: kind: AstClassMemberKind union: @@ -794,7 +873,7 @@ class AstClassMember: printf("\n") elif self->kind == AstClassMemberKind::Union: printf("union:\n") - self->union_fields.print(tp) + self->union_fields.print_with_tree_printer(tp) elif self->kind == AstClassMemberKind::Method: printf("method ") self->method.signature.print() @@ -812,6 +891,9 @@ class AstClassMember: else: assert False + +# class Foo: +# ...members... class AstClassDef: name: byte[100] name_location: Location @@ -831,6 +913,11 @@ class AstClassDef: self->members[i].free() free(self->members) + +# enum Foo: +# Member1 +# Member2 +# Member3 class AstEnumDef: name: byte[100] name_location: Location diff --git a/compiler/build_cfg.jou b/compiler/build_cfg.jou index 4a1cece7..b450a56f 100644 --- a/compiler/build_cfg.jou +++ b/compiler/build_cfg.jou @@ -269,13 +269,13 @@ def build_binop( k: CfInstructionKind if op == AstExpressionKind::Add: k = CfInstructionKind::NumAdd - elif op == AstExpressionKind::Subtract: + elif op == AstExpressionKind::Sub: k = CfInstructionKind::NumSub - elif op == AstExpressionKind::Multiply: + elif op == AstExpressionKind::Mul: k = CfInstructionKind::NumMul - elif op == AstExpressionKind::Divide: + elif op == AstExpressionKind::Div: k = CfInstructionKind::NumDiv - elif op == AstExpressionKind::Modulo: + elif op == AstExpressionKind::Mod: k = CfInstructionKind::NumMod elif op == AstExpressionKind::Eq: k = CfInstructionKind::NumEq @@ -826,10 +826,10 @@ def build_expression(st: State*, expr: AstExpression*) -> LocalVariable*: add_binary_op(st, expr->location, CfInstructionKind::NumSub, zero, temp, result) elif ( expr->kind == AstExpressionKind::Add - or expr->kind == AstExpressionKind::Subtract - or expr->kind == AstExpressionKind::Multiply - or expr->kind == AstExpressionKind::Divide - or expr->kind == AstExpressionKind::Modulo + or expr->kind == AstExpressionKind::Sub + or expr->kind == AstExpressionKind::Mul + or expr->kind == AstExpressionKind::Div + or expr->kind == AstExpressionKind::Mod or expr->kind == AstExpressionKind::Eq or expr->kind == AstExpressionKind::Ne or expr->kind == AstExpressionKind::Gt @@ -1036,10 +1036,10 @@ def build_statement(st: State*, stmt: AstStatement*) -> None: elif ( stmt->kind == AstStatementKind::InPlaceAdd - or stmt->kind == AstStatementKind::InPlaceSubtract - or stmt->kind == AstStatementKind::InPlaceMultiply - or stmt->kind == AstStatementKind::InPlaceDivide - or stmt->kind == AstStatementKind::InPlaceModulo + or stmt->kind == AstStatementKind::InPlaceSub + or stmt->kind == AstStatementKind::InPlaceMul + or stmt->kind == AstStatementKind::InPlaceDiv + or stmt->kind == AstStatementKind::InPlaceMod ): targetexpr = &stmt->assignment.target rhsexpr = &stmt->assignment.value @@ -1052,14 +1052,14 @@ def build_statement(st: State*, stmt: AstStatement*) -> None: if stmt->kind == AstStatementKind::InPlaceAdd: op = AstExpressionKind::Add - elif stmt->kind == AstStatementKind::InPlaceSubtract: - op = AstExpressionKind::Subtract - elif stmt->kind == AstStatementKind::InPlaceMultiply: - op = AstExpressionKind::Multiply - elif stmt->kind == AstStatementKind::InPlaceDivide: - op = AstExpressionKind::Divide - elif stmt->kind == AstStatementKind::InPlaceModulo: - op = AstExpressionKind::Modulo + elif stmt->kind == AstStatementKind::InPlaceSub: + op = AstExpressionKind::Sub + elif stmt->kind == AstStatementKind::InPlaceMul: + op = AstExpressionKind::Mul + elif stmt->kind == AstStatementKind::InPlaceDiv: + op = AstExpressionKind::Div + elif stmt->kind == AstStatementKind::InPlaceMod: + op = AstExpressionKind::Mod else: assert False diff --git a/compiler/parser.jou b/compiler/parser.jou index 84e1cf36..ba289ee6 100644 --- a/compiler/parser.jou +++ b/compiler/parser.jou @@ -49,20 +49,20 @@ def build_operator_expression(t: Token*, arity: int, operands: AstExpression*) - result.kind = AstExpressionKind::Add elif t->is_operator("-"): if arity == 2: - result.kind = AstExpressionKind::Subtract + result.kind = AstExpressionKind::Sub else: result.kind = AstExpressionKind::Negate elif t->is_operator("*"): if arity == 2: - result.kind = AstExpressionKind::Multiply + result.kind = AstExpressionKind::Mul else: result.kind = AstExpressionKind::Dereference elif t->is_operator("/"): assert arity == 2 - result.kind = AstExpressionKind::Divide + result.kind = AstExpressionKind::Div elif t->is_operator("%"): assert arity == 2 - result.kind = AstExpressionKind::Modulo + result.kind = AstExpressionKind::Mod elif t->is_keyword("and"): assert arity == 2 result.kind = AstExpressionKind::And @@ -87,13 +87,13 @@ def determine_the_kind_of_a_statement_that_starts_with_an_expression( if this_token_is_after_that_initial_expression->is_operator("+="): return AstStatementKind::InPlaceAdd if this_token_is_after_that_initial_expression->is_operator("-="): - return AstStatementKind::InPlaceSubtract + return AstStatementKind::InPlaceSub if this_token_is_after_that_initial_expression->is_operator("*="): - return AstStatementKind::InPlaceMultiply + return AstStatementKind::InPlaceMul if this_token_is_after_that_initial_expression->is_operator("/="): - return AstStatementKind::InPlaceDivide + return AstStatementKind::InPlaceDiv if this_token_is_after_that_initial_expression->is_operator("%="): - return AstStatementKind::InPlaceModulo + return AstStatementKind::InPlaceMod return AstStatementKind::ExpressionStatement class MemberInfo: @@ -690,8 +690,8 @@ class Parser: result = self->parse_expression_with_add() while self->tokens->is_keyword("as"): as_location = (self->tokens++)->location # TODO: shouldn't need so many parentheses - p: AstAsExpression* = malloc(sizeof(*p)) - *p = AstAsExpression{type = self->parse_type(), value = result} + p: AstAs* = malloc(sizeof(*p)) + *p = AstAs{type = self->parse_type(), value = result} result = AstExpression{ location = as_location, kind = AstExpressionKind::As, @@ -781,7 +781,7 @@ class Parser: expr = self->parse_expression() result.kind = determine_the_kind_of_a_statement_that_starts_with_an_expression(self->tokens) if result.kind == AstStatementKind::ExpressionStatement: - if not expr.can_have_side_effects(): + if not expr.is_valid_as_a_statement(): fail(expr.location, "not a valid statement") result.expression = expr else: diff --git a/compiler/typecheck.jou b/compiler/typecheck.jou index 15bff2ef..8856749a 100644 --- a/compiler/typecheck.jou +++ b/compiler/typecheck.jou @@ -130,19 +130,11 @@ def evaluate_array_length(expr: AstExpression*) -> int: return expr->int_value fail(expr->location, "cannot evaluate array length at compile time") -def is_void(t: AstType*) -> bool: - return t->kind == AstTypeKind::Named and strcmp(t->name, "void") == 0 - -def is_none(t: AstType*) -> bool: - return t->kind == AstTypeKind::Named and strcmp(t->name, "None") == 0 - -def is_noreturn(t: AstType*) -> bool: - return t->kind == AstTypeKind::Named and strcmp(t->name, "noreturn") == 0 def type_from_ast(ft: FileTypes*, asttype: AstType*) -> Type*: msg: byte[500] - if is_void(asttype) or is_none(asttype) or is_noreturn(asttype): + if asttype->is_void() or asttype->is_none() or asttype->is_noreturn(): snprintf(msg, sizeof(msg), "'%s' cannot be used here because it is not a type", asttype->name) fail(asttype->location, msg) @@ -170,7 +162,7 @@ def type_from_ast(ft: FileTypes*, asttype: AstType*) -> Type*: fail(asttype->location, msg) if asttype->kind == AstTypeKind::Pointer: - if is_void(asttype->value_type): + if asttype->value_type->is_void(): return voidPtrType return get_pointer_type(type_from_ast(ft, asttype->value_type)) @@ -246,10 +238,10 @@ def handle_signature(ft: FileTypes*, astsig: AstSignature*, self_class: Type*) - sig.argtypes[i] = argtype - sig.is_noreturn = is_noreturn(&astsig->return_type) - if is_none(&astsig->return_type) or is_noreturn(&astsig->return_type): + sig.is_noreturn = astsig->return_type.is_noreturn() + if astsig->return_type.is_none() or astsig->return_type.is_noreturn(): sig.returntype = NULL - elif is_void(&astsig->return_type): + elif astsig->return_type.is_void(): fail(astsig->return_type.location, "void is not a valid return type, use '-> None' if the function does not return a value") else: sig.returntype = type_from_ast(ft, &astsig->return_type) @@ -440,10 +432,10 @@ def short_expression_description(expr: AstExpression*) -> byte*: if ( expr->kind == AstExpressionKind::Add - or expr->kind == AstExpressionKind::Subtract - or expr->kind == AstExpressionKind::Multiply - or expr->kind == AstExpressionKind::Divide - or expr->kind == AstExpressionKind::Modulo + or expr->kind == AstExpressionKind::Sub + or expr->kind == AstExpressionKind::Mul + or expr->kind == AstExpressionKind::Div + or expr->kind == AstExpressionKind::Mod or expr->kind == AstExpressionKind::Negate ): return "the result of a calculation" @@ -701,13 +693,13 @@ def check_binop( do_what: byte* if op == AstExpressionKind::Add: do_what = "add" - elif op == AstExpressionKind::Subtract: + elif op == AstExpressionKind::Sub: do_what = "subtract" - elif op == AstExpressionKind::Multiply: + elif op == AstExpressionKind::Mul: do_what = "multiply" - elif op == AstExpressionKind::Divide: + elif op == AstExpressionKind::Div: do_what = "divide" - elif op == AstExpressionKind::Modulo: + elif op == AstExpressionKind::Mod: do_what = "take remainder with" elif ( op == AstExpressionKind::Eq @@ -784,10 +776,10 @@ def check_binop( if ( op == AstExpressionKind::Add - or op == AstExpressionKind::Subtract - or op == AstExpressionKind::Multiply - or op == AstExpressionKind::Divide - or op == AstExpressionKind::Modulo + or op == AstExpressionKind::Sub + or op == AstExpressionKind::Mul + or op == AstExpressionKind::Div + or op == AstExpressionKind::Mod ): return cast_type @@ -1267,10 +1259,10 @@ def typecheck_expression(ft: FileTypes*, expr: AstExpression*) -> ExpressionType elif ( expr->kind == AstExpressionKind::Add - or expr->kind == AstExpressionKind::Subtract - or expr->kind == AstExpressionKind::Multiply - or expr->kind == AstExpressionKind::Divide - or expr->kind == AstExpressionKind::Modulo + or expr->kind == AstExpressionKind::Sub + or expr->kind == AstExpressionKind::Mul + or expr->kind == AstExpressionKind::Div + or expr->kind == AstExpressionKind::Mod or expr->kind == AstExpressionKind::Eq or expr->kind == AstExpressionKind::Ne or expr->kind == AstExpressionKind::Gt @@ -1384,10 +1376,10 @@ def typecheck_statement(ft: FileTypes*, stmt: AstStatement*) -> None: elif ( stmt->kind == AstStatementKind::InPlaceAdd - or stmt->kind == AstStatementKind::InPlaceSubtract - or stmt->kind == AstStatementKind::InPlaceMultiply - or stmt->kind == AstStatementKind::InPlaceDivide - or stmt->kind == AstStatementKind::InPlaceModulo + or stmt->kind == AstStatementKind::InPlaceSub + or stmt->kind == AstStatementKind::InPlaceMul + or stmt->kind == AstStatementKind::InPlaceDiv + or stmt->kind == AstStatementKind::InPlaceMod ): targetexpr = &stmt->assignment.target valueexpr = &stmt->assignment.value @@ -1399,17 +1391,17 @@ def typecheck_statement(ft: FileTypes*, stmt: AstStatement*) -> None: if stmt->kind == AstStatementKind::InPlaceAdd: op = AstExpressionKind::Add opname = "addition" - elif stmt->kind == AstStatementKind::InPlaceSubtract: - op = AstExpressionKind::Subtract + elif stmt->kind == AstStatementKind::InPlaceSub: + op = AstExpressionKind::Sub opname = "subtraction" - elif stmt->kind == AstStatementKind::InPlaceMultiply: - op = AstExpressionKind::Multiply + elif stmt->kind == AstStatementKind::InPlaceMul: + op = AstExpressionKind::Mul opname = "multiplication" - elif stmt->kind == AstStatementKind::InPlaceDivide: - op = AstExpressionKind::Divide + elif stmt->kind == AstStatementKind::InPlaceDiv: + op = AstExpressionKind::Div opname = "division" - elif stmt->kind == AstStatementKind::InPlaceModulo: - op = AstExpressionKind::Modulo + elif stmt->kind == AstStatementKind::InPlaceMod: + op = AstExpressionKind::Mod opname = "modulo" else: assert False