diff --git a/core/src/main/scala-3/MLContext.scala b/core/src/main/scala-3/MLContext.scala index b90d07f8..aa59bdd1 100644 --- a/core/src/main/scala-3/MLContext.scala +++ b/core/src/main/scala-3/MLContext.scala @@ -13,7 +13,7 @@ import scala.collection.mutable class MLContext() { - val dialectOpContext: mutable.Map[String, OperationObject] = mutable.Map() + val dialectOpContext: mutable.Map[String, MLIROperationObject] = mutable.Map() val dialectAttrContext: mutable.Map[String, AttributeObject] = mutable.Map() def getOperation(name: String) = dialectOpContext.get(name) diff --git a/core/src/main/scala-3/Parser.scala b/core/src/main/scala-3/Parser.scala index 046eaf08..83aceba3 100644 --- a/core/src/main/scala-3/Parser.scala +++ b/core/src/main/scala-3/Parser.scala @@ -147,13 +147,13 @@ object Parser { var parentScope: Option[Scope] = None, var valueMap: mutable.Map[String, Value[Attribute]] = mutable.Map.empty[String, Value[Attribute]], - var valueWaitlist: mutable.Map[Operation, ListType[ + var valueWaitlist: mutable.Map[MLIROperation, ListType[ (String, Attribute) - ]] = mutable.Map.empty[Operation, ListType[(String, Attribute)]], + ]] = mutable.Map.empty[MLIROperation, ListType[(String, Attribute)]], var blockMap: mutable.Map[String, Block] = mutable.Map.empty[String, Block], - var blockWaitlist: mutable.Map[Operation, ListType[String]] = - mutable.Map.empty[Operation, ListType[String]] + var blockWaitlist: mutable.Map[MLIROperation, ListType[String]] = + mutable.Map.empty[MLIROperation, ListType[String]] ) { def defineValues( @@ -636,12 +636,12 @@ class Parser(val context: MLContext, val args: Args = Args()) // [x] toplevel := (operation | attribute-alias-def | type-alias-def)* // shortened definition TODO: finish... - def TopLevel[$: P]: P[Operation] = P( + def TopLevel[$: P]: P[MLIROperation] = P( Start ~ (Operations(0)) ~ E({ currentScope.checkValueWaitlist() currentScope.checkBlockWaitlist() }) ~ End - ).map((toplevel: ListType[Operation]) => + ).map((toplevel: ListType[MLIROperation]) => toplevel.toList match { case (head: ModuleOp) :: Nil => head case _ => @@ -705,7 +705,7 @@ class Parser(val context: MLContext, val args: Args = Args()) attributes: DictType[String, Attribute] = DictType(), resultsTypes: Seq[Attribute] = Seq(), operandsTypes: Seq[Attribute] = Seq() - ): Operation = { + ): MLIROperation = { if (operandsNames.length != operandsTypes.length) { throw new Exception( @@ -720,7 +720,7 @@ class Parser(val context: MLContext, val args: Args = Args()) val useAndRefBlockSeqs: (ListType[Block], ListType[String]) = currentScope.useBlocks(successorsNames) - val opObject: Option[OperationObject] = ctx.getOperation(opName) + val opObject: Option[MLIROperationObject] = ctx.getOperation(opName) val op = opObject match { case Some(x) => @@ -767,10 +767,12 @@ class Parser(val context: MLContext, val args: Args = Args()) return op } - def Operations[$: P](at_least_this_many: Int = 0): P[ListType[Operation]] = + def Operations[$: P]( + at_least_this_many: Int = 0 + ): P[ListType[MLIROperation]] = P(OperationPat.rep(at_least_this_many).map(_.to(ListType))) - def OperationPat[$: P]: P[Operation] = P( + def OperationPat[$: P]: P[MLIROperation] = P( OpResultList.orElse(Seq())./.flatMap(Op(_)) ~/ TrailingLocation.? ) @@ -844,7 +846,7 @@ class Parser(val context: MLContext, val args: Args = Args()) def createBlock( // name arguments operations - uncutBlock: (String, Seq[(String, Attribute)], ListType[Operation]) + uncutBlock: (String, Seq[(String, Attribute)], ListType[MLIROperation]) ): Block = { val newBlock = new Block( operations = uncutBlock._3, @@ -874,7 +876,9 @@ class Parser(val context: MLContext, val args: Args = Args()) // \/ // [x] - region ::= `{` operation* block* `}` - def defineRegion(parseResult: (ListType[Operation], Seq[Block])): Region = { + def defineRegion( + parseResult: (ListType[MLIROperation], Seq[Block]) + ): Region = { return parseResult._1.length match { case 0 => val region = new Region(blocks = parseResult._2) diff --git a/core/src/main/scala-3/Printer.scala b/core/src/main/scala-3/Printer.scala index dacbd484..fa3d1667 100644 --- a/core/src/main/scala-3/Printer.scala +++ b/core/src/main/scala-3/Printer.scala @@ -120,11 +120,14 @@ case class Printer( || OPERATION PRINTER || \*≡==---==≡≡≡≡≡≡≡≡≡==---==≡*/ - def printCustomOperation(op: Operation, indentLevel: Int = 0): String = { + def printCustomOperation(op: MLIROperation, indentLevel: Int = 0): String = { indent * indentLevel + op.print(this) } - def printGenericOperation(op: Operation, indentLevel: Int = 0): String = { + def printGenericMLIROperation( + op: MLIROperation, + indentLevel: Int = 0 + ): String = { var results: Seq[String] = Seq() var resultsTypes: Seq[String] = Seq() var operands: Seq[String] = Seq() @@ -191,17 +194,23 @@ case class Printer( return s"${"\""}${op.name}${"\""}($operationOperands)$operationSuccessors$dictionaryProperties$operationRegions$dictionaryAttributes : $functionType" } - def printOperation(op: Operation, indentLevel: Int = 0): String = { + def printOperation(op: MLIROperation, indentLevel: Int = 0): String = { val results = op.results.map(printValue(_)).mkString(", ") + (if op.results.nonEmpty then " = " else "") indent * indentLevel + results + (if strictly_generic then - printGenericOperation(op, indentLevel) + printGenericMLIROperation( + op, + indentLevel + ) else op.custom_print(this)) } - def printOperations(ops: Seq[Operation], indentLevel: Int = 0): String = { + def printOperations( + ops: Seq[MLIROperation], + indentLevel: Int = 0 + ): String = { (for { op <- ops } yield printOperation(op, indentLevel)).mkString("\n") } diff --git a/core/src/main/scala-3/builtin/Builtin.scala b/core/src/main/scala-3/builtin/Builtin.scala index f761013c..9d1ca217 100644 --- a/core/src/main/scala-3/builtin/Builtin.scala +++ b/core/src/main/scala-3/builtin/Builtin.scala @@ -459,14 +459,14 @@ case class AffineSetAttr(val affine_set: AffineSet) // ModuleOp // // ==------== // -object ModuleOp extends OperationObject { +object ModuleOp extends MLIROperationObject { override def name = "builtin.module" override def factory = ModuleOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = + ): P[MLIROperation] = P( parser.Region ).map((x: Region) => ModuleOp(regions = ListType(x))) diff --git a/core/src/main/scala-3/ir/Block.scala b/core/src/main/scala-3/ir/Block.scala index 8259e39c..197324fd 100644 --- a/core/src/main/scala-3/ir/Block.scala +++ b/core/src/main/scala-3/ir/Block.scala @@ -18,22 +18,23 @@ object Block { def apply( arguments_types: Iterable[Attribute] | Attribute = Seq(), - operations: Iterable[Operation] | Operation = Seq() + operations: Iterable[MLIROperation] | MLIROperation = Seq() ): Block = new Block(arguments_types, operations) - def apply(operations: Iterable[Operation] | Operation): Block = new Block( - operations - ) + def apply(operations: Iterable[MLIROperation] | MLIROperation): Block = + new Block( + operations + ) def apply( arguments_types: Iterable[Attribute], - operations_expr: Iterable[Value[Attribute]] => Iterable[Operation] + operations_expr: Iterable[Value[Attribute]] => Iterable[MLIROperation] ): Block = new Block(arguments_types, operations_expr) def apply( arguments_types: Attribute, - operations_expr: Value[Attribute] => Iterable[Operation] + operations_expr: Value[Attribute] => Iterable[MLIROperation] ): Block = new Block(arguments_types, operations_expr) @@ -48,7 +49,7 @@ object Block { */ case class Block private ( val arguments: ListType[Value[Attribute]], - val operations: ListType[Operation] + val operations: ListType[MLIROperation] ) { /** Constructs a Block instance with the given argument types and operations. @@ -57,12 +58,12 @@ case class Block private ( * The types of the arguments, either as a single Attribute or an Iterable * of Attributes. * @param operations - * The operations, either as a single Operation or an Iterable of - * Operations. + * The operations, either as a single MLIROperation or an Iterable of + * MLIROperations. */ def this( arguments_types: Iterable[Attribute] | Attribute = Seq(), - operations: Iterable[Operation] | Operation = Seq() + operations: Iterable[MLIROperation] | MLIROperation = Seq() ) = this( ListType.from((arguments_types match { @@ -70,8 +71,8 @@ case class Block private ( case multiple: Iterable[Attribute] => multiple }).map(Value(_))), ListType.from((operations match { - case single: Operation => Seq(single) - case multiple: Iterable[Operation] => multiple + case single: MLIROperation => Seq(single) + case multiple: Iterable[MLIROperation] => multiple })) ) @@ -85,7 +86,7 @@ case class Block private ( private def this( args: ( Iterable[Value[Attribute]] | Value[Attribute], - Iterable[Operation] | Operation + Iterable[MLIROperation] | MLIROperation ) ) = this( @@ -94,8 +95,8 @@ case class Block private ( case multiple: Iterable[Value[Attribute]] => multiple }), ListType.from(args._2 match { - case single: Operation => Seq(single) - case multiple: Iterable[Operation] => multiple + case single: MLIROperation => Seq(single) + case multiple: Iterable[MLIROperation] => multiple }) ) @@ -103,10 +104,10 @@ case class Block private ( * arguments. * * @param operations - * The operations, either as a single Operation or an Iterable of - * Operations. + * The operations, either as a single MLIROperation or an Iterable of + * MLIROperations. */ - def this(operations: Iterable[Operation] | Operation) = + def this(operations: Iterable[MLIROperation] | MLIROperation) = this(Seq(), operations) /** Constructs a Block instance with the given argument type and a function to @@ -119,8 +120,8 @@ case class Block private ( */ def this( argument_type: Iterable[Attribute], - operations_expr: Iterable[Value[Attribute]] => Iterable[Operation] | - Operation + operations_expr: Iterable[Value[Attribute]] => Iterable[MLIROperation] | + MLIROperation ) = this({ val args = argument_type.map(Value(_)) @@ -138,7 +139,8 @@ case class Block private ( */ def this( argument_type: Attribute, - operations_expr: Value[Attribute] => Iterable[Operation] | Operation + operations_expr: Value[Attribute] => Iterable[MLIROperation] | + MLIROperation ) = this({ val arg = Value(argument_type) @@ -147,7 +149,7 @@ case class Block private ( var container_region: Option[Region] = None - private def attach_op(op: Operation): Unit = { + private def attach_op(op: MLIROperation): Unit = { op.container_block match { case Some(x) => throw new Exception( @@ -165,13 +167,13 @@ case class Block private ( } } - def add_op(new_op: Operation): Unit = { + def add_op(new_op: MLIROperation): Unit = { val oplen = operations.length attach_op(new_op) operations.insertAll(oplen, ListType(new_op)) } - def add_ops(new_ops: Seq[Operation]): Unit = { + def add_ops(new_ops: Seq[MLIROperation]): Unit = { val oplen = operations.length for (op <- new_ops) { attach_op(op) @@ -179,7 +181,10 @@ case class Block private ( operations.insertAll(oplen, ListType(new_ops: _*)) } - def insert_op_before(existing_op: Operation, new_op: Operation): Unit = { + def insert_op_before( + existing_op: MLIROperation, + new_op: MLIROperation + ): Unit = { (existing_op.container_block equals Some(this)) match { case true => attach_op(new_op) @@ -193,8 +198,8 @@ case class Block private ( } def insert_ops_before( - existing_op: Operation, - new_ops: Seq[Operation] + existing_op: MLIROperation, + new_ops: Seq[MLIROperation] ): Unit = { (existing_op.container_block equals Some(this)) match { case true => @@ -210,7 +215,10 @@ case class Block private ( } } - def insert_op_after(existing_op: Operation, new_op: Operation): Unit = { + def insert_op_after( + existing_op: MLIROperation, + new_op: MLIROperation + ): Unit = { (existing_op.container_block equals Some(this)) match { case true => attach_op(new_op) @@ -224,8 +232,8 @@ case class Block private ( } def insert_ops_after( - existing_op: Operation, - new_ops: Seq[Operation] + existing_op: MLIROperation, + new_ops: Seq[MLIROperation] ): Unit = { (existing_op.container_block equals Some(this)) match { case true => @@ -246,7 +254,7 @@ case class Block private ( for (op <- operations) op.drop_all_references } - def detach_op(op: Operation): Operation = { + def detach_op(op: MLIROperation): MLIROperation = { (op.container_block equals Some(this)) match { case true => op.container_block = None @@ -254,19 +262,19 @@ case class Block private ( return op case false => throw new Exception( - "Operation can only be detached from a block in which it is contained." + "MLIROperation can only be detached from a block in which it is contained." ) } } - def erase_op(op: Operation) = { + def erase_op(op: MLIROperation) = { detach_op(op) op.erase() } - def getIndexOf(op: Operation): Int = { + def getIndexOf(op: MLIROperation): Int = { operations.lastIndexOf(op) match { - case -1 => throw new Exception("Operation not present in the block.") + case -1 => throw new Exception("MLIROperation not present in the block.") case x => x } diff --git a/core/src/main/scala-3/ir/Dialect.scala b/core/src/main/scala-3/ir/Dialect.scala index 949379cc..8ef6c4b9 100644 --- a/core/src/main/scala-3/ir/Dialect.scala +++ b/core/src/main/scala-3/ir/Dialect.scala @@ -12,6 +12,6 @@ package scair.ir \*≡==---==≡≡==---==≡*/ final case class Dialect( - val operations: Seq[OperationObject], + val operations: Seq[MLIROperationObject], val attributes: Seq[AttributeObject] ) {} diff --git a/core/src/main/scala-3/ir/IRUtils.scala b/core/src/main/scala-3/ir/IRUtils.scala index 79e641a0..008760e9 100644 --- a/core/src/main/scala-3/ir/IRUtils.scala +++ b/core/src/main/scala-3/ir/IRUtils.scala @@ -17,6 +17,22 @@ import scala.collection.mutable.ListBuffer // ╚██████╔╝ ░░░██║░░░ ██║ ███████╗ ██████╔╝ // ░╚═════╝░ ░░░╚═╝░░░ ╚═╝ ╚══════╝ ╚═════╝░ +/*≡==--==≡≡≡≡≡≡≡==--=≡≡*\ +|| OP INPUTS || +\*≡==---==≡≡≡≡≡==---==≡*/ +// for ClairV2 + +type Operand[+T <: Attribute] = Value[T] +case class Result[+T <: Attribute](val value: Value[T]) + +case class Property[+T <: Attribute]( + val typ: T +) + +case class Attr[+T <: Attribute]( + val typ: T +) + /*≡==--==≡≡≡==--=≡≡*\ || UTILS || \*≡==---==≡==---==≡*/ diff --git a/core/src/main/scala-3/ir/Operation.scala b/core/src/main/scala-3/ir/Operation.scala index 3c64f24b..24d47bf3 100644 --- a/core/src/main/scala-3/ir/Operation.scala +++ b/core/src/main/scala-3/ir/Operation.scala @@ -15,7 +15,7 @@ import scair.Printer || OPERATIONS || \*≡==---==≡≡==---==≡*/ -sealed abstract class Operation( +sealed abstract class MLIROperation( val name: String, val operands: ListType[Value[Attribute]] = ListType(), val successors: ListType[Block] = ListType(), @@ -28,7 +28,7 @@ sealed abstract class Operation( ) extends OpTrait { val results: ListType[Value[Attribute]] = results_types.map(Value(_)) - def op: Operation = this + def op: MLIROperation = this var container_block: Option[Block] = None @@ -87,7 +87,7 @@ sealed abstract class Operation( } def custom_print(p: Printer): String = - p.printGenericOperation(this) + p.printGenericMLIROperation(this) final def print(printer: Printer): String = { printer.printOperation(this) @@ -118,7 +118,7 @@ case class UnregisteredOperation( DictType.empty[String, Attribute], override val dictionaryAttributes: DictType[String, Attribute] = DictType.empty[String, Attribute] -) extends Operation( +) extends MLIROperation( name = name, operands, successors, @@ -138,7 +138,7 @@ class RegisteredOperation( DictType.empty[String, Attribute], dictionaryAttributes: DictType[String, Attribute] = DictType.empty[String, Attribute] -) extends Operation( +) extends MLIROperation( name = name, operands, successors, @@ -148,10 +148,10 @@ class RegisteredOperation( dictionaryAttributes ) -trait OperationObject { +trait MLIROperationObject { def name: String - def parse[$: P](parser: Parser): P[Operation] = + def parse[$: P](parser: Parser): P[MLIROperation] = throw new Exception( s"No custom Parser implemented for Operation '${name}'" ) @@ -163,7 +163,7 @@ trait OperationObject { ListType[Region] /* = regions */, DictType[String, Attribute], /* = dictProps */ DictType[String, Attribute] /* = dictAttrs */ - ) => Operation + ) => MLIROperation def factory: FactoryType @@ -176,7 +176,7 @@ trait OperationObject { DictType.empty[String, Attribute], dictionaryAttributes: DictType[String, Attribute] = DictType.empty[String, Attribute] - ): Operation = factory( + ): MLIROperation = factory( operands, successors, results_types, diff --git a/core/src/main/scala-3/ir/Region.scala b/core/src/main/scala-3/ir/Region.scala index 43ea4822..c6eb54b6 100644 --- a/core/src/main/scala-3/ir/Region.scala +++ b/core/src/main/scala-3/ir/Region.scala @@ -15,7 +15,7 @@ case class Region( blocks: Seq[Block] ) { - var container_operation: Option[Operation] = None + var container_operation: Option[MLIROperation] = None def drop_all_references: Unit = { container_operation = None diff --git a/core/src/main/scala-3/ir/Traits.scala b/core/src/main/scala-3/ir/Traits.scala index a1108e5c..e6e08f16 100644 --- a/core/src/main/scala-3/ir/Traits.scala +++ b/core/src/main/scala-3/ir/Traits.scala @@ -12,7 +12,7 @@ package scair.ir \*≡==---=≡=---==≡*/ abstract class OpTrait { - def op: Operation + def op: MLIROperation def trait_verify(): Unit = () } diff --git a/core/src/main/scala-3/ir/Value.scala b/core/src/main/scala-3/ir/Value.scala index a525eb18..3acd9e7e 100644 --- a/core/src/main/scala-3/ir/Value.scala +++ b/core/src/main/scala-3/ir/Value.scala @@ -13,7 +13,7 @@ package scair.ir // TO-DO: perhaps a linked list of a use to other uses within an operation // for faster use retrieval and index update -case class Use(val operation: Operation, val index: Int) { +case class Use(val operation: MLIROperation, val index: Int) { override def equals(o: Any): Boolean = o match { case Use(op, idx) => diff --git a/core/src/main/scala-3/scairdl/IRElements.scala b/core/src/main/scala-3/scairdl/IRElements.scala index 612d9b13..7544d89f 100644 --- a/core/src/main/scala-3/scairdl/IRElements.scala +++ b/core/src/main/scala-3/scairdl/IRElements.scala @@ -4,7 +4,7 @@ import fastparse.* import fastparse.ScalaWhitespace.* import scair.dialects.builtin.* import scair.ir.Attribute -import scair.ir.Operation +import scair.ir.MLIROperation import scair.scairdl.constraints.* import java.io.File @@ -47,7 +47,7 @@ abstract class EscapeHatch[T: ClassTag] { class AttrEscapeHatch[T <: Attribute: ClassTag]() extends EscapeHatch[T] -class OpEscapeHatch[T <: Operation: ClassTag]() extends EscapeHatch[T] +class OpEscapeHatch[T <: MLIROperation: ClassTag]() extends EscapeHatch[T] /*≡≡=--=≡≡≡=--=≡≡*\ || TYPES || @@ -257,7 +257,7 @@ case class OperationDef( s""" override def parse[$$:P]( parser: Parser - ): P[Operation] = { + ): P[MLIROperation] = { P( $combinedParsing ).map { @@ -753,7 +753,7 @@ case class OperationDef( """ def print(implicit indent: Int): String = s""" -object $className extends OperationObject { +object $className extends MLIROperationObject { override def name = "$name" override def factory = $className.apply ${assembly_format diff --git a/core/src/main/scala-3/transformations/Passes.scala b/core/src/main/scala-3/transformations/Passes.scala index 36041435..7890a173 100644 --- a/core/src/main/scala-3/transformations/Passes.scala +++ b/core/src/main/scala-3/transformations/Passes.scala @@ -11,5 +11,5 @@ import scair.ir.* abstract class ModulePass { val name: String - def transform(op: Operation): Operation = ??? + def transform(op: MLIROperation): MLIROperation = ??? } diff --git a/core/src/main/scala-3/transformations/PatternRewriter.scala b/core/src/main/scala-3/transformations/PatternRewriter.scala index f3074a07..e0ab2de3 100644 --- a/core/src/main/scala-3/transformations/PatternRewriter.scala +++ b/core/src/main/scala-3/transformations/PatternRewriter.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.Stack object InsertPoint { - def before(op: Operation): InsertPoint = { + def before(op: MLIROperation): InsertPoint = { (op.container_block == None) match { case true => throw new Exception( @@ -37,7 +37,7 @@ object InsertPoint { } } - def after(op: Operation): InsertPoint = { + def after(op: MLIROperation): InsertPoint = { (op.container_block == None) match { case true => throw new Exception( @@ -68,7 +68,10 @@ object InsertPoint { } -case class InsertPoint(val block: Block, val insert_before: Option[Operation]) { +case class InsertPoint( + val block: Block, + val insert_before: Option[MLIROperation] +) { // custom constructor def this(block: Block) = { @@ -92,7 +95,7 @@ case class InsertPoint(val block: Block, val insert_before: Option[Operation]) { object RewriteMethods { - def erase_op(op: Operation) = { + def erase_op(op: MLIROperation) = { op.container_block match { case Some(block) => block.erase_op(op) @@ -103,12 +106,12 @@ object RewriteMethods { def insert_ops_at( insertion_point: InsertPoint, - ops: Operation | Seq[Operation] + ops: MLIROperation | Seq[MLIROperation] ): Unit = { val operations = ops match { - case x: Operation => Seq(x) - case y: Seq[Operation] => y + case x: MLIROperation => Seq(x) + case y: Seq[MLIROperation] => y } insertion_point.insert_before match { case Some(op) => @@ -122,22 +125,22 @@ object RewriteMethods { } def insert_ops_before( - op: Operation, - new_ops: Operation | Seq[Operation] + op: MLIROperation, + new_ops: MLIROperation | Seq[MLIROperation] ): Unit = { insert_ops_at(InsertPoint.before(op), new_ops) } def insert_ops_after( - op: Operation, - new_ops: Operation | Seq[Operation] + op: MLIROperation, + new_ops: MLIROperation | Seq[MLIROperation] ): Unit = { insert_ops_at(InsertPoint.after(op), new_ops) } def replace_op( - op: Operation, - new_ops: Operation | Seq[Operation], + op: MLIROperation, + new_ops: MLIROperation | Seq[MLIROperation], new_results: Option[Seq[Value[Attribute]]] = None ): Unit = { @@ -148,8 +151,8 @@ object RewriteMethods { } val ops = new_ops match { - case x: Operation => Seq(x) - case y: Seq[Operation] => y + case x: MLIROperation => Seq(x) + case y: Seq[MLIROperation] => y } val results = new_results match { @@ -180,11 +183,11 @@ object RewriteMethods { // OPERATION REWRITER // class PatternRewriter( - var current_op: Operation + var current_op: MLIROperation ) { var has_done_action: Boolean = false - def erase_op(op: Operation): Unit = { + def erase_op(op: MLIROperation): Unit = { RewriteMethods.erase_op(op) has_done_action = true } @@ -196,55 +199,59 @@ class PatternRewriter( def insert_op_at_location( insertion_point: InsertPoint, - ops: Operation | Seq[Operation] + ops: MLIROperation | Seq[MLIROperation] ): Unit = { RewriteMethods.insert_ops_at(insertion_point, ops) has_done_action = true } - def insert_op_before_matched_op(ops: Operation | Seq[Operation]): Unit = { + def insert_op_before_matched_op( + ops: MLIROperation | Seq[MLIROperation] + ): Unit = { RewriteMethods.insert_ops_before(current_op, ops) has_done_action = true } - def insert_op_after_matched_op(ops: Operation | Seq[Operation]): Unit = { + def insert_op_after_matched_op( + ops: MLIROperation | Seq[MLIROperation] + ): Unit = { RewriteMethods.insert_ops_before(current_op, ops) has_done_action = true } def insert_op_at_end_of( block: Block, - ops: Operation | Seq[Operation] + ops: MLIROperation | Seq[MLIROperation] ): Unit = { insert_op_at_location(InsertPoint.at_end_of(block), ops) } def insert_op_at_start_of( block: Block, - ops: Operation | Seq[Operation] + ops: MLIROperation | Seq[MLIROperation] ): Unit = { insert_op_at_location(InsertPoint.at_start_of(block), ops) } def insert_ops_before( - op: Operation, - new_ops: Operation | Seq[Operation] + op: MLIROperation, + new_ops: MLIROperation | Seq[MLIROperation] ): Unit = { RewriteMethods.insert_ops_before(op, new_ops) has_done_action = true } def insert_ops_after( - op: Operation, - new_ops: Operation | Seq[Operation] + op: MLIROperation, + new_ops: MLIROperation | Seq[MLIROperation] ): Unit = { RewriteMethods.insert_ops_after(op, new_ops) has_done_action = true } def replace_op( - op: Operation, - new_ops: Operation | Seq[Operation], + op: MLIROperation, + new_ops: MLIROperation | Seq[MLIROperation], new_results: Option[Seq[Value[Attribute]]] = None ): Unit = { RewriteMethods.replace_op(op, new_ops, new_results) @@ -253,7 +260,10 @@ class PatternRewriter( } abstract class RewritePattern { - def match_and_rewrite(op: Operation, rewriter: PatternRewriter): Unit = ??? + + def match_and_rewrite(op: MLIROperation, rewriter: PatternRewriter): Unit = + ??? + } // OPERATION REWRITE WALKER // @@ -261,20 +271,20 @@ class PatternRewriteWalker( val pattern: RewritePattern ) { - private var worklist = Stack[Operation]() + private var worklist = Stack[MLIROperation]() def rewrite_module(module: ModuleOp): Unit = { return rewrite_op(module) } - def rewrite_op(op: Operation): Unit = { + def rewrite_op(op: MLIROperation): Unit = { populate_worklist(op) var op_was_modified = process_worklist() return op_was_modified } - private def populate_worklist(op: Operation): Unit = { + private def populate_worklist(op: MLIROperation): Unit = { worklist.push(op) op.regions.reverseIterator.foreach((x: Region) => x.blocks.reverseIterator.foreach((y: Block) => diff --git a/core/src/test/scala-3/TraitTest.scala b/core/src/test/scala-3/TraitTest.scala index 6a18c574..05acfb03 100644 --- a/core/src/test/scala-3/TraitTest.scala +++ b/core/src/test/scala-3/TraitTest.scala @@ -8,7 +8,7 @@ import scair.ir.* import scala.collection.mutable -object FillerOp extends OperationObject { +object FillerOp extends MLIROperationObject { override def name: String = "filler" override def factory: FactoryType = FillerOp.apply } @@ -32,7 +32,7 @@ case class FillerOp( dictionaryAttributes ) {} -object TerminatorOp extends OperationObject { +object TerminatorOp extends MLIROperationObject { override def name: String = "terminator" override def factory: FactoryType = TerminatorOp.apply } @@ -57,7 +57,7 @@ case class TerminatorOp( ) with IsTerminator {} -object NoTerminatorOp extends OperationObject { +object NoTerminatorOp extends MLIROperationObject { override def name: String = "noterminator" override def factory: FactoryType = NoTerminatorOp.apply } diff --git a/dialects/src/main/scala-3/lingodb/DBOps.scala b/dialects/src/main/scala-3/lingodb/DBOps.scala index 1486bf54..92b00e50 100644 --- a/dialects/src/main/scala-3/lingodb/DBOps.scala +++ b/dialects/src/main/scala-3/lingodb/DBOps.scala @@ -262,14 +262,14 @@ case class DB_StringType(val typ: Seq[Attribute]) // ConstantOp // // ==----------== // -object DB_ConstantOp extends OperationObject { +object DB_ConstantOp extends MLIROperationObject { override def name: String = "db.constant" override def factory = DB_ConstantOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( "(" ~ parser.Type ~ ")" ~ ":" ~ parser.Type ~ parser.OptionalAttributes ).map( @@ -332,14 +332,14 @@ case class DB_ConstantOp( // CompareOp // // ==----------== // -object DB_CmpOp extends OperationObject { +object DB_CmpOp extends MLIROperationObject { override def name: String = "db.compare" override def factory = DB_CmpOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( DB_CmpPredicateAttr.caseParser ~ ValueId ~ ":" ~ parser.Type ~ "," ~ ValueId ~ ":" ~ parser.Type ~ parser.OptionalAttributes ).map( @@ -418,7 +418,7 @@ case class DB_CmpOp( // MulOp // // ==-----== // -object DB_MulOp extends OperationObject { +object DB_MulOp extends MLIROperationObject { override def name: String = "db.mul" override def factory = DB_MulOp.apply @@ -482,7 +482,7 @@ object DB_MulOp extends OperationObject { override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ ":" ~ parser.Type ~ "," ~ ValueId ~ ":" ~ parser.Type ~ parser.OptionalAttributes ).map( @@ -562,7 +562,7 @@ case class DB_MulOp( // DivOp // // ==-----== // -object DB_DivOp extends OperationObject { +object DB_DivOp extends MLIROperationObject { override def name: String = "db.div" override def factory = DB_DivOp.apply @@ -629,7 +629,7 @@ object DB_DivOp extends OperationObject { override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ ":" ~ parser.Type ~ "," ~ ValueId ~ ":" ~ parser.Type ~ parser.OptionalAttributes ).map( @@ -705,14 +705,14 @@ case class DB_DivOp( // AddOp // // ==-----== // -object DB_AddOp extends OperationObject { +object DB_AddOp extends MLIROperationObject { override def name: String = "db.add" override def factory = DB_AddOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ ":" ~ parser.Type ~ "," ~ ValueId ~ ":" ~ parser.Type ~ parser.OptionalAttributes ).map( @@ -790,14 +790,14 @@ case class DB_AddOp( // SubOp // // ==-----== // -object DB_SubOp extends OperationObject { +object DB_SubOp extends MLIROperationObject { override def name: String = "db.sub" override def factory = DB_SubOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ ":" ~ parser.Type ~ "," ~ ValueId ~ ":" ~ parser.Type ~ parser.OptionalAttributes ).map( @@ -873,14 +873,14 @@ case class DB_SubOp( // CastOp // // ==-----== // -object CastOp extends OperationObject { +object CastOp extends MLIROperationObject { override def name: String = "db.cast" override def factory = CastOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ ":" ~ parser.Type ~ "->" ~ parser.Type.rep ~ parser.OptionalAttributes ).map( diff --git a/dialects/src/main/scala-3/lingodb/RelAlgOps.scala b/dialects/src/main/scala-3/lingodb/RelAlgOps.scala index 06f11df9..d06eddc1 100644 --- a/dialects/src/main/scala-3/lingodb/RelAlgOps.scala +++ b/dialects/src/main/scala-3/lingodb/RelAlgOps.scala @@ -145,10 +145,12 @@ private def DialectRegion[$: P](parser: Parser) = P( b }) ~ "{" - ~ parser.Operations(1) ~ "}").map((b: Block, y: ListType[Operation]) => { - b.operations ++= y - new Region(Seq(b)) - }) + ~ parser.Operations(1) ~ "}").map( + (b: Block, y: ListType[MLIROperation]) => { + b.operations ++= y + new Region(Seq(b)) + } + ) ) ~ E({ parser.enterParentRegion }) @@ -156,7 +158,7 @@ private def DialectRegion[$: P](parser: Parser) = P( // BaseTableOp // // ==-----------== // -object BaseTableOp extends OperationObject { +object BaseTableOp extends MLIROperationObject { override def name: String = "relalg.basetable" override def factory = BaseTableOp.apply @@ -237,14 +239,14 @@ case class BaseTableOp( // SelectionOp // // ==-----------== // -object SelectionOp extends OperationObject { +object SelectionOp extends MLIROperationObject { override def name: String = "relalg.selection" override def factory = SelectionOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ DialectRegion(parser) ~ parser.OptionalKeywordAttributes ) @@ -314,14 +316,14 @@ case class SelectionOp( // MapOp // // ==-----== // -object MapOp extends OperationObject { +object MapOp extends MLIROperationObject { override def name: String = "relalg.map" override def factory = MapOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ "computes" ~ ":" ~ "[" ~ ColumnDefAttr.parse(parser).rep.map(ArrayAttribute(_)) ~ "]" @@ -409,14 +411,14 @@ case class MapOp( // AggregationOp // // ==-------------== // -object AggregationOp extends OperationObject { +object AggregationOp extends MLIROperationObject { override def name: String = "relalg.aggregation" override def factory = AggregationOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ "[" ~ ColumnRefAttr .parse(parser) @@ -527,14 +529,14 @@ case class AggregationOp( // CountRowsOp // // ==-----------== // -object CountRowsOp extends OperationObject { +object CountRowsOp extends MLIROperationObject { override def name: String = "relalg.count" override def factory = CountRowsOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ parser.OptionalAttributes ).map( ( @@ -604,14 +606,14 @@ case class CountRowsOp( // AggrFuncOp // // ==----------== // -object AggrFuncOp extends OperationObject { +object AggrFuncOp extends MLIROperationObject { override def name: String = "relalg.aggrfn" override def factory = AggrFuncOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( RelAlg_AggrFunc.caseParser ~ ColumnRefAttr.parse(parser) ~ ValueId ~ ":" ~ parser.Type.rep(1) ~ parser.OptionalAttributes @@ -708,14 +710,14 @@ case class AggrFuncOp( // SortOp // // ==------== // -object SortOp extends OperationObject { +object SortOp extends MLIROperationObject { override def name: String = "relalg.sort" override def factory = SortOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ "[" ~ (SortSpecificationAttr .parse(parser)) @@ -805,14 +807,14 @@ case class SortOp( // MaterializeOp // // ==-------------== // -object MaterializeOp extends OperationObject { +object MaterializeOp extends MLIROperationObject { override def name: String = "relalg.materialize" override def factory = MaterializeOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ "[" ~ ColumnRefAttr .parse(parser) diff --git a/dialects/src/main/scala-3/lingodb/SubOperatorOps.scala b/dialects/src/main/scala-3/lingodb/SubOperatorOps.scala index ecbf1dac..7ee7f41b 100644 --- a/dialects/src/main/scala-3/lingodb/SubOperatorOps.scala +++ b/dialects/src/main/scala-3/lingodb/SubOperatorOps.scala @@ -80,14 +80,14 @@ case class ResultTable( // SetResultOp // // ==-----------== // -object SetResultOp extends OperationObject { +object SetResultOp extends MLIROperationObject { override def name: String = "subop.set_result" override def factory = SetResultOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( parser.Type ~ ValueId ~ ":" ~ parser.Type ~ parser.OptionalAttributes ).map( diff --git a/dialects/src/main/scala-3/lingodb/TupleStream.scala b/dialects/src/main/scala-3/lingodb/TupleStream.scala index 04350074..cf7862ab 100644 --- a/dialects/src/main/scala-3/lingodb/TupleStream.scala +++ b/dialects/src/main/scala-3/lingodb/TupleStream.scala @@ -126,7 +126,7 @@ case class ColumnRefAttr(val refName: SymbolRefAttr) // ReturnOp // // ==--------== // -object ReturnOp extends OperationObject { +object ReturnOp extends MLIROperationObject { override def name: String = "tuples.return" override def factory = ReturnOp.apply @@ -140,7 +140,7 @@ object ReturnOp extends OperationObject { override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( parser.OptionalAttributes ~ (ValueId.rep(sep = ",") ~ ":" ~ parser.Type.rep(sep = ",")).orElse((Seq(), Seq())) @@ -192,14 +192,14 @@ case class ReturnOp( // GetColumnOp // // ==-----------== // -object GetColumnOp extends OperationObject { +object GetColumnOp extends MLIROperationObject { override def name: String = "tuples.getcol" override def factory = GetColumnOp.apply // ==--- Custom Parsing ---== // override def parse[$: P]( parser: Parser - ): P[Operation] = P( + ): P[MLIROperation] = P( ValueId ~ ColumnRefAttr.parse(parser) ~ ":" ~ parser.Type ~ parser.OptionalAttributes ).map( diff --git a/dialects/src/main/scala-3/math/Math.scala b/dialects/src/main/scala-3/math/Math.scala index 1a4e5279..5d4bf79d 100644 --- a/dialects/src/main/scala-3/math/Math.scala +++ b/dialects/src/main/scala-3/math/Math.scala @@ -15,13 +15,13 @@ import scala.collection.mutable // AbsfOp // // ==--------== // -object AbsfOp extends OperationObject { +object AbsfOp extends MLIROperationObject { override def name: String = "math.absf" override def factory: FactoryType = AbsfOp.apply override def parse[$: P]( parser: Parser - ): P[Operation] = { + ): P[MLIROperation] = { P( "" ~ Parser.ValueUse ~ ":" ~ parser.Type ).map { case (operandName, type_) => @@ -73,13 +73,13 @@ case class AbsfOp( // FPowIOp // // ==--------== // -object FPowIOp extends OperationObject { +object FPowIOp extends MLIROperationObject { override def name: String = "math.fpowi" override def factory = FPowIOp.apply override def parse[$: P]( parser: Parser - ): P[Operation] = { + ): P[MLIROperation] = { P( Parser.ValueUse ~ "," ~ Parser.ValueUse ~ ":" ~ parser.Type ~ "," ~ parser.Type ).map { diff --git a/dialects/src/main/scala-3/test/Test.scala b/dialects/src/main/scala-3/test/Test.scala index d59040f0..a062d990 100644 --- a/dialects/src/main/scala-3/test/Test.scala +++ b/dialects/src/main/scala-3/test/Test.scala @@ -21,7 +21,7 @@ case class TestOp( dictionaryAttributes ) -object TestOp extends OperationObject { +object TestOp extends MLIROperationObject { override def name = "test.op" override def factory = TestOp.apply } diff --git a/tests/filecheck/core/general/.scala-build/ide-inputs.json b/tests/filecheck/core/general/.scala-build/ide-inputs.json index 538857c6..fc7effc4 100644 --- a/tests/filecheck/core/general/.scala-build/ide-inputs.json +++ b/tests/filecheck/core/general/.scala-build/ide-inputs.json @@ -1 +1 @@ -{"args":["/home/maks/phd/scair/scair/tests/filecheck/core/general/hello_world.scala"]} \ No newline at end of file +{"args":["/home/maks/phd/scair-off/scair/tests/filecheck/core/general/hello_world.scala"]} \ No newline at end of file diff --git a/transformations/src/main/scala-3/CMathDummyTransformation.scala b/transformations/src/main/scala-3/CMathDummyTransformation.scala index a0473351..cc407920 100644 --- a/transformations/src/main/scala-3/CMathDummyTransformation.scala +++ b/transformations/src/main/scala-3/CMathDummyTransformation.scala @@ -12,7 +12,7 @@ import scair.transformations.RewritePattern object AddDummyAttributeToDict extends RewritePattern { override def match_and_rewrite( - op: Operation, + op: MLIROperation, rewriter: PatternRewriter ): Unit = { op match { @@ -35,7 +35,7 @@ object TestInsertingDummyOperation extends RewritePattern { def defDum(name: String) = new UnregisteredOperation(name) override def match_and_rewrite( - op: Operation, + op: MLIROperation, rewriter: PatternRewriter ): Unit = { @@ -82,7 +82,7 @@ object TestReplacingDummyOperation extends RewritePattern { ) override def match_and_rewrite( - op: Operation, + op: MLIROperation, rewriter: PatternRewriter ): Unit = { @@ -105,7 +105,7 @@ object TestReplacingDummyOperation extends RewritePattern { object DummyPass extends ModulePass { override val name = "dummy-pass" - override def transform(op: Operation): Operation = { + override def transform(op: MLIROperation): MLIROperation = { val prw = new PatternRewriteWalker(AddDummyAttributeToDict) prw.rewrite_op(op) @@ -117,7 +117,7 @@ object DummyPass extends ModulePass { object TestInsertionPass extends ModulePass { override val name = "test-ins-pass" - override def transform(op: Operation): Operation = { + override def transform(op: MLIROperation): MLIROperation = { val prw = new PatternRewriteWalker(TestInsertingDummyOperation) prw.rewrite_op(op) @@ -129,7 +129,7 @@ object TestInsertionPass extends ModulePass { object TestReplacementPass extends ModulePass { override val name = "test-rep-pass" - override def transform(op: Operation): Operation = { + override def transform(op: MLIROperation): MLIROperation = { val prw = new PatternRewriteWalker(TestReplacingDummyOperation) prw.rewrite_op(op)