Skip to content

Commit

Permalink
Merge pull request tensor-compiler#526 from zhang677/test-target
Browse files Browse the repository at this point in the history
Forall Context
  • Loading branch information
weiya711 authored Aug 2, 2022
2 parents cb00a90 + b0788bb commit 2b8ece4
Show file tree
Hide file tree
Showing 7 changed files with 435 additions and 183 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ lib/
*cmake_install.cmake
CMakeCache.txt
doc

.idea/
apps/tensor_times_vector/tensor_times_vector
53 changes: 53 additions & 0 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@ struct SuchThatNode;
class IndexExprVisitorStrict;
class IndexStmtVisitorStrict;

/// Describe the relation between indexVar sets of lhs and rhs in an Assignment node.
/// equal: lhs = rhs
/// none: lhs and rhs are mutually exclusive. And lhs and rhs are not empty sets.
/// lcr: rhs is a proper subset of lhs. (lhs contains rhs)
/// rcl: lhs is a proper subset of rhs. (rhs contains lhs)
/// inter: lhs and rhs share common elements but are not equal or empty. Some examples:
/// ```
/// // equal
/// ws(i1) += A(i1) // i1 is a child index node
/// ws(i) = A(i) // i is a parent index node
///
/// // none
/// ws(i1) += A(i) // i1 is a child of i
/// B_new(i) = B(i1)
///
/// // lcr
/// ws(i,k) = A(i) * B(i)
///
/// // rcl
/// ws(i) += A(i,k) * B(i,k)
///
/// // inter
/// ws(i,j) += A(i,k) * B(k,j)
/// ```
///
enum IndexSetRel {
equal, none, lcr, rcl, inter
};

/// Return true if the index statement is of the given subtype. The subtypes
/// are Assignment, Forall, Where, Sequence, and Multi.
template <typename SubType> bool isa(IndexExpr);
Expand Down Expand Up @@ -768,6 +797,18 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
IndexStmt assemble(TensorVar result, AssembleStrategy strategy,
bool separately_schedulable = false) const;

/// The wsaccel primitive specifies the dimensions of a workspace that will be accelerated.
/// Acceleration means adding compressed acceleration datastructures (bitmap, coordinate list) to a dense workspace.
/// shouldAccel controls whether acceleration will be applied.
/// When shouldAccel is true, if accelIndexVars is empty, then all dimensions should be accelerated.
/// When shouldAccel is true, if accelIndexVars is not empty, then dimensions in accelIndexVars will be accelerated.
/// When shouldAccel is false, accelIndexVars is ignored.
/// Currently, it only supports one-dimension acceleration. Acceleration is used by default.
///
/// Precondition:
/// Workspace can be accessed by the IndexVars in the accelIndexVars.
IndexStmt wsaccel(TensorVar& ws, bool shouldAccel = true,const std::vector<IndexVar>& accelIndexVars ={});

/// Casts index statement to specified subtype.
template <typename SubType>
SubType as() {
Expand Down Expand Up @@ -820,6 +861,9 @@ class Assignment : public IndexStmt {
/// Return the reduction index variables i nthe assign
std::vector<IndexVar> getReductionVars() const;

/// Return the set relation of indexVars in lhs and rhs
IndexSetRel getIndexSetRel() const;

typedef AssignmentNode Node;
};

Expand Down Expand Up @@ -1143,6 +1187,15 @@ class TensorVar : public util::Comparable<TensorVar> {
/// Gets the fill value of the tensor variable. May be left undefined.
const Literal& getFill() const;

/// Gets the acceleration dimensions
const std::vector<IndexVar>& getAccelIndexVars() const;

/// Gets the acceleration flag
bool getShouldAccel() const;

/// Set the acceleration dimensions
void setAccelIndexVars(const std::vector<IndexVar>& accelIndexVars, bool shouldAccel);

/// Set the fill value of the tensor variable
void setFill(const Literal& fill);

Expand Down
87 changes: 87 additions & 0 deletions src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2048,6 +2048,32 @@ IndexStmt IndexStmt::assemble(TensorVar result, AssembleStrategy strategy,
return transformed;
}

IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector<IndexVar>& accelIndexVars) {
if (accelIndexVars.size() == 0) {
ws.setAccelIndexVars(accelIndexVars, shouldAccel);
return *this;
}
set<IndexVar> TempVars;
match(*this,
std::function<void(const WhereNode*,Matcher*)>([&](const WhereNode* where,Matcher* ctx) {
auto Temp = getResultAccesses(where->producer).first[0];
if (Temp.getTensorVar() == ws) {
for (auto i :getIndexVars()){
TempVars.insert(i);
}
}
ctx->match(where->producer);
ctx->match(where->consumer);
}));
for (auto i : accelIndexVars) {
if (TempVars.find(i) == TempVars.end()) {
taco_uerror << "No matching indexVars in the Accel";
}
}
ws.setAccelIndexVars(accelIndexVars, shouldAccel);
return *this;
}

std::ostream& operator<<(std::ostream& os, const IndexStmt& expr) {
if (!expr.defined()) return os << "IndexStmt()";
IndexNotationPrinter printer(os);
Expand Down Expand Up @@ -2102,6 +2128,50 @@ std::vector<IndexVar> Assignment::getReductionVars() const {
return reductionVars;
}

IndexSetRel Assignment::getIndexSetRel() const {
vector<IndexVar> freeVars = getLhs().getIndexVars();
set<IndexVar> lseen(freeVars.begin(), freeVars.end());
vector<IndexVar> RVars ;
match(getRhs(),
std::function<void(const AccessNode*)>([&](const AccessNode* op) {
for (auto& var : op->indexVars) {
RVars.push_back(var);
}
}));
set<IndexVar> rseen(RVars.begin(), RVars.end());
IndexSetRel rel = equal;
std::vector<IndexVar> v_inter;
int lnum = lseen.size();
int rnum = rseen.size();
int rcl_num = 0;
for (auto & var : rseen){
if (util::contains(lseen, var)) {
rcl_num += 1;
}
}
if (rcl_num == 0) {
rel = none;
}
else if ((rcl_num<lnum) && (rcl_num == rnum)){
rel = lcr;
}
else if ((rcl_num<lnum) && (rcl_num<rnum)){
rel = inter;
} else if ((rcl_num == lnum) && (rcl_num == rnum)){
rel = equal;
} else if ((rcl_num == lnum) && (rcl_num<rnum)) {
rel = rcl;
}
else {
rel = none;
}

if (lnum == 0 && rel == none) {
rel = rcl;
}
return rel;
}

template <> bool isa<Assignment>(IndexStmt s) {
return isa<AssignmentNode>(s.ptr);
}
Expand Down Expand Up @@ -2476,6 +2546,8 @@ struct TensorVar::Content {
Format format;
Schedule schedule;
Literal fill;
std::vector<IndexVar> accelIndexVars;
bool shouldAccel;
};

TensorVar::TensorVar() : content(nullptr) {
Expand Down Expand Up @@ -2508,6 +2580,8 @@ TensorVar::TensorVar(const int& id, const string& name, const Type& type, const
content->type = type;
content->format = format;
content->fill = fill.defined()? fill : Literal::zero(type.getDataType());
content->accelIndexVars = std::vector<IndexVar> {};
content->shouldAccel = true;
}

int TensorVar::getId() const {
Expand Down Expand Up @@ -2551,6 +2625,19 @@ const Literal& TensorVar::getFill() const {
return content->fill;
}

const std::vector<IndexVar>& TensorVar::getAccelIndexVars() const {
return content->accelIndexVars;
}

bool TensorVar::getShouldAccel() const {
return content->shouldAccel;
}

void TensorVar::setAccelIndexVars(const std::vector<IndexVar>& accelIndexVars, bool shouldAccel) {
content->shouldAccel = shouldAccel;
content->accelIndexVars = accelIndexVars;
}

void TensorVar::setFill(const Literal &fill) {
content->fill = fill;
}
Expand Down
Loading

0 comments on commit 2b8ece4

Please sign in to comment.