Skip to content

Commit

Permalink
[Refactor] Introduce Memory Optimization based on MIP (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei authored Sep 11, 2024
1 parent 70becee commit 83e2933
Show file tree
Hide file tree
Showing 22 changed files with 874 additions and 451 deletions.
26 changes: 18 additions & 8 deletions source/ajit/backends/clang.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
(defmethod device-packed-by ((id Clang)) 4)
(defparameter *access* nil)
(defparameter *args* nil)
(defparameter *suffix* nil)
(defun args-p (id) (if (stringp id) (find (intern id) *args*) (find id *args*)))

(defun load-foreign-function (source &key (compiler "gcc") (lang "c") (compiler-flags))
Expand Down Expand Up @@ -177,8 +178,8 @@ Compiled with: ~a"
(let ((ref (render-isl-aref rhs :genid #'(lambda (x) (nth x *access*)))))
(if (string= ref "")
(if (args-p lhs)
(format nil "(*~(~a~))" lhs)
(format nil "~(~a~)" lhs))
(format nil "(*~(~a~)~a)" lhs (unroll-suffix rhs *suffix*))
(format nil "~(~a~)~a" lhs (unroll-suffix rhs *suffix*)))
(format nil "~(~a~)[~(~a~)]" lhs ref))))

(defmethod %render-expr ((lang Clang) (op (eql :INDEX-COMPONENTS)) lhs rhs z)
Expand Down Expand Up @@ -255,8 +256,9 @@ Compiled with: ~a"
(getattr node :_unrolled)
(list node)))
(let ((idx (getattr node :idx))
(args (map 'list #'(lambda (x) (r x)) (getattr node :args))))
(princ (%render-nodes kernel-lang (gethash idx (poly-pipeline polyhedral)) args indent) out))))))))))
(args (map 'list #'(lambda (x) (r x)) (getattr node :args)))
(*suffix* (getattr node :unroll-offsets)))
(princ (%render-nodes kernel-lang (gethash idx (poly-pipeline polyhedral)) args indent) out))))))))))

(defun ->cdtype (dtype)
(ecase dtype
Expand Down Expand Up @@ -297,8 +299,8 @@ Compiled with: ~a"
(let ((ref (render-isl-aref type :genid #'(lambda (x) (nth x access)))))
(if (string= ref "")
(if (args-p id)
(format nil "(*~(~a~))" id)
(format nil "~(~a~)" id))
(format nil "(*~(~a~)~a)" id (unroll-suffix type *suffix*))
(format nil "~(~a~)~a" id (unroll-suffix type *suffix*)))
(format nil "~(~a~)[~(~a~)]" id ref)))))
(loop with *access* = access
for node in (graph-nodes graph)
Expand All @@ -322,7 +324,15 @@ Compiled with: ~a"
(:WMMA
(multiple-value-bind (c a b) (apply #'values (node-reads node))
(multiple-value-bind (ct at bt) (apply #'values (relay-reads type))
(line "~(~a~) += ~(~a~) * ~(~a~);" (render-aref c ct) (render-aref a at) (render-aref b bt)))))
(line "~(~a~)~(~a~) += ~(~a~) * ~(~a~);"
(if (car (getattr node :declare-type))
(format nil "~a " (->cdtype (buffer-dtype ct)))
"")
(render-aref c ct) (render-aref a at) (render-aref b bt)))))
(:EXPR
(multiple-value-bind (at) (apply #'values (relay-writes type))
(line "~(~a~) = ~(~a~);" (render-aref (car (node-writes node)) at) (render-expr lang (getattr node :EXPR)))))))))))
(line "~(~a~)~(~a~) = ~(~a~);"
(if (car (getattr node :declare-type))
(format nil "~a " (->cdtype (buffer-dtype at)))
"")
(render-aref (car (node-writes node)) at) (render-expr lang (getattr node :EXPR)))))))))))
4 changes: 2 additions & 2 deletions source/ajit/caten.ajit.asd
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
(:file "scheduler")
(:file "isl-objects")
(:file "isl-ast-helpers")
(:file "transform")
(:file "graph")
(:file "render-graph")
(:file "kernel-info")
(:file "multiexpr")
(:file "transform")
(:file "memory-planner")
(:file "device")
(:file "renderer")
Expand Down
4 changes: 4 additions & 0 deletions source/ajit/device.lisp
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
(in-package :caten/ajit)
;;
;; Device. An abstraction class for the renderer.
;;
;;

(defclass Device () nil)
;; Configurations
Expand Down
55 changes: 46 additions & 9 deletions source/ajit/helpers.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,29 @@ should be used instead"
(push w seen)))
(reverse depends-on)))

(defun node-reads1 (node)
"Ignores the first read var of :EXPR"
(if (eql :EXPR (node-type node))
(cdr (node-reads node))
(node-reads node)))

(defun relay-reads1 (node)
"Ignores the first read var of :EXPR"
(if (eql :EXPR (node-type node))
(cdr (relay-reads (read-type-relay node)))
(relay-reads (read-type-relay node))))

(defun nodes-depends-on/buffers (nodes)
"Enumerates the unsolved buffer ids from the sched graph."
(declare (type list nodes))
(let ((seen `(t nil)) (depends-on))
(loop for node in nodes do
(loop for r in `(,@(node-reads node) ,@(getattr node :_loop_bound_nodes))
for typ in `(,@(relay-reads (read-type-relay node)) ,@(getattr node :_loop_bound_nodes_type))
(loop for r in `(,@(node-reads1 node) ,@(getattr node :_loop_bound_nodes))
for typ in `(,@(relay-reads1 node) ,@(getattr node :_loop_bound_nodes_type))
if (null (find r seen)) do
(when (symbolp r) (push (cons r typ) depends-on))
(push r seen))
(loop for read in (node-reads node) do
(loop for read in (node-reads1 node) do
(loop for shape in (buffer-reconstruct-view-args read)
if (null (find shape seen)) do
(push (cons shape (make-uconst-buffer)) depends-on)
Expand Down Expand Up @@ -421,8 +433,8 @@ in a single timestamp otherwise recursive dependencies will occur.
(every #'(lambda (node) (remattr node :_reads_old_for_multiexpr) node) (graph-nodes graph))
graph)

(defun optimize-non-in-place-buffers (base-avm avm refcounter graph seen verbose)
(declare (ignore refcounter))
(defun optimize-non-in-place-buffers (base-avm avm mp graph seen verbose kernel-args)
(declare (ignore mp))
(let* ((kernel-arg-symbols
(loop for node in (graph-nodes graph)
if (eql (node-type node) :JIT_KERNEL)
Expand All @@ -439,16 +451,30 @@ in a single timestamp otherwise recursive dependencies will occur.
collect k))
(extra-allocs
(loop for name in non-in-place-list
for node = (find name (graph-nodes (avm-graph avm)) :test #'find :key #'node-writes)
if node
for args = (loop for arg in kernel-args
if (eql (argument-name arg) name)
collect (argument-metadata arg))
if args
collect
(let* ((pos (position name (node-writes node)))
(typ (nth pos (relay-writes (read-type-relay node)))))
(let* ((typ
(if (= (length args) 1)
(car args)
(progn
;;(assert (every #'(lambda (x) (equal (buffer-orig-buffer-shape x) (buffer-orig-buffer-shape (car args)))) args))
(let ((s (find-if #'identity args :key #'buffer-orig-buffer-shape)))
(if (null s) ;; all tensors are contiguous
(car args)
(or
;; If there are multiple tensors, one is viewed, one is not viewed
;; Find out the original one and allocate it with the original shape.
(find (buffer-orig-buffer-shape s) args :key #'buffer-shape :test #'equal)
(error "Cannot determine the size of nested buffer! (It is a bug of Caten.)"))))))))
(make-node :Buffer :Allocate
(list name) (map 'list #'reveal-buffer `(,@(buffer-shape typ) ,@(buffer-stride typ)))
:nrank (buffer-nrank typ)
:dtype (buffer-dtype typ)
:_type_relay (make-inferred-type nil (list typ)))))))

(when verbose (format t "~%A number of buffers that failed to mutate in-place: ~a" (length extra-allocs)))
;; [TODO] Schedule to reuse the allocated buffer in non-in-place-list
;; Relocate to the most nearest
Expand Down Expand Up @@ -500,3 +526,14 @@ in a single timestamp otherwise recursive dependencies will occur.

(defun padding-list (list rank &key (with 0))
(append list (loop for i in (range 0 (- rank (length list))) collect with)))

(defun unroll-suffix (buffer unroll-offsets)
"Renders a suffix for a given number"
(flet ((f (idx) (find idx unroll-offsets :key #'car :test #'equalp)))
(apply
#'concatenate
'string
(loop for idx in (buffer-depend-idx-list buffer)
for unroll = (and idx (f idx))
if unroll
append (list "_" (format nil "~a" (cdr unroll)))))))
13 changes: 11 additions & 2 deletions source/ajit/isl-ast-helpers.lisp
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
(in-package :caten/ajit)
;; Translates From ISL_AST -> Lisp_AST
;; ~~ Lisp AST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
;; [isl-ast-helper.lisp]
;; Translator from ISL-AST to Lisp Object.
;; `Expr` is a slightly special, because the renderer also uses it.
;; other ASTs use it for ajit and eliminated in the compiling process
;; Expr is a lisp structure having following format:
;;
;; Expr(op, x, y, z)
;; op is a keyword indicating the object.
;; Expr is a very useful data structure, used for nested calculations, FOR and IF conditions, etc.
;; Identity can be verified by using expr-eq.
;; ~~ Lisp AST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defstruct (ASTBlock
(:constructor make-block (body)))
(body body :type list))
Expand Down
4 changes: 3 additions & 1 deletion source/ajit/isl-objects.lisp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
(in-package :caten/ajit)

;; isl-object.lisp
;; A helper to render ISL format from Lisp.
;; TODO: [Refactor] Replace this by caten/isl and delete this file.
(defmacro define-isl-object (print-name docstring ((&rest args) &rest slots) &body body)
(declare (type string print-name))
(let* ((name (intern (string-upcase print-name)))
Expand Down
7 changes: 7 additions & 0 deletions source/ajit/kernel-info.lisp
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
(in-package :caten/ajit)
;; kernel-info.lisp
;; A lisp-dumpable object for the compiled kernel, which is handled by AJIT.
;; Compiled kernels are finally represented as a node ":JIT_KERNEL"
;; Node[JIT_KERNEL] out1 out2 out3 <- jit_caller(out1 out2 out3)
;; where fname = function_name
;; jit-info = jit-info
;; ...
;; ~~ Fused Kernel Objects ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defstruct (JIT-Info)
(fname :nil :type keyword)
Expand Down
Loading

0 comments on commit 83e2933

Please sign in to comment.