Skip to content

Commit

Permalink
changing operation name to MLIROperation
Browse files Browse the repository at this point in the history
  • Loading branch information
baymaks committed Feb 11, 2025
1 parent 576eb65 commit 48b1058
Show file tree
Hide file tree
Showing 23 changed files with 208 additions and 159 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala-3/MLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import scala.collection.mutable

class MLContext() {

val dialectOpContext: mutable.Map[String, OperationObject] = mutable.Map()
val dialectOpContext: mutable.Map[String, MLIROperationObject] = mutable.Map()
val dialectAttrContext: mutable.Map[String, AttributeObject] = mutable.Map()

def getOperation(name: String) = dialectOpContext.get(name)
Expand Down
28 changes: 16 additions & 12 deletions core/src/main/scala-3/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ object Parser {
var parentScope: Option[Scope] = None,
var valueMap: mutable.Map[String, Value[Attribute]] =
mutable.Map.empty[String, Value[Attribute]],
var valueWaitlist: mutable.Map[Operation, ListType[
var valueWaitlist: mutable.Map[MLIROperation, ListType[
(String, Attribute)
]] = mutable.Map.empty[Operation, ListType[(String, Attribute)]],
]] = mutable.Map.empty[MLIROperation, ListType[(String, Attribute)]],
var blockMap: mutable.Map[String, Block] =
mutable.Map.empty[String, Block],
var blockWaitlist: mutable.Map[Operation, ListType[String]] =
mutable.Map.empty[Operation, ListType[String]]
var blockWaitlist: mutable.Map[MLIROperation, ListType[String]] =
mutable.Map.empty[MLIROperation, ListType[String]]
) {

def defineValues(
Expand Down Expand Up @@ -636,12 +636,12 @@ class Parser(val context: MLContext, val args: Args = Args())
// [x] toplevel := (operation | attribute-alias-def | type-alias-def)*
// shortened definition TODO: finish...

def TopLevel[$: P]: P[Operation] = P(
def TopLevel[$: P]: P[MLIROperation] = P(
Start ~ (Operations(0)) ~ E({
currentScope.checkValueWaitlist()
currentScope.checkBlockWaitlist()
}) ~ End
).map((toplevel: ListType[Operation]) =>
).map((toplevel: ListType[MLIROperation]) =>
toplevel.toList match {
case (head: ModuleOp) :: Nil => head
case _ =>
Expand Down Expand Up @@ -705,7 +705,7 @@ class Parser(val context: MLContext, val args: Args = Args())
attributes: DictType[String, Attribute] = DictType(),
resultsTypes: Seq[Attribute] = Seq(),
operandsTypes: Seq[Attribute] = Seq()
): Operation = {
): MLIROperation = {

if (operandsNames.length != operandsTypes.length) {
throw new Exception(
Expand All @@ -720,7 +720,7 @@ class Parser(val context: MLContext, val args: Args = Args())
val useAndRefBlockSeqs: (ListType[Block], ListType[String]) =
currentScope.useBlocks(successorsNames)

val opObject: Option[OperationObject] = ctx.getOperation(opName)
val opObject: Option[MLIROperationObject] = ctx.getOperation(opName)

val op = opObject match {
case Some(x) =>
Expand Down Expand Up @@ -767,10 +767,12 @@ class Parser(val context: MLContext, val args: Args = Args())
return op
}

def Operations[$: P](at_least_this_many: Int = 0): P[ListType[Operation]] =
def Operations[$: P](
at_least_this_many: Int = 0
): P[ListType[MLIROperation]] =
P(OperationPat.rep(at_least_this_many).map(_.to(ListType)))

def OperationPat[$: P]: P[Operation] = P(
def OperationPat[$: P]: P[MLIROperation] = P(
OpResultList.orElse(Seq())./.flatMap(Op(_)) ~/ TrailingLocation.?
)

Expand Down Expand Up @@ -844,7 +846,7 @@ class Parser(val context: MLContext, val args: Args = Args())

def createBlock(
// name arguments operations
uncutBlock: (String, Seq[(String, Attribute)], ListType[Operation])
uncutBlock: (String, Seq[(String, Attribute)], ListType[MLIROperation])
): Block = {
val newBlock = new Block(
operations = uncutBlock._3,
Expand Down Expand Up @@ -874,7 +876,9 @@ class Parser(val context: MLContext, val args: Args = Args())
// \/
// [x] - region ::= `{` operation* block* `}`

def defineRegion(parseResult: (ListType[Operation], Seq[Block])): Region = {
def defineRegion(
parseResult: (ListType[MLIROperation], Seq[Block])
): Region = {
return parseResult._1.length match {
case 0 =>
val region = new Region(blocks = parseResult._2)
Expand Down
19 changes: 14 additions & 5 deletions core/src/main/scala-3/Printer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,14 @@ case class Printer(
|| OPERATION PRINTER ||
\*≡==---==≡≡≡≡≡≡≡≡≡==---==≡*/

def printCustomOperation(op: Operation, indentLevel: Int = 0): String = {
def printCustomOperation(op: MLIROperation, indentLevel: Int = 0): String = {
indent * indentLevel + op.print(this)
}

def printGenericOperation(op: Operation, indentLevel: Int = 0): String = {
def printGenericMLIROperation(
op: MLIROperation,
indentLevel: Int = 0
): String = {
var results: Seq[String] = Seq()
var resultsTypes: Seq[String] = Seq()
var operands: Seq[String] = Seq()
Expand Down Expand Up @@ -191,17 +194,23 @@ case class Printer(
return s"${"\""}${op.name}${"\""}($operationOperands)$operationSuccessors$dictionaryProperties$operationRegions$dictionaryAttributes : $functionType"
}

def printOperation(op: Operation, indentLevel: Int = 0): String = {
def printOperation(op: MLIROperation, indentLevel: Int = 0): String = {
val results =
op.results.map(printValue(_)).mkString(", ") + (if op.results.nonEmpty
then " = "
else "")
indent * indentLevel + results + (if strictly_generic then
printGenericOperation(op, indentLevel)
printGenericMLIROperation(
op,
indentLevel
)
else op.custom_print(this))
}

def printOperations(ops: Seq[Operation], indentLevel: Int = 0): String = {
def printOperations(
ops: Seq[MLIROperation],
indentLevel: Int = 0
): String = {
(for { op <- ops } yield printOperation(op, indentLevel)).mkString("\n")
}

Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala-3/builtin/Builtin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,14 @@ case class AffineSetAttr(val affine_set: AffineSet)
// ModuleOp //
// ==------== //

object ModuleOp extends OperationObject {
object ModuleOp extends MLIROperationObject {
override def name = "builtin.module"
override def factory = ModuleOp.apply

// ==--- Custom Parsing ---== //
override def parse[$: P](
parser: Parser
): P[Operation] =
): P[MLIROperation] =
P(
parser.Region
).map((x: Region) => ModuleOp(regions = ListType(x)))
Expand Down
78 changes: 43 additions & 35 deletions core/src/main/scala-3/ir/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,23 @@ object Block {

def apply(
arguments_types: Iterable[Attribute] | Attribute = Seq(),
operations: Iterable[Operation] | Operation = Seq()
operations: Iterable[MLIROperation] | MLIROperation = Seq()
): Block = new Block(arguments_types, operations)

def apply(operations: Iterable[Operation] | Operation): Block = new Block(
operations
)
def apply(operations: Iterable[MLIROperation] | MLIROperation): Block =
new Block(
operations
)

def apply(
arguments_types: Iterable[Attribute],
operations_expr: Iterable[Value[Attribute]] => Iterable[Operation]
operations_expr: Iterable[Value[Attribute]] => Iterable[MLIROperation]
): Block =
new Block(arguments_types, operations_expr)

def apply(
arguments_types: Attribute,
operations_expr: Value[Attribute] => Iterable[Operation]
operations_expr: Value[Attribute] => Iterable[MLIROperation]
): Block =
new Block(arguments_types, operations_expr)

Expand All @@ -48,7 +49,7 @@ object Block {
*/
case class Block private (
val arguments: ListType[Value[Attribute]],
val operations: ListType[Operation]
val operations: ListType[MLIROperation]
) {

/** Constructs a Block instance with the given argument types and operations.
Expand All @@ -57,21 +58,21 @@ case class Block private (
* The types of the arguments, either as a single Attribute or an Iterable
* of Attributes.
* @param operations
* The operations, either as a single Operation or an Iterable of
* Operations.
* The operations, either as a single MLIROperation or an Iterable of
* MLIROperations.
*/
def this(
arguments_types: Iterable[Attribute] | Attribute = Seq(),
operations: Iterable[Operation] | Operation = Seq()
operations: Iterable[MLIROperation] | MLIROperation = Seq()
) =
this(
ListType.from((arguments_types match {
case single: Attribute => Seq(single)
case multiple: Iterable[Attribute] => multiple
}).map(Value(_))),
ListType.from((operations match {
case single: Operation => Seq(single)
case multiple: Iterable[Operation] => multiple
case single: MLIROperation => Seq(single)
case multiple: Iterable[MLIROperation] => multiple
}))
)

Expand All @@ -85,7 +86,7 @@ case class Block private (
private def this(
args: (
Iterable[Value[Attribute]] | Value[Attribute],
Iterable[Operation] | Operation
Iterable[MLIROperation] | MLIROperation
)
) =
this(
Expand All @@ -94,19 +95,19 @@ case class Block private (
case multiple: Iterable[Value[Attribute]] => multiple
}),
ListType.from(args._2 match {
case single: Operation => Seq(single)
case multiple: Iterable[Operation] => multiple
case single: MLIROperation => Seq(single)
case multiple: Iterable[MLIROperation] => multiple
})
)

/** Constructs a Block instance with the given operations and no block
* arguments.
*
* @param operations
* The operations, either as a single Operation or an Iterable of
* Operations.
* The operations, either as a single MLIROperation or an Iterable of
* MLIROperations.
*/
def this(operations: Iterable[Operation] | Operation) =
def this(operations: Iterable[MLIROperation] | MLIROperation) =
this(Seq(), operations)

/** Constructs a Block instance with the given argument type and a function to
Expand All @@ -119,8 +120,8 @@ case class Block private (
*/
def this(
argument_type: Iterable[Attribute],
operations_expr: Iterable[Value[Attribute]] => Iterable[Operation] |
Operation
operations_expr: Iterable[Value[Attribute]] => Iterable[MLIROperation] |
MLIROperation
) =
this({
val args = argument_type.map(Value(_))
Expand All @@ -138,7 +139,8 @@ case class Block private (
*/
def this(
argument_type: Attribute,
operations_expr: Value[Attribute] => Iterable[Operation] | Operation
operations_expr: Value[Attribute] => Iterable[MLIROperation] |
MLIROperation
) =
this({
val arg = Value(argument_type)
Expand All @@ -147,7 +149,7 @@ case class Block private (

var container_region: Option[Region] = None

private def attach_op(op: Operation): Unit = {
private def attach_op(op: MLIROperation): Unit = {
op.container_block match {
case Some(x) =>
throw new Exception(
Expand All @@ -165,21 +167,24 @@ case class Block private (
}
}

def add_op(new_op: Operation): Unit = {
def add_op(new_op: MLIROperation): Unit = {
val oplen = operations.length
attach_op(new_op)
operations.insertAll(oplen, ListType(new_op))
}

def add_ops(new_ops: Seq[Operation]): Unit = {
def add_ops(new_ops: Seq[MLIROperation]): Unit = {
val oplen = operations.length
for (op <- new_ops) {
attach_op(op)
}
operations.insertAll(oplen, ListType(new_ops: _*))
}

def insert_op_before(existing_op: Operation, new_op: Operation): Unit = {
def insert_op_before(
existing_op: MLIROperation,
new_op: MLIROperation
): Unit = {
(existing_op.container_block equals Some(this)) match {
case true =>
attach_op(new_op)
Expand All @@ -193,8 +198,8 @@ case class Block private (
}

def insert_ops_before(
existing_op: Operation,
new_ops: Seq[Operation]
existing_op: MLIROperation,
new_ops: Seq[MLIROperation]
): Unit = {
(existing_op.container_block equals Some(this)) match {
case true =>
Expand All @@ -210,7 +215,10 @@ case class Block private (
}
}

def insert_op_after(existing_op: Operation, new_op: Operation): Unit = {
def insert_op_after(
existing_op: MLIROperation,
new_op: MLIROperation
): Unit = {
(existing_op.container_block equals Some(this)) match {
case true =>
attach_op(new_op)
Expand All @@ -224,8 +232,8 @@ case class Block private (
}

def insert_ops_after(
existing_op: Operation,
new_ops: Seq[Operation]
existing_op: MLIROperation,
new_ops: Seq[MLIROperation]
): Unit = {
(existing_op.container_block equals Some(this)) match {
case true =>
Expand All @@ -246,27 +254,27 @@ case class Block private (
for (op <- operations) op.drop_all_references
}

def detach_op(op: Operation): Operation = {
def detach_op(op: MLIROperation): MLIROperation = {
(op.container_block equals Some(this)) match {
case true =>
op.container_block = None
operations -= op
return op
case false =>
throw new Exception(
"Operation can only be detached from a block in which it is contained."
"MLIROperation can only be detached from a block in which it is contained."
)
}
}

def erase_op(op: Operation) = {
def erase_op(op: MLIROperation) = {
detach_op(op)
op.erase()
}

def getIndexOf(op: Operation): Int = {
def getIndexOf(op: MLIROperation): Int = {
operations.lastIndexOf(op) match {
case -1 => throw new Exception("Operation not present in the block.")
case -1 => throw new Exception("MLIROperation not present in the block.")
case x => x
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala-3/ir/Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ package scair.ir
\*≡==---==≡≡==---==≡*/

final case class Dialect(
val operations: Seq[OperationObject],
val operations: Seq[MLIROperationObject],
val attributes: Seq[AttributeObject]
) {}
Loading

0 comments on commit 48b1058

Please sign in to comment.