Skip to content

Commit

Permalink
[Prepreq] Revisit the algorithm in transformer.lisp (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei authored Oct 10, 2024
1 parent a92a59d commit 0a42af9
Show file tree
Hide file tree
Showing 14 changed files with 839 additions and 267 deletions.
1 change: 1 addition & 0 deletions source/ajit/caten.ajit.asd
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
(:file "kernel-info")
(:file "multiexpr")
(:file "transform")
(:file "polyhedral-group")
(:file "memory-planner")
(:file "device")
(:file "backends/clang")
Expand Down
31 changes: 31 additions & 0 deletions source/ajit/device.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,34 @@
(defgeneric device-parallel-depth (device-prefix) (:documentation "Return a fixnum indicating n outermost loops are parallelized."))
(defgeneric device-packed-by (device-prefix) (:documentation "Funcall is packed by the returned value of this method. Default: 1 (ignored)"))
(defmethod device-packed-by ((device-prefix t)) 1)

(defclass ISL-Expr (Device) nil)
(defmethod default-device ((device-prefix (eql :isl-expr))) (make-instance 'ISL-Expr))

(defmethod %render-expr ((lang ISL-Expr) op lhs rhs z)
(assert (and lhs rhs) () "~a is not implemented?" op)
(assert (null z))
(format nil "(~a~(~a~)~a)"
(render-expr lang lhs)
(ecase op
(:+ :+) (:- :-) (:* :*) (:/ :/)
(:ADD :+) (:MUL :*) (:IDIV "/")
(:AND :and) (:OR :or) (:!= :!=) (:EQ :=)
(:XOR :xor)
(:% :%) (:equal :=) (:<= :<=) (:>= :>=) (:< :<) (:> :>))
(render-expr lang rhs)))

(defmethod %render-expr ((lang ISL-Expr) (op (eql :MAX)) lhs rhs z)
(assert (and lhs rhs))
(if z
(format nil "max(~a, max(~a, ~a))" (render-expr lang lhs) (render-expr lang rhs) (render-expr lang z))
(format nil "max(~a, ~a)" (render-expr lang lhs) (render-expr lang rhs))))

(defmethod %render-expr ((lang ISL-Expr) (op (eql :MIN)) lhs rhs z)
(assert (and lhs rhs))
(if z
(format nil "min(~a, min(~a, ~a))" (render-expr lang lhs) (render-expr lang rhs) (render-expr lang z))
(format nil "min(~a, ~a)" (render-expr lang lhs) (render-expr lang rhs))))

(defmethod %render-expr ((lang ISL-Expr) (op (eql :Const)) lhs rhs z)
(format nil "~(~a~)" lhs))
11 changes: 8 additions & 3 deletions source/ajit/isl-objects.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
(and
(not (eql (node-class node) :IR))
(not (eql (node-type node) :Allocate))))

;; ~~ AREF ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defun isl-access-expr-no-stride (gid stride upfrom by broadcast-p)
(declare (ignore stride))
Expand All @@ -99,7 +98,8 @@
(make-const gid nil))
(make-expr :MUL (make-const upfrom nil) (make-const stride nil))))))

(defun render-isl-aref (buffer &key (genid #'gid) (indexing #'isl-access-expr) (flatten nil) (strides nil) (use-permute nil) (upper nil) (mutate-scalar nil) &aux (c 0))
(defun render-isl-aref (buffer &key (genid #'gid) (indexing #'isl-access-expr) (flatten nil)
(strides nil) (use-permute nil) (upper nil) (mutate-scalar nil) (sum t) &aux (c 0))
"Renders the stride computation for ISL:
```
A[stride1 * view_info1 * index_component_0 + bias1 + stride2 * view_info2 * index_component_1 + bias2 + ...]
Expand Down Expand Up @@ -129,11 +129,16 @@ A[stride1 * view_info1 * index_component_0 + bias1 + stride2 * view_info2 * inde
#'concatenate 'string
(butlast
(loop for idx in (nconc indices (when upper (loop repeat (- upper c) collect (make-expr 0 nil))))
;; if (not (expr-eq idx (make-const 0 nil)))
append (list (render-expr (default-device :clang) idx) ", "))))
(flet ((add (x y) (make-expr :ADD x y)))
(if (null indices)
nil
(simplify-expr (reduce #'add indices)))))))
(if sum
(simplify-expr (reduce #'add indices))
(loop for e in (map 'list #'simplify-expr indices)
unless (expr-zero-p e)
collect e)))))))
;; ~~ DOMAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defun render-domain (pipeline target-keys &key (depends-on nil))
"Render the domain notation from the scheduled subgraphs
Expand Down
24 changes: 24 additions & 0 deletions source/ajit/memory-planner.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,36 @@ Lifespan:
(loop while (r))
kernels))

(defmethod mp-auto-schedule! ((mp MemoryPlanner))
(let ((polyhedrons
(loop for kernel in (mp-kernels mp)
for group in (mp-groups mp)
collect
(loop for kr in kernel
collect
(group->polyhedral-group group kr)))))
(dolist (p polyhedrons)
;; Final Chance to apply Loop Fusion
;; Kernrels with complicated memory access relations, like Matmul+Transpose, Conv are first fused here
;; Polyhedral Group is a result of splitting a group -> multiple group?
(when (and p (every #'(lambda (x) (typep x 'Polyhedral-Auto-Scheduler)) p))
(affine-fusion p)))

;; Tiling, Vectorizing, Parallelizing(CPU/GPU), Loop Fission here
;; [TODO] Apply the changes to mp-kernerls, mp-groups
))

(defmethod retrive-kernels ((mp MemoryPlanner))
"Finalizes the result of memory-planner, retriving the final rendering-graph"
(flet ((prune ()
"Applies the dead code elimination"
(setf (mp-kernels mp) (dead-kernel-elimination (mp-groups mp) (mp-kernels mp) (append (avm-fw-outputs (mp-avm mp)) (avm-bw-outputs (mp-avm mp)))))))
(prune)

;; [Note] Auto_Scheduler is *work in progress*
;; (mp-auto-schedule! mp)
;; (prune)

;; 1. Mutate output buffers as a scalar
(optimize-memory-load mp)
;; 2. Hide Latency Optimization
Expand Down
Loading

0 comments on commit 0a42af9

Please sign in to comment.