-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
D2M Pass 4: Tensor Layout #2205
base: main
Are you sure you want to change the base?
Conversation
fe34ef2
to
d0fd266
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's off to a good start! Some comments inline
auto dpsOp = mlir::cast<DestinationStyleOpInterface>(op); | ||
assert(dpsOp.getNumDpsInits() == 1 && | ||
"Only one result tensor is supported for now"); | ||
dpsOp.getDpsInits()[0].setType(optimal_layout); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's legal to use set*
operations inside of a pattern rewriter. I think all graph modifications must happen inside of PatternRewriter
so these need to be wrapped in rewriter.modifyOpInPlace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrapped all uses of the set*
operations inside of modifyOpInPlace
callbacks.
auto result_encoding = | ||
mlir::dyn_cast_or_null<MetalLayoutAttr>(result_type.getEncoding()); | ||
assert(result_encoding && "Tensor type must have a MetalLayoutAttr encoding"); | ||
auto optimal_output_grid = getOptimalGrid( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably we need to assert that this layout's grid shape is all 1's since we're assuming here that the shard shape is the fully expanded physical/tiled shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the new check on line 82 should take care of this. We fail the rewrite if the grid size isn't 1x1.
mlir::cast<MetalLayoutAttr>(optimal_layout.getEncoding()).getGrid()); | ||
|
||
return failure(); // need some better way to exit cond. the rewriter than | ||
// always returning false! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return success if you want to commit your changes and then recursively run the pattern. Return failure if you don't want to commit (also signals to the rewrite driver that no update occur'd so no recursion needed on behalf of this invocation).
Typically you need to test at the top to see if your change has already been applied, esp in this case where you're not changing the op type so this pattern will be matched for every recursive step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the success/failure logic on most of the passes. Please take another look.
} | ||
if (grid_shape.size() == i + 1) { | ||
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think let's not worry about pad for now, we don't have a way of implementing that kind of gather pattern yet and we have more fundamental things to get first. Let's just pick the first largest grid that divides memref dim[I]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this interferes with our current padding support at all. In the case that the grid evenly divides the tensor, it will always pick that largest divisor. The padding here just safeguards us for odd shapes, and will essentially "round up" the tensor size to fit into tile aligned memrefs that are divisible by the grid. If this isn't what we want at this stage can you tell me a bit more about how we should handle odd tensor shapes for the interim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this interferes with our current padding support at all
I think this is what I'm getting at, we don't really have padding support anywhere else, so it's unclear if we're ready to lock in on this heuristic.
Nit we should use camelCase.
It might be written more succinctly:
for (size_t i = 0; i < memrefShape.size(); i++) {
int64_t dim = memrefShape[i];
int64_t gridDim = deviceGridShape[i];
int64_t shardDim = llvm::divideCeil(dim, gridDim);
gridShape.push_back(llvm::divideCeil(dim, shardDim));
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be clear, should I just assert out for shapes that are not tile aligned already, and then we can just divide the grid? Even if the tensor is tile aligned, take something like 19x1
tiles, should the grid shape just remain 1x1
because we can't pad out to a divisor of the grid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we'll have to align up to tile for sure. Yeah I was thinking just picking even divisors for now, which I understand is unfortunate for primes > 8. Maybe it doesn't matter either way, I guess it's fine if we do it this way for now, I think I didn't fully understand how the pad was working here at first, but I think we should be flexible to change it if we need to in the future.
f8b952c
to
e6f0d4b
Compare
e6f0d4b
to
654de15
Compare
} | ||
if (grid_shape.size() == i + 1) { | ||
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this interferes with our current padding support at all
I think this is what I'm getting at, we don't really have padding support anywhere else, so it's unclear if we're ready to lock in on this heuristic.
Nit we should use camelCase.
It might be written more succinctly:
for (size_t i = 0; i < memrefShape.size(); i++) {
int64_t dim = memrefShape[i];
int64_t gridDim = deviceGridShape[i];
int64_t shardDim = llvm::divideCeil(dim, gridDim);
gridShape.push_back(llvm::divideCeil(dim, shardDim));
}
assert(op->getResults().size() == 1 && | ||
"Only one result tensor is supported for now"); | ||
auto optimal_layout = getLocalLayout(op->getResult(0), rewriter, device); | ||
if (genericOp.getGridAttr() != GridAttr::get(rewriter.getContext()) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this condition ever true if the latter condition isn't?
mlir::dyn_cast_or_null<MetalLayoutAttr>(result_type.getEncoding()); | ||
assert(result_encoding && "Tensor type must have a MetalLayoutAttr encoding"); | ||
auto optimal_output_grid = getOptimalGrid( | ||
tensor.getContext(), result_encoding.getMemref().getShape(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think result_encoding.getMemref().getShape()
might be clearer as llvm::divideCeil(resultEncoding.getPhysicalShape(...), 32)
.
Ticket
#1908
Problem description
This is a complex problem. The tensor layout pass will evolve over time to make better and more efficient decisions.
What's changed
This initial iteration of the pass will do four things:
Rewrite the function return to add a
reinterpret_layout
to_layout
from whatever layout the final op produces, to the expected output of the original function. Insert an empty tensor op to hold this result, and rewrite the function return op.Rewrite the outputs of each ttir generic op to use the optimal tensor layout. The initial heuristic is as follows:
This section will also rewrite the worker grid attribute of the generic to match the output grid.
Rewrite the function operands to their optimal layouts. This does not change the function operands themselves, it only changes the uses of these operands to be references to
reinterpret_layout
to_layout
s and allocate the empty tensor ops required to store this transformation.Rewrite generic block memrefs to match the new tensor layouts.
Checklist