Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

D2M Pass 4: Tensor Layout #2205

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft

Conversation

jdesousa-TT
Copy link
Contributor

@jdesousa-TT jdesousa-TT commented Feb 18, 2025

Ticket

#1908

Problem description

This is a complex problem. The tensor layout pass will evolve over time to make better and more efficient decisions.

What's changed

This initial iteration of the pass will do four things:

  1. Rewrite the function return to add a reinterpret_layout to_layout from whatever layout the final op produces, to the expected output of the original function. Insert an empty tensor op to hold this result, and rewrite the function return op.

  2. Rewrite the outputs of each ttir generic op to use the optimal tensor layout. The initial heuristic is as follows:

    • Maximize the grid shape that the tensor will use
    • Minimize the padding needed to achieve this size
    • Never allow for entire cores to be working on padding alone

    This section will also rewrite the worker grid attribute of the generic to match the output grid.

  3. Rewrite the function operands to their optimal layouts. This does not change the function operands themselves, it only changes the uses of these operands to be references to reinterpret_layout to_layouts and allocate the empty tensor ops required to store this transformation.

  4. Rewrite generic block memrefs to match the new tensor layouts.

Checklist

  • New/Existing tests provide coverage for changes

@jdesousa-TT jdesousa-TT force-pushed the jdesousa/ttir-tensor-layout branch 2 times, most recently from fe34ef2 to d0fd266 Compare February 19, 2025 01:37
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's off to a good start! Some comments inline

auto dpsOp = mlir::cast<DestinationStyleOpInterface>(op);
assert(dpsOp.getNumDpsInits() == 1 &&
"Only one result tensor is supported for now");
dpsOp.getDpsInits()[0].setType(optimal_layout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's legal to use set* operations inside of a pattern rewriter. I think all graph modifications must happen inside of PatternRewriter so these need to be wrapped in rewriter.modifyOpInPlace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapped all uses of the set* operations inside of modifyOpInPlace callbacks.

auto result_encoding =
mlir::dyn_cast_or_null<MetalLayoutAttr>(result_type.getEncoding());
assert(result_encoding && "Tensor type must have a MetalLayoutAttr encoding");
auto optimal_output_grid = getOptimalGrid(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we need to assert that this layout's grid shape is all 1's since we're assuming here that the shard shape is the fully expanded physical/tiled shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the new check on line 82 should take care of this. We fail the rewrite if the grid size isn't 1x1.

mlir::cast<MetalLayoutAttr>(optimal_layout.getEncoding()).getGrid());

return failure(); // need some better way to exit cond. the rewriter than
// always returning false!
Copy link
Contributor

@nsmithtt nsmithtt Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return success if you want to commit your changes and then recursively run the pattern. Return failure if you don't want to commit (also signals to the rewrite driver that no update occur'd so no recursion needed on behalf of this invocation).

Typically you need to test at the top to see if your change has already been applied, esp in this case where you're not changing the op type so this pattern will be matched for every recursive step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the success/failure logic on most of the passes. Please take another look.

}
if (grid_shape.size() == i + 1) {
continue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think let's not worry about pad for now, we don't have a way of implementing that kind of gather pattern yet and we have more fundamental things to get first. Let's just pick the first largest grid that divides memref dim[I]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this interferes with our current padding support at all. In the case that the grid evenly divides the tensor, it will always pick that largest divisor. The padding here just safeguards us for odd shapes, and will essentially "round up" the tensor size to fit into tile aligned memrefs that are divisible by the grid. If this isn't what we want at this stage can you tell me a bit more about how we should handle odd tensor shapes for the interim?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this interferes with our current padding support at all

I think this is what I'm getting at, we don't really have padding support anywhere else, so it's unclear if we're ready to lock in on this heuristic.

Nit we should use camelCase.

It might be written more succinctly:

for (size_t i = 0; i < memrefShape.size(); i++) {
    int64_t dim = memrefShape[i];
    int64_t gridDim = deviceGridShape[i];
    int64_t shardDim = llvm::divideCeil(dim, gridDim);
    gridShape.push_back(llvm::divideCeil(dim, shardDim));
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be clear, should I just assert out for shapes that are not tile aligned already, and then we can just divide the grid? Even if the tensor is tile aligned, take something like 19x1 tiles, should the grid shape just remain 1x1 because we can't pad out to a divisor of the grid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we'll have to align up to tile for sure. Yeah I was thinking just picking even divisors for now, which I understand is unfortunate for primes > 8. Maybe it doesn't matter either way, I guess it's fine if we do it this way for now, I think I didn't fully understand how the pad was working here at first, but I think we should be flexible to change it if we need to in the future.

@jdesousa-TT jdesousa-TT force-pushed the jdesousa/ttir-tensor-layout branch 4 times, most recently from f8b952c to e6f0d4b Compare February 24, 2025 18:00
@jdesousa-TT jdesousa-TT force-pushed the jdesousa/ttir-tensor-layout branch from e6f0d4b to 654de15 Compare February 24, 2025 19:04
}
if (grid_shape.size() == i + 1) {
continue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this interferes with our current padding support at all

I think this is what I'm getting at, we don't really have padding support anywhere else, so it's unclear if we're ready to lock in on this heuristic.

Nit we should use camelCase.

It might be written more succinctly:

for (size_t i = 0; i < memrefShape.size(); i++) {
    int64_t dim = memrefShape[i];
    int64_t gridDim = deviceGridShape[i];
    int64_t shardDim = llvm::divideCeil(dim, gridDim);
    gridShape.push_back(llvm::divideCeil(dim, shardDim));
}

assert(op->getResults().size() == 1 &&
"Only one result tensor is supported for now");
auto optimal_layout = getLocalLayout(op->getResult(0), rewriter, device);
if (genericOp.getGridAttr() != GridAttr::get(rewriter.getContext()) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this condition ever true if the latter condition isn't?

mlir::dyn_cast_or_null<MetalLayoutAttr>(result_type.getEncoding());
assert(result_encoding && "Tensor type must have a MetalLayoutAttr encoding");
auto optimal_output_grid = getOptimalGrid(
tensor.getContext(), result_encoding.getMemref().getShape(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think result_encoding.getMemref().getShape() might be clearer as llvm::divideCeil(resultEncoding.getPhysicalShape(...), 32).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants