From 2eb2336a0f25a45930acdcc2a89482cebffa8ad4 Mon Sep 17 00:00:00 2001 From: jay3332 <40323796+jay3332@users.noreply.github.com> Date: Fri, 15 Mar 2024 20:23:13 -0400 Subject: [PATCH] lower function args --- codegen/src/aot.rs | 76 +++++++++++------ codegen/src/lib.rs | 19 ++--- hir/src/check.rs | 9 +- hir/src/infer.rs | 209 ++++++++++++++++++++++++++++++++------------- hir/src/lib.rs | 100 ++++++++++++++-------- hir/src/lower.rs | 152 ++++++++++++++++++++++----------- hir/src/typed.rs | 13 +-- mir/src/lib.rs | 14 +-- mir/src/lower.rs | 80 ++++++++--------- src/main.rs | 2 +- test.trb | 10 ++- 11 files changed, 431 insertions(+), 253 deletions(-) diff --git a/codegen/src/aot.rs b/codegen/src/aot.rs index 2800526b..b21a994f 100644 --- a/codegen/src/aot.rs +++ b/codegen/src/aot.rs @@ -1,6 +1,6 @@ use common::span::Spanned; -use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::{ + attributes::Attribute, basic_block::BasicBlock, builder::Builder, context::Context, @@ -11,8 +11,8 @@ use inkwell::{ IntPredicate, }; use mir::{ - BlockId, Constant, Expr, Func, IntIntrinsic, IntSign, IntWidth, LocalEnv, LocalId, Node, - PrimitiveTy, Ty, UnaryIntIntrinsic, + BlockId, Constant, Expr, Func, Ident, IntIntrinsic, IntSign, IntWidth, LocalEnv, LocalId, + LookupId, Node, PrimitiveTy, Ty, UnaryIntIntrinsic, }; use std::{collections::HashMap, mem::MaybeUninit, ops::Not}; @@ -28,11 +28,12 @@ pub struct Compiler<'a, 'ctx> { pub builder: &'a Builder<'ctx>, pub fpm: &'a PassManager>, pub module: &'a Module<'ctx>, - pub func: &'a Func, + lowering: MaybeUninit, + fn_value: MaybeUninit>, + functions: HashMap>, locals: HashMap>>, blocks: HashMap>, - fn_value: MaybeUninit>, increment: usize, } @@ -44,6 +45,11 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> { self.context.custom_width_int_type(width as usize as _) } + #[inline] + fn lowering_mut(&mut self) -> &mut Func { + unsafe { self.lowering.assume_init_mut() } + } + #[inline] const fn fn_value(&self) -> FunctionValue<'ctx> { unsafe { self.fn_value.assume_init() } @@ -178,19 +184,17 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> { BasicValueEnum::IntValue(bool_value) } Expr::Call(func, args) => { - let f = self.module.get_function(&func.to_string()).unwrap(); let args = args .into_iter() .map(|arg| self.lower_expr(arg).unwrap().into()) .collect::>(); self.builder - .build_call(f, &args, &self.next_increment()) + .build_call(self.functions[&func], &args, &self.next_increment()) .try_as_basic_value() .left() .unwrap() } - _ => todo!(), }) } @@ -249,12 +253,12 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> { /// Lowers a block given its ID. pub fn lower_block(&mut self, block_id: BlockId) { - let block = self.func.blocks.get(&block_id).unwrap(); + let block = self.lowering_mut().blocks.remove(&block_id).unwrap(); self.builder .position_at_end(*self.blocks.get(&block_id).unwrap()); for node in block { - self.lower_node(node.clone()) + self.lower_node(node) } } @@ -294,10 +298,9 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> { .create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) } - /// Compiles the specified function into an LLVM `FunctionValue`. - fn compile_fn(&mut self) { - let (names, param_tys) = self - .func + /// Registers the specified function into an LLVM `FunctionValue`. + fn register_fn(&mut self, id: LookupId, func: &Func) -> Vec { + let (names, param_tys) = func .params .iter() .filter_map(|(name, ty)| { @@ -307,18 +310,29 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> { }) .unzip::<_, _, Vec<_>, Vec<_>>(); - let fn_ty = match self.func.ret_ty.is_zst() { + let fn_ty = match func.ret_ty.is_zst() { true => self.context.void_type().fn_type(¶m_tys, false), - false => self.lower_ty(&self.func.ret_ty).fn_type(¶m_tys, false), + false => self.lower_ty(&func.ret_ty).fn_type(¶m_tys, false), }; // TODO: qualified name - let name = self.func.name.to_string(); - self.fn_value - .write(self.module.add_function(&name, fn_ty, None)); + let name = func.name.to_string(); + let fn_value = self.module.add_function(&name, fn_ty, None); + self.functions.insert(id, fn_value); + names + } + + /// Compiles the body of the given function. + fn compile_fn(&mut self, fn_value: FunctionValue<'ctx>, func: Func, names: Vec) { + let block_ids = func.blocks.keys().copied().collect::>(); + self.lowering = MaybeUninit::new(func); + self.fn_value.write(fn_value); + self.locals.clear(); + self.blocks.clear(); + self.increment = 0; // Create blocks - for id in self.func.blocks.keys() { + for id in &block_ids { let bb = self .context .append_basic_block(self.fn_value(), &id.to_string()); @@ -343,8 +357,9 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> { } // Compile body - self.func.blocks.keys().for_each(|&id| self.lower_block(id)); + block_ids.into_iter().for_each(|id| self.lower_block(id)); self.fn_value().print_to_string(); + unsafe { self.lowering.assume_init_drop() }; // Verify and run optimizations if self.fn_value().verify(true) { @@ -362,20 +377,27 @@ impl<'a, 'ctx> Compiler<'a, 'ctx> { builder: &'a Builder<'ctx>, pass_manager: &'a PassManager>, module: &'a Module<'ctx>, - func: &'a Func, - ) -> FunctionValue<'ctx> { + functions: HashMap, + ) { let mut compiler = Self { context, builder, fpm: pass_manager, module, - func, + functions: HashMap::with_capacity(functions.len()), + lowering: MaybeUninit::uninit(), fn_value: MaybeUninit::uninit(), locals: HashMap::new(), - blocks: HashMap::with_capacity(func.blocks.len()), + blocks: HashMap::new(), increment: 0, }; - compiler.compile_fn(); - compiler.fn_value() + + let mut names = Vec::with_capacity(functions.len()); + for (id, func) in &functions { + names.push(compiler.register_fn(*id, func)); + } + for ((id, func), names) in functions.into_iter().zip(names) { + compiler.compile_fn(compiler.functions[&id], func, names); + } } } diff --git a/codegen/src/lib.rs b/codegen/src/lib.rs index 212d42bc..2fe4d401 100644 --- a/codegen/src/lib.rs +++ b/codegen/src/lib.rs @@ -11,14 +11,11 @@ pub use inkwell::{ }; use inkwell::passes::PassManager; -use mir::{Mir, ModuleId}; +use mir::{Func, LookupId}; +use std::collections::HashMap; -pub fn compile_llvm<'ctx>( - context: &'ctx Context, - mir: &Mir, - module_id: ModuleId, /*, options: CompileOptions*/ -) -> Module<'ctx> { - let module = context.create_module(&module_id.to_string()); +pub fn compile_llvm(context: &Context, functions: HashMap) -> Module { + let module = context.create_module("root"); let builder = context.create_builder(); // Create FPM @@ -33,12 +30,6 @@ pub fn compile_llvm<'ctx>( fpm.add_reassociate_pass(); fpm.initialize(); - for func in mir - .functions - .iter() - .filter_map(|(id, func)| id.0.eq(&module_id).then_some(func)) - { - aot::Compiler::compile(&context, &builder, &fpm, &module, func); - } + aot::Compiler::compile(&context, &builder, &fpm, &module, functions); module } diff --git a/hir/src/check.rs b/hir/src/check.rs index d02f130b..b9db107c 100644 --- a/hir/src/check.rs +++ b/hir/src/check.rs @@ -8,7 +8,7 @@ use crate::{ BinaryIntIntrinsic, BoolIntrinsic, Constraint, Expr, IntIntrinsic, LocalEnv, Relation, Ty, TypedExpr, UnaryIntIntrinsic, UnificationTable, }, - Hir, IntSign, ModuleId, Node, Op, Pattern, PrimitiveTy, Scope, ScopeId, + Hir, IntSign, Lookup, ModuleId, Node, Op, Pattern, PrimitiveTy, Scope, ScopeId, }; use common::span::{Spanned, SpannedExt}; @@ -447,8 +447,11 @@ impl<'a> TypeChecker<'a> { .expect("scope not found"); // Substitute over all functions in the scope - for (_, func) in &mut scope.funcs { - self.substitute_scope(module, func.body, table); + for (_, &Lookup(_, id)) in &scope.items { + let scope = self.thir_mut().funcs[&id].body; + self.substitute_scope(module, scope, table); + + let func = self.thir_mut().funcs.get_mut(&id).unwrap(); func.header.ret_ty.apply(&table.substitutions); } diff --git a/hir/src/infer.rs b/hir/src/infer.rs index 3b6b1409..97bce94b 100644 --- a/hir/src/infer.rs +++ b/hir/src/infer.rs @@ -5,8 +5,9 @@ use crate::{ UnificationTable, }, warning::Warning, - Expr, FloatWidth, Func, FuncHeader, FuncParam, Hir, Ident, IntSign, IntWidth, ItemId, Literal, - LogicalOp, Metadata, ModuleId, Node, Pattern, PrimitiveTy, ScopeId, TyParam, + Expr, FloatWidth, Func, FuncHeader, FuncParam, Hir, Ident, IntSign, IntWidth, ItemId, ItemKind, + Literal, LogicalOp, Lookup, LookupId, Metadata, ModuleId, Node, Pattern, PrimitiveTy, ScopeId, + TyParam, }; use common::span::{Span, Spanned, SpannedExt}; use std::{borrow::Cow, collections::HashMap}; @@ -76,10 +77,10 @@ impl UnificationTable { pub enum BindingKind { Var, Param, - Func, + Func(LookupId), } -struct Binding { +pub struct Binding { pub def_span: Span, pub ty: Ty, pub mutable: Option, @@ -131,7 +132,7 @@ pub struct Scope { kind: ScopeKind, ty_params: Vec>, locals: HashMap<(Ident, LocalEnv), Local>, - funcs: HashMap>, + funcs: HashMap)>, label: Option>, exited: Option, } @@ -198,6 +199,7 @@ impl TypeLowerer { scope_id: ScopeId, module: ModuleId, kind: ScopeKind, + bindings: Vec<(Ident, Binding)>, ty_params: Vec>, label: Option>, resolution: Option, @@ -209,7 +211,10 @@ impl TypeLowerer { module_id: module, kind, ty_params, - locals: HashMap::new(), + locals: bindings + .into_iter() + .map(|(ident, binding)| ((ident, self.local_env), Local::from_binding(binding))) + .collect(), funcs: HashMap::new(), label, exited: None, @@ -340,26 +345,25 @@ impl TypeLowerer { }); } - // Does a constant with this name exist? let item = ItemId(self.scope().module_id, *ident.value()); - if let Some(cnst) = self - .hir - .scopes - .get(&self.scope().id) - .and_then(|scope| scope.consts.get(&item)) - { - return Ok(Binding { - def_span: cnst.name.span(), - ty: self.lower_hir_ty(cnst.ty.clone()), - mutable: None, - initialized: true, - kind: BindingKind::Var, - }); - } - - // Does a function with this name exist? + // Does an item with this name exist? for scope in self.scopes.iter().rev() { - if let Some(header) = scope.funcs.get(&item).cloned() { + if let Some(cnst) = self.hir.scopes.get(&scope.id).and_then(|scope| { + scope + .items + .get(&item) + .and_then(|Lookup(_, id)| self.hir.consts.get(id)) + }) { + return Ok(Binding { + def_span: cnst.name.span(), + ty: self.lower_hir_ty(cnst.ty.clone()), + mutable: None, + initialized: true, + kind: BindingKind::Var, + }); + } + + if let Some((id, header)) = scope.funcs.get(&item).cloned() { return Ok(Binding { def_span: header.name.span(), ty: Ty::Func( @@ -372,7 +376,7 @@ impl TypeLowerer { ), mutable: None, initialized: true, - kind: BindingKind::Func, + kind: BindingKind::Func(id), }); } } @@ -384,7 +388,14 @@ impl TypeLowerer { fn lower_exit_in_context(&mut self, scope_id: ScopeId, divergent: bool) -> Result { let label = self.hir.scopes.get(&scope_id).unwrap().label.clone(); Ok( - match self.lower_scope(scope_id, ScopeKind::Block, Vec::new(), divergent, None)? { + match self.lower_scope( + scope_id, + ScopeKind::Block, + Vec::new(), + Vec::new(), + divergent, + None, + )? { // Does it exit *just* itself? If so, return the type of the expression ExitAction::FromBlock(None, ty, _) => ty, ExitAction::FromBlock(Some(lbl), ty, _) @@ -423,10 +434,9 @@ impl TypeLowerer { BindingKind::Var | BindingKind::Param => { TypedExpr(typed::Expr::Local(ident, args, self.local_env), binding.ty) } - BindingKind::Func => TypedExpr( - typed::Expr::Func(ident, args, ItemId(self.scope().module_id, ident.0)), - binding.ty, - ), + BindingKind::Func(id) => { + TypedExpr(typed::Expr::Func(ident, args, id), binding.ty) + } } } Expr::Tuple(exprs) => { @@ -570,18 +580,22 @@ impl TypeLowerer { } Expr::Loop(scope_id) => { let label = self.hir.scopes.get(&scope_id).unwrap().label.clone(); - let ty = - match self.lower_scope(scope_id, ScopeKind::Loop, Vec::new(), true, None)? { - // Are we exiting from just the loop? - ExitAction::FromBlock(Some(lbl), ty, _) - if label.is_some_and(|l| l == lbl) => - { - ty - } - ExitAction::FromNearestLoop(ty, _) => ty, - // Otherwise we are exiting further out from the loop. - r => Ty::Exit(Box::new(r)), - }; + let ty = match self.lower_scope( + scope_id, + ScopeKind::Loop, + Vec::new(), + Vec::new(), + true, + None, + )? { + // Are we exiting from just the loop? + ExitAction::FromBlock(Some(lbl), ty, _) if label.is_some_and(|l| l == lbl) => { + ty + } + ExitAction::FromNearestLoop(ty, _) => ty, + // Otherwise we are exiting further out from the loop. + r => Ty::Exit(Box::new(r)), + }; TypedExpr(typed::Expr::Loop(scope_id), ty) } Expr::CallOp(op, lhs, rhs) => TypedExpr( @@ -603,7 +617,7 @@ impl TypeLowerer { .map(|expr| self.lower_expr(expr)) .collect::>>()?; - let (arg_tys, return_ty) = match callee_ty { + let (arg_tys, mut return_ty) = match callee_ty { Ty::Func(args, ret_ty) => (Some(args), *ret_ty), Ty::Unknown(i) => { let arg_tys = args @@ -639,8 +653,7 @@ impl TypeLowerer { .constraints .push_back(Constraint(arg.value().1.clone(), ty.clone())); - let conflict = self.table.unify_all(); - if let Some(conflict) = conflict { + if let Some(conflict) = self.table.unify_all() { self.err_nonfatal(Error::TypeConflict { expected: (ty.into(), None), actual: arg.as_ref().map(|expr| expr.1.clone().into()), @@ -656,12 +669,27 @@ impl TypeLowerer { // Deduce the type of call (is it an intrinsic, function/method call, or something else?) let expr = match callee.value().0 { - typed::Expr::Func(_, _, item) => typed::Expr::CallFunc { - parent: None, - func: item, - args, - kwargs: Vec::new(), - }, + typed::Expr::Func(_, _, item) => { + let header = &self.thir.funcs[&item].header; + self.table + .constraints + .push_back(Constraint(header.ret_ty.clone(), return_ty.clone())); + if let Some(conflict) = self.table.unify_all() { + self.err_nonfatal(Error::TypeConflict { + expected: (header.ret_ty.clone(), header.ret_ty_span), + actual: Spanned(return_ty.clone().into(), span), + constraint: conflict, + }); + } + return_ty.apply(&self.table.substitutions); + + typed::Expr::CallFunc { + parent: None, + func: item, + args, + kwargs: Vec::new(), + } + } _ => unimplemented!(), }; @@ -939,9 +967,16 @@ impl TypeLowerer { ret_ty_span: header.ret_ty_span, }; + let mut bindings = Vec::new(); + for param in &header.params { + if let Err(why) = flatten_param(¶m.pat, param.ty.clone(), &mut bindings) { + self.err_nonfatal(why); + } + } self.lower_scope( func.body, ScopeKind::Func, + bindings, header.ty_params.clone(), true, Some((header.ret_ty.clone(), header.ret_ty_span)), @@ -965,6 +1000,7 @@ impl TypeLowerer { &mut self, scope_id: ScopeId, kind: ScopeKind, + bindings: Vec<(Ident, Binding)>, ty_params: Vec>, divergent: bool, resolution: Option, @@ -975,16 +1011,22 @@ impl TypeLowerer { scope_id, scope.module_id, kind, + bindings, ty_params, label, resolution, ); - let mut funcs = HashMap::new(); - for (name, func) in scope.funcs.drain() { + let mut items = HashMap::new(); + for (name, lookup @ Lookup(_, id)) in scope.items.extract_if(|_, l| l.0 == ItemKind::Func) { + let func = self.hir.funcs.remove(&id).expect("func not found"); let func = self.lower_func(func)?; - self.scope_mut().funcs.insert(name, func.header.clone()); - funcs.insert(name, func); + // register the function in the scope + self.scope_mut() + .funcs + .insert(name, (id, func.header.clone())); + self.thir.funcs.insert(id, func); + items.insert(name, lookup); } let mut exit_action = None; @@ -1012,11 +1054,7 @@ impl TypeLowerer { label, decorators: scope.decorators, children: lowered.spanned(full_span), - funcs, - aliases: HashMap::new(), - consts: HashMap::new(), - structs: HashMap::new(), - types: HashMap::new(), + items, }, ); let exit_action = @@ -1051,7 +1089,14 @@ impl TypeLowerer { let scope_id = *self.hir.modules.get(&module_id).unwrap(); // TODO: exit action can be non-standard in things like REPLs - match self.lower_scope(scope_id, ScopeKind::Block, Vec::new(), true, None)? { + match self.lower_scope( + scope_id, + ScopeKind::Block, + Vec::new(), + Vec::new(), + true, + None, + )? { _ => (), // TODO: For packaged code, check exit type is void } self.exit_scope_if_exists(); @@ -1094,6 +1139,48 @@ impl ExitAction { } } +/// Flattens a pattern parameter into a list of bindings to prepare for MIR lowering. +/// +/// # Example +/// ```terbium +/// func sum_tuple((a, b): (int, int)) = a + b; +/// +/// // Lowered to: +/// func sum_tuple(a: int, b: int) = a + b; +/// ``` +pub fn flatten_param( + pat: &Spanned, + ty: Ty, + bindings: &mut Vec<(Ident, Binding)>, +) -> Result<()> { + match (pat.value(), ty) { + (Pattern::Ident { ident, mut_kw }, ty) => bindings.push(( + ident.0, + Binding { + def_span: ident.1, + ty, + mutable: mut_kw.clone(), + initialized: true, + kind: BindingKind::Param, + }, + )), + (Pattern::Tuple(pats), Ty::Tuple(tys)) => { + if pats.len() != tys.len() { + return Err( + pat_errors::tuple_len_mismatch(pats.len().spanned(pat.span()), tys.len(), None) + ); + } + for (pat, ty) in pats.iter().zip(tys) { + flatten_param(pat, ty, bindings)?; + } + } + (Pattern::Tuple(_), ty) => { + return Err(pat_errors::tuple_mismatch(pat.span(), ty, None)); + } + } + Ok(()) +} + pub mod pat_errors { use super::*; diff --git a/hir/src/lib.rs b/hir/src/lib.rs index 833805ec..11904ad3 100644 --- a/hir/src/lib.rs +++ b/hir/src/lib.rs @@ -12,6 +12,7 @@ //! type checking and desugaring with the knowledge of the types of all expressions. This lowering //! is performed by [`TypeChecker`]. +#![feature(hash_extract_if)] #![feature(let_chains)] #![feature(more_qualified_paths)] @@ -92,7 +93,11 @@ impl From for ModuleId { } } -/// The ID of a top-level item. +/// Global item lookup ID. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct LookupId(pub usize); + +/// The ID of a top-level or order-agnostic item. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct ItemId( /// The module in which the item is defined. @@ -147,6 +152,16 @@ pub struct Hir { pub modules: HashMap, /// A mapping of all lexical scopes within the program. pub scopes: HashMap>, + /// A mapping of all functions in the program. + pub funcs: HashMap>, + /// A mapping of all aliases in the program. + pub aliases: HashMap>, + /// A mapping of all constants in the program. + pub consts: HashMap>, + /// A mapping of all raw structs within the program. + pub structs: HashMap>, + /// A mapping of all types within the program. + pub types: HashMap>, } impl Default for Hir { @@ -154,6 +169,11 @@ impl Default for Hir { Self { modules: HashMap::new(), scopes: HashMap::new(), + funcs: HashMap::new(), + aliases: HashMap::new(), + consts: HashMap::new(), + structs: HashMap::new(), + types: HashMap::new(), } } } @@ -239,6 +259,18 @@ pub enum Node { ImplicitReturn(Spanned), } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ItemKind { + Func, + Alias, + Const, + Struct, + Type, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Lookup(pub ItemKind, pub LookupId); + #[derive(Clone, Debug)] pub struct Scope { /// The module in which this scope is defined. @@ -249,16 +281,8 @@ pub struct Scope { pub label: Option>, /// The children of this scope. pub children: Spanned>>>, - /// A mapping of all top-level functions in the scope. - pub funcs: HashMap>, - /// A mapping of all aliases in the scope. - pub aliases: HashMap>, - /// A mapping of all constants in the scope. - pub consts: HashMap>, - /// A mapping of all raw structs within the scope. - pub structs: HashMap>, - /// A mapping of all types within the scope. - pub types: HashMap>, + /// A lookup of all items in the scope. + pub items: HashMap, } impl Scope { @@ -277,13 +301,21 @@ impl Scope { decorators: Vec::new(), label, children, - funcs: HashMap::new(), - aliases: HashMap::new(), - consts: HashMap::new(), - structs: HashMap::new(), - types: HashMap::new(), + items: HashMap::new(), } } + + #[inline] + #[must_use] + pub fn lookup_id(&self, id: ItemId) -> Option { + self.items.get(&id).map(|lookup| lookup.1) + } + + #[inline] + #[must_use] + pub(crate) fn lookup_id_or_panic(&self, id: ItemId) -> LookupId { + self.items[&id].1 + } } #[derive(Clone, Debug)] @@ -312,9 +344,9 @@ fn assert_equal_params_length( ty_name: Spanned, ty_params_len: usize, ty_args_len: usize, -) -> Result<(), error::Error> { +) -> Result<(), Error> { if ty_args_len != ty_params_len { - return Err(error::Error::IncorrectTypeArgumentCount { + return Err(Error::IncorrectTypeArgumentCount { span, ty: ty_name.as_ref().map(ToString::to_string), expected: ty_params_len, @@ -332,7 +364,7 @@ pub struct TyDef { } impl> TyDef { - pub fn apply_params(&self, span: Span, params: Vec) -> Result { + pub fn apply_params(&self, span: Span, params: Vec) -> Result { assert_equal_params_length( span, self.name, @@ -396,7 +428,7 @@ impl Display for FieldVisibility { impl FieldVisibility { pub fn from_ast(v: Spanned) -> error::Result { if v.0.get.0 < v.0.set.0 { - Err(error::Error::GetterLessVisibleThanSetter(v)) + Err(Error::GetterLessVisibleThanSetter(v)) } else { Ok(Self { get: v.0.get.0, @@ -700,7 +732,7 @@ impl StructTy { self, span: Option, params: Vec, - ) -> Result { + ) -> Result { assert_equal_params_length( span.unwrap_or(self.name.span()), self.name, @@ -947,22 +979,18 @@ where for decorator in &scope.decorators { format!("@!{decorator}").write_indent(f)?; } - for (item, func) in &scope.funcs { - writeln!(f, "{item}:")?; - WithHir(func, self).write_indent(f)?; - } - for (item, alias) in &scope.aliases { - writeln!(f, "{item}:")?; - WithHir(alias, self).write_indent(f)?; - } - for (item, cnst) in &scope.consts { - writeln!(f, "{item}:")?; - WithHir(cnst, self).write_indent(f)?; - } - for (item, strct) in &scope.structs { - writeln!(f, "{item}:")?; - WithHir(strct, self).write_indent(f)?; + + for (item_id, Lookup(kind, lookup)) in &scope.items { + writeln!(f, "{item_id}:")?; + match kind { + ItemKind::Func => WithHir(&self.funcs[&lookup], self).write_indent(f)?, + ItemKind::Alias => WithHir(&self.aliases[&lookup], self).write_indent(f)?, + ItemKind::Const => WithHir(&self.consts[&lookup], self).write_indent(f)?, + ItemKind::Struct => WithHir(&self.structs[&lookup], self).write_indent(f)?, + ItemKind::Type => continue, + } } + for line in scope.children.value() { WithHir(line, self).write_indent(f)?; } diff --git a/hir/src/lower.rs b/hir/src/lower.rs index 1315a2a8..01a75579 100644 --- a/hir/src/lower.rs +++ b/hir/src/lower.rs @@ -1,8 +1,8 @@ use crate::{ error::{Error, Result}, Const, Decorator, Expr, FieldVisibility, FloatWidth, Func, FuncHeader, FuncParam, Hir, Ident, - IntSign, IntWidth, ItemId, Literal, LogicalOp, ModuleId, Node, Op, Pattern, PrimitiveTy, Scope, - ScopeId, StructField, StructTy, Ty, TyDef, TyParam, + IntSign, IntWidth, ItemId, ItemKind, Literal, LogicalOp, Lookup, LookupId, ModuleId, Node, Op, + Pattern, PrimitiveTy, Scope, ScopeId, StructField, StructTy, Ty, TyDef, TyParam, }; use common::span::{Span, Spanned, SpannedExt}; use grammar::{ @@ -27,6 +27,8 @@ pub struct AstLowerer { sty_needs_field_resolution: HashMap, Vec)>, /// Build up outer decorators to apply to the next item. outer_decorators: Vec>, + /// Increment of global lookup ID. + lookup_id: usize, /// The HIR being constructed. pub hir: Hir, /// Non-fatal errors that occurred during lowering. @@ -84,6 +86,14 @@ fn ty_params_into_unbounded_ty_param(ty_params: &[ast::TyParam]) -> Vec .collect() } +macro_rules! insert_lookup { + ($self:ident, $target:ident, $kind:ident, $e:expr) => {{ + let id = $self.next_lookup_id(); + $self.hir.$target.insert(id, $e); + Lookup(ItemKind::$kind, id) + }}; +} + impl AstLowerer { /// Creates a new AST lowerer. pub fn new(root: Vec>) -> Self { @@ -91,6 +101,7 @@ impl AstLowerer { module_nodes: HashMap::from([(ModuleId::root(), root)]), sty_needs_field_resolution: HashMap::new(), outer_decorators: Vec::new(), + lookup_id: 0, hir: Hir::default(), errors: Vec::new(), } @@ -119,6 +130,17 @@ impl AstLowerer { } } + #[inline] + fn get_item_name(&self, Lookup(kind, lookup): &Lookup) -> Option> { + match kind { + ItemKind::Func => self.hir.funcs.get(lookup).map(|func| func.header.name), + ItemKind::Alias => self.hir.aliases.get(lookup).map(|alias| alias.name), + ItemKind::Const => self.hir.consts.get(lookup).map(|cnst| cnst.name), + ItemKind::Struct => self.hir.structs.get(lookup).map(|sty| sty.name), + ItemKind::Type => self.hir.types.get(lookup).map(|ty| ty.name), + } + } + /// Asserts the item ID is unique. #[inline] pub fn assert_item_unique( @@ -127,16 +149,12 @@ impl AstLowerer { item: &ItemId, src: Spanned, ) -> Result<()> { - let occupied = - if let Some(occupied) = scope.structs.get(item) { - Some(occupied.name.span()) - } else if let Some(occupied) = scope.consts.get(item) { - Some(occupied.name.span()) - } else { - None - }; - if let Some(occupied) = occupied { - return Err(Error::NameConflict(occupied, src)); + if let Some(occupied) = scope + .items + .get(item) + .and_then(|item| self.get_item_name(item)) + { + return Err(Error::NameConflict(occupied.span(), src)); } Ok(()) } @@ -154,6 +172,13 @@ impl AstLowerer { self.scope_ctx(*scope_id) } + #[inline] + pub fn next_lookup_id(&mut self) -> LookupId { + let id = self.lookup_id; + self.lookup_id += 1; + LookupId(id) + } + /// Completely performs a lowering pass over a module. pub fn resolve_module(&mut self, module: ModuleId, span: Span) -> Result<()> { // SAFETY: `children` is set later in this function. @@ -174,17 +199,23 @@ impl AstLowerer { /// Perform a pass over the AST to simply resolve all top-level types. pub fn resolve_types(&mut self, module: ModuleId, scope: &mut Scope) -> Result<()> { - let nodes = self.module_nodes.get(&module).expect("module not found"); + let nodes = self + .module_nodes + .get(&module) + .cloned() + .expect("module not found"); // Do a pass over all types to identify them - for node in nodes { + for node in &nodes { if let Some((item_id, ty_def)) = self.pass_over_ty_def(module, node)? { - scope.types.insert(item_id, ty_def); + scope + .items + .insert(item_id, insert_lookup!(self, types, Type, ty_def)); } } // Do a second pass to register and resolve types - for Spanned(node, _) in nodes.clone() { + for Spanned(node, _) in nodes { match node { ast::Node::Struct(sct) => { let sct_name = sct.name.clone(); @@ -193,11 +224,14 @@ impl AstLowerer { let sty = self.lower_struct_def_into_ty(module, sct.clone(), scope)?; // Update type parameters with their bounds - if let Some(ty_def) = scope.types.get_mut(&item_id) { + if let Some(ty_def) = self.hir.types.get_mut(&scope.lookup_id_or_panic(item_id)) + { ty_def.ty_params = sty.ty_params.clone(); } self.propagate_nonfatal(self.assert_item_unique(scope, &item_id, sct_name)); - scope.structs.insert(item_id, sty); + scope + .items + .insert(item_id, insert_lookup!(self, structs, Struct, sty)); } _ => (), } @@ -236,9 +270,10 @@ impl AstLowerer { return Err(Error::CircularTypeReference(cycle)); } - let fields = scope + let fields = self + .hir .structs - .get(&pid) + .get(&scope.lookup_id_or_panic(pid)) .cloned() .expect("struct not found, this is a bug") .into_adhoc_struct_ty_with_applied_ty_params(Some(dest.span()), args)? @@ -246,9 +281,10 @@ impl AstLowerer { removed.push(sid); for child in &removed { - let sty = scope + let sty = self + .hir .structs - .get_mut(&child) + .get_mut(&scope.lookup_id_or_panic(*child)) .expect("struct not found, this is a bug"); let mut fields = fields.clone(); @@ -289,7 +325,9 @@ impl AstLowerer { }; let item = ItemId(module, *ident.value()); self.propagate_nonfatal(self.assert_item_unique(scope, &item, name)); - scope.consts.insert(item, cnst); + scope + .items + .insert(item, insert_lookup!(self, consts, Const, cnst)); } } Ok(()) @@ -409,8 +447,12 @@ impl AstLowerer { /// struct A<__0> { a: __0 } /// ``` fn desugar_inferred_types_in_structs(&mut self, scope: &mut Scope) { - for sty in scope.structs.values_mut() { - // Desugar inference type into generics that will be inferred anyways + for Lookup(kind, id) in scope.items.values() { + if *kind != ItemKind::Struct { + continue; + } + let sty = self.hir.structs.get_mut(id).unwrap(); + // Desugar inference type into generics that will be inferred anyway for (i, ty) in sty .fields .iter_mut() @@ -500,13 +542,17 @@ impl AstLowerer { ModuleId(Intern::from_ref(ty_module.as_slice())) }; - let lookup = ItemId(mid, ident); - let ty_def = ctx + let err = Error::TypeNotFound(full_span, Spanned(tail, span), mid); + let Lookup(kind, id) = ctx .scope - .types - .get(&lookup) - .cloned() - .ok_or(Error::TypeNotFound(full_span, Spanned(tail, span), mid))?; + .items + .get(&ItemId(mid, ident)) + .ok_or(err.clone())?; + + if *kind != ItemKind::Type { + return Err(err.clone()); + } + let ty_def = self.hir.types.get(id).ok_or(err)?; let ty_params = match application { Some(app) => app @@ -595,7 +641,9 @@ impl AstLowerer { }; let item = ItemId(module, *ident.value()); self.propagate_nonfatal(self.assert_item_unique(ctx.scope, &item, name)); - scope.funcs.insert(item, func); + scope + .items + .insert(item, insert_lookup!(self, funcs, Func, func)); } } Ok(()) @@ -1095,27 +1143,29 @@ impl AstLowerer { app: Option, ) -> Result> { let item_id = ItemId(ctx.module(), *ident.value()); - Ok(if let Some(cnst) = ctx.scope.consts.get(&item_id) { - // TODO: true const-eval instead of inline (this will be replaced by `alias`) - if let Some(app) = app { - return Err(Error::ExplicitTypeArgumentsNotAllowed(app.span())); - } - cnst.value.clone() - } else { - Expr::Ident( - ident, - app.map(|app| { + Ok( + if let Some(Lookup(ItemKind::Const, id)) = ctx.scope.items.get(&item_id) { + // TODO: true const-eval instead of inline (this will be replaced by `alias`) + if let Some(app) = app { + return Err(Error::ExplicitTypeArgumentsNotAllowed(app.span())); + } + self.hir.consts[id].value.clone() + } else { + Expr::Ident( + ident, app.map(|app| { - app.into_iter() - .map(|ty| self.lower_ty(ctx, ty.into_value())) - .collect::>() + app.map(|app| { + app.into_iter() + .map(|ty| self.lower_ty(ctx, ty.into_value())) + .collect::>() + }) + .transpose() }) - .transpose() - }) - .transpose()?, - ) - .spanned(ident.span()) - }) + .transpose()?, + ) + .spanned(ident.span()) + }, + ) } #[inline] diff --git a/hir/src/typed.rs b/hir/src/typed.rs index 987f166c..0f334bc8 100644 --- a/hir/src/typed.rs +++ b/hir/src/typed.rs @@ -1,7 +1,7 @@ use crate::{ infer::{ExitAction, InferMetadata}, - Ident, IntSign, IntWidth, Intrinsic, ItemId, Literal, Node, Op, Pattern, PrimitiveTy, ScopeId, - StaticOp, WithHir, + Ident, IntSign, IntWidth, Intrinsic, ItemId, Literal, LookupId, Node, Op, Pattern, PrimitiveTy, + ScopeId, StaticOp, WithHir, }; use common::span::Spanned; use std::{ @@ -127,7 +127,7 @@ pub enum BoolIntrinsic { pub enum Expr { Literal(Literal), Local(Spanned, Option>>, LocalEnv), - Func(Spanned, Option>>, ItemId), + Func(Spanned, Option>>, LookupId), Type(Spanned), Tuple(Vec), Array(Vec), @@ -135,7 +135,7 @@ pub enum Expr { BoolIntrinsic(BoolIntrinsic), Intrinsic(Intrinsic, Vec), CallFunc { - func: ItemId, + func: LookupId, parent: Option, args: Vec, kwargs: Vec<(Ident, E)>, @@ -222,10 +222,11 @@ impl Display for WithHir<'_, TypedExpr, InferMetadata> { write!( f, - "{}{func}({args})", + "{}{}({args})", parent .as_ref() - .map_or_else(String::new, |ty| format!("{ty}.")) + .map_or_else(String::new, |ty| format!("{ty}.")), + self.1.funcs[&func].header.name, ) } Expr::CallOp(op, slf, args) => { diff --git a/mir/src/lib.rs b/mir/src/lib.rs index 1cbfe9af..1a478f1b 100644 --- a/mir/src/lib.rs +++ b/mir/src/lib.rs @@ -26,12 +26,12 @@ pub use hir::{ BinaryIntIntrinsic, LocalEnv, Ty, /* TODO: monomorphize types */ UnaryIntIntrinsic, }, - IntSign, IntWidth, ModuleId, PrimitiveTy, + Ident, IntSign, IntWidth, LookupId, ModuleId, PrimitiveTy, }; pub use lower::Lowerer; use common::span::Spanned; -use hir::{FloatWidth, Ident, ItemId, ScopeId}; +use hir::{FloatWidth, ItemId, ScopeId}; use indexmap::IndexMap; use std::{ collections::HashMap, @@ -95,14 +95,14 @@ impl LocalId { #[derive(Clone, Debug, Default)] pub struct Mir { /// All procedures, including the top-level and main procedures. - pub functions: HashMap, + pub functions: HashMap, } impl Display for Mir { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.functions - .values() - .map(|func| writeln!(f, "{func}")) + .iter() + .map(|(id, func)| writeln!(f, "[#{}] {func}", id.0)) .collect() } } @@ -216,7 +216,7 @@ pub enum Expr { Local(LocalId), IntIntrinsic(IntIntrinsic, IntSign, IntWidth), BoolIntrinsic(BoolIntrinsic), - Call(ItemId, Vec>), + Call(LookupId, Vec>), } impl Display for Expr { @@ -229,7 +229,7 @@ impl Display for Expr { } Self::BoolIntrinsic(b) => write!(f, "{b}"), Self::Call(i, args) => { - write!(f, "{}(", i)?; + write!(f, "[#{}](", i.0)?; write_comma_sep(f, args.iter())?; write!(f, ")") } diff --git a/mir/src/lower.rs b/mir/src/lower.rs index a8e6af0a..05337525 100644 --- a/mir/src/lower.rs +++ b/mir/src/lower.rs @@ -5,10 +5,11 @@ use crate::{ Node, TypedHir, }; use common::span::{Spanned, SpannedExt}; +use hir::infer::flatten_param; use hir::{ - pat_errors, typed, - typed::{LocalEnv, Ty, TypedExpr}, - Ident, IntSign, ItemId, Literal, ModuleId, Pattern, PrimitiveTy, ScopeId, + typed::{self, LocalEnv, Ty, TypedExpr}, + Ident, IntSign, ItemId, ItemKind, Literal, Lookup, LookupId, ModuleId, Pattern, PrimitiveTy, + ScopeId, }; use std::collections::HashMap; @@ -21,6 +22,8 @@ pub struct Lowerer { pub mir: Mir, /// Non-fatal errors that occurred during lowering. pub errors: Vec, + /// Cache of how to flatten calls to functions. + func_pats: HashMap>>, } pub struct Ctx<'a> { @@ -74,6 +77,7 @@ impl Lowerer { thir, mir: Mir::default(), errors: Vec::new(), + func_pats: HashMap::new(), } } @@ -102,7 +106,7 @@ impl Lowerer { ctx: &mut Ctx, Spanned(TypedExpr(expr, ty), span): Spanned, ) -> Spanned { - type HirExpr = hir::typed::Expr; + type HirExpr = typed::Expr; // SAFETY: the context is valid for the duration of the function let bctx = unsafe { &mut *ctx.bctx }; @@ -288,8 +292,13 @@ impl Lowerer { Expr::Local(result_local) } HirExpr::CallFunc { func, args, .. } => { - // TODO: flatten args - let args = args + let params = &self.func_pats[&func]; + let mut flattened = Vec::new(); + for (pat, arg) in params.iter().zip(args) { + apply_arg(pat, arg, &mut flattened).unwrap(); + } + + let args = flattened .into_iter() .map(|arg| self.lower_expr(ctx, arg)) .collect(); @@ -316,8 +325,22 @@ impl Lowerer { let scope = self.thir.scopes.remove(&scope).expect("no such scope"); // First, lower all static items in the scope - for (id, func) in scope.funcs { - let func = self.lower_func(id, func); + let mut funcs = Vec::new(); + for (item, Lookup(kind, id)) in scope.items { + match kind { + ItemKind::Func => { + let func = self.thir.funcs.remove(&id).expect("no such func"); + self.func_pats.insert( + id, + func.header.params.iter().map(|p| &p.pat).cloned().collect(), + ); + funcs.push((id, item, func)); + } + _ => continue, // TODO + } + } + for (id, item, func) in funcs { + let func = self.lower_func(item, func); self.mir.functions.insert(id, func); } @@ -442,7 +465,7 @@ impl Lowerer { } Func { name: item_id, - params, + params: params.into_iter().map(|(i, b)| (i, b.ty)).collect(), ret_ty: header.ret_ty, blocks: self.lower_map(body, true), } @@ -451,11 +474,11 @@ impl Lowerer { pub fn lower_module(&mut self, module: ModuleId) { let scope_id = self.thir.modules.get(&module).expect("no such module"); let blocks = self.lower_map(*scope_id, true); - let name = ItemId(module, "__root".into()); + self.mir.functions.insert( - name, + LookupId(usize::MAX), Func { - name, + name: ItemId(module, "__root".into()), params: Vec::new(), ret_ty: Ty::VOID, blocks, @@ -464,39 +487,6 @@ impl Lowerer { } } -/// Flattens a pattern parameter into a list of bindings to prepare for MIR lowering. -/// -/// # Example -/// ```terbium -/// func sum_tuple((a, b): (int, int)) = a + b; -/// -/// // Lowered to: -/// func sum_tuple(a: int, b: int) = a + b; -/// ``` -pub fn flatten_param( - pat: &Spanned, - ty: Ty, - bindings: &mut Vec<(Ident, Ty)>, -) -> hir::error::Result<()> { - match (pat.value(), ty) { - (Pattern::Ident { ident, .. }, ty) => bindings.push((ident.0, ty)), - (Pattern::Tuple(pats), Ty::Tuple(tys)) => { - if pats.len() != tys.len() { - return Err( - pat_errors::tuple_len_mismatch(pats.len().spanned(pat.span()), tys.len(), None) - ); - } - for (pat, ty) in pats.iter().zip(tys) { - flatten_param(pat, ty, bindings)?; - } - } - (Pattern::Tuple(_), ty) => { - return Err(pat_errors::tuple_mismatch(pat.span(), ty, None)); - } - } - Ok(()) -} - pub fn apply_arg( pat: &Spanned, arg: Spanned, diff --git a/src/main.rs b/src/main.rs index ba6aa381..23f6d884 100644 --- a/src/main.rs +++ b/src/main.rs @@ -117,7 +117,7 @@ fn main() -> Result<(), Box> { let start = std::time::Instant::now(); let ctx = codegen::Context::create(); - let module = compile_llvm(&ctx, &mir_lowerer.mir, ModuleId::root()); + let module = compile_llvm(&ctx, mir_lowerer.mir.functions); full += start.elapsed(); println!("=== [ LLVM IR ({:?} to compile) ] ===", start.elapsed()); diff --git a/test.trb b/test.trb index 6bfc50fd..b66f976f 100644 --- a/test.trb +++ b/test.trb @@ -1,2 +1,8 @@ -let x = 1; -x = 2; +func a(x: int32, y: int32) -> int32 { + x + y +} + +func b() { + a(6, 5) + 10 +} +