Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alias in-place buffers with matching fold factors and strides #604

Merged
merged 6 commits into from
Feb 18, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions builder/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ class in_place_aliaser : public stmt_mutator {
// The name of the allocation or output variable that this buffer is derived from.
var root;

std::vector<dim_expr> dims;
bool allow_alias = true;
};
// Tracks buffer symbols that are actually the same buffer.
Expand All @@ -796,8 +797,16 @@ class in_place_aliaser : public stmt_mutator {
}
}

bool fold_factors_strides_same(const std::vector<dim_expr>& dims_a, const std::vector<dim_expr>& dims_b) {
for (std::size_t ix = 0; ix < std::min(dims_a.size(), dims_b.size()); ++ix) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check needs to guarantee that the allocation being replaced isn't losing a stride/fold_factor constraint, which it might be if it has more dimensions than the alias target.

Basically, I think this check needs to be aware of which dims are the allocation and which are the alias target, and it needs to verify that all of the allocation strides/fold factors match. I also think maybe if the rank doesn't match, it shouldn't allow any stride/fold factor on the allocation (like HEAD).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the comparison function, PTAL.

if ((dims_a[ix].stride.defined() || dims_b[ix].stride.defined()) && !prove_true(dims_a[ix].stride == dims_b[ix].stride)) return false;
if ((dims_a[ix].fold_factor.defined() || dims_b[ix].fold_factor.defined()) && !prove_true(dims_a[ix].fold_factor == dims_b[ix].fold_factor)) return false;
}
return true;
}

void visit(const allocate* op) override {
auto set_buffer = set_value_in_scope(buffers, op->sym, {op->sym});
auto set_buffer = set_value_in_scope(buffers, op->sym, {op->sym, op->dims});
auto set_back = set_value_in_scope(backward, op->sym, var());
auto set_fwd = set_value_in_scope(forward, op->sym, var());
auto set_used = set_value_in_scope(use_count, op->sym, 0);
Expand All @@ -813,18 +822,12 @@ class in_place_aliaser : public stmt_mutator {
// problem with multiple uses is the buffer we use instead of this allocation might be bigger, and the other use
// needs those values missing from this allocation.
can_alias = false;
} else if (std::any_of(op->dims.begin(), op->dims.end(),
[&](const dim_expr& i) { return i.stride.defined() || i.fold_factor.defined(); })) {
// Don't alias if doing so could drop a stride or fold factor constraint.
// TODO: We could relax this check to allow aliasing if we know that the stride and fold factor of the buffer we
// are aliasing is the same.
can_alias = false;
}

if (can_alias && back && back->defined() && buffers.lookup(*back)) {
if (can_alias && back && back->defined() && buffers.lookup(*back) && fold_factors_strides_same(op->dims, buffers[*back]->dims)) {
forward.erase(*back);
set_result(crop_buffer::make(op->sym, *back, dims_bounds(op->dims), std::move(body)));
} else if (can_alias && fwd && fwd->defined() && buffers.lookup(*fwd)) {
} else if (can_alias && fwd && fwd->defined() && buffers.lookup(*fwd) && fold_factors_strides_same(op->dims, buffers[*fwd]->dims)) {
backward.erase(*fwd);
set_result(clone_buffer::make(op->sym, *fwd, std::move(body)));
} else if (!body.same_as(op->body)) {
Expand Down
Loading