Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Roadmap] Sparse TensorIR #466

Open
6 of 17 tasks
yzh119 opened this issue Sep 6, 2021 · 20 comments
Open
6 of 17 tasks

[Roadmap] Sparse TensorIR #466

yzh119 opened this issue Sep 6, 2021 · 20 comments
Labels

Comments

@yzh119
Copy link
Collaborator

yzh119 commented Sep 6, 2021

Design: #368

  1. Sparse block lowering. (transformation)
    1. Sparse/Dense coordinates transformation. (@MasterJH5574 WIP)
      1. The same sparse iterator viewed in different sparse buffers.
    2. + block body
    3. tir.reads/writes for sparse buffers.
    4. convert SparseBuffers into normal Buffers
    5. convert SparseBlocks into normal Blocks
    6. ...
  2. Python-side frontend.
    1. Python-binding of sparse buffers/blocks/formats. (@yzh119 WIP)
    2. Specialize sparse formats.
    3. ...
  3. TIR intrinsics.
    1. lower_bound (@yzh119 WIP)
      1. Return offset or value (prefer offset)
    2. boundary check
  4. Examples
    1. End-to-end BSR SpMM/SDDMM example and optimizations (@yzh119 )
    2. Sparse Softmax.

\infy: tensorization.

@junrushao junrushao changed the title [Checklist] Sparse TIR todo list [Roadmap] Sparse TensorIR Sep 10, 2021
@tqchen
Copy link
Contributor

tqchen commented Sep 11, 2021

Thanks @yzh119 . can you paste a TVMScript example mockup of the sparse TensorIR, and transformation allowed? This would help us greatly understand the relation between the current design rationale and the new one.

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 11, 2021

Yes I'll elaborate more here.

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 15, 2021

Sparse Formats

In Sparse TIR we have four kinds of axis.

  1. Fixed Sparse.
  2. Variable Sparse.
  3. Fixed Dense
  4. Variable Dense.

They can represent both sparse matrices and ragged tensors.
The length of fixed axis is a constant while the length of variable axis depends on its coordinates.
We use format tree to describe the dependency among axes.

For example, if we want to represent an irregular batched matrix multiplication:

For i in range(b):
    For j in range(n[i]):
        For k in range(m[i]):
            For l in range(k[i]):
                 C[i, j, k] = C[i, j, k] + A[i, j, l] * B[i, l, k]

Its dependent tree is:

     i
    /|\
   j k l

The following syntax describe how do we define such structure in Sparse TIR.

i = tir.sp.FixedDenseAxis(b)
j = tir.sp.VariableDenseAxis(N)
k = tir.sp.VariableDenseAxis(M)      (where N, M, K are input 1-dim buffers in this case)
l = tir.sp.VariableDenseAxis(K)
Fmt = tir.sp.format([[i, j], [i, k], [i, l]], (i, j, k, l))
A = tir.sp.match_buffer(A_handle, fmt, (i, j, l))
B = tir.sp.match_buffer(B_handle, fmt, (i, j, k))
C = tir.sp.match_buffer(C_handle, fmt, (i, k, l))

Sparse Blocks

A sparse block indicates

  1. the axes involved in the block
  2. whether these axes are used as spatial/reduce axis
  3. how do we reorder/fuse these axes.
with tir.sp.block([i, j, k], [spatial, spatial, reduce], [[0, 1], [2]]) as [vi, vj, vk]:
    pass

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 25, 2021

After discussion with @MasterJH5574 and @tqchen , we decide to update the syntax as following:

def sddmm(a: ty.handle, b: ty.handle, c: ty.handle, fmt: ty.handle):
    N = tir.var('n')
    M = tir.var('m')
    B = tir.var('b')  # block size
    K = tir.var('k')
    i = tir.match_axis(fmt, 'i', N)
    j = tir.match_axis(fmt, 'j', M)
    k = tir.match_axis(fmt, 'k', K)
    bi = tir.match_axis(fmt, 'bi', B)
    bj = tir.match_axis(fmt, 'bj', B)
    A = tir.match_buffer(a, (i, bi, k), 'float32')
    B = tir.match_buffer(b, (tir.to_dense(j), bj, k), 'float32')
    C = tir.match_buffer(c, (i, j, bi, bj), 'float32')
    for vi, vj in tir.fuse(tir.cord(i), tir.cord(j)):
        for vk in tir.cord(k):
            for vbi in tir.cord(bi):
                for vbj in tir.cord(vbj):
                    with tir.block([], 'sddmm'):
                        tir.block_attr({'sparse': True})
                        with tir.init():
                            C[vi, vj, vbi, vbj] = 0.
                        C[vi, vj, vbi, vbj] = C[vi, vj, vbi, vbj] +\
                            A[vi, vbi, vk] * B[vj, vbj, vk]

where tir.match_axis, tir.pos, tir.cord, tir.fuse are newly introduced keywords for sparse support in TIR.

Below we describe detailed syntax of our new design.

Format Definition

We write format definition in Python outside out TIR scripts:

fmt = tir.format(
    {
         "i": (tir.kDenseFixed, None),
         "j": (tir.kSparseVariable, "i")
    }
)

We specify the format via a Python dictionary:

  1. key indicates the axis name.
  2. value is a tuple where the first element indicate the axis property and the second element refers to its parent axis.

each axis has no more than one parent.

Sparse Tensor Declaration

Sparse Tensor is declared in Python as well.

indptr = [...]
indices = [...]
a = tir.sparse.tensor(
    data,
    indptr,
    indices,
)

where data is a tvm runtime array, indptr and indices are list of tvm runtime arrays, and they both have inherent data types.

Sparse Support in TIR scripts.

match_axis

Matches axis defined in format and bind it with attributes such as length and number of columns (for ellpack-like formats).

match_buffer

Note that we need to support match buffer given

pos/cord

pos iterate our non-zero elements in the given axis, and cord iterate our all elements in a given axis (from 0 to length).

fuse

fuse several sparse/dense axes, note that the indices conversion is not trivial affine maps so we use the keyword fuse to indicate such fusion.

to_dense

Convert sparse axis to dense ones.

block with sparse attribute

block we sparse attribute need to be lowered via a schedule primitive called lower_sparse.
No dense schedules are allowed to apply inside a block if there is a sparse attribute.

Schedules

lower_sparse(loops)

lower sparse do coordinate conversion according to sparse formats for a specific loop/loops.
Insert proper block isolation is there are dependency between two axis.

reorder(loops)

Reorder loops and check whether they violate format constraints.

fuse(loops)

Fuse loops and check whether they violate format constraints.

@MasterJH5574
Copy link
Collaborator

@yzh119 Thanks for the update! I'd like propose another design of formats as follows:

## Format proposal F2
fmt = tir.format({
  "name1": tir.DenseFixedAxis(),
  "name2": tir.DenseVariableAxis(name),  # `name` is the name of the axis it depends on
  "name3": tir.SparseFixedAxis(n_col),
  "name4": tir.SparseVariableAxis(),
})

The main difference between this design and the above format design is that in this design, we treat different axis kinds in different ways. It's reasonable because:

  • Only a dense-variable axis depends on some other axis, as the length of the axis is determined by the iterator value of the axis it depends on. Sparse-variable and sparse-fixed axes don't depend on other axes, because their lengths are known at compile-time. To be clear, sparse-variable iterators depend on other iterators, but sparse-variable axes do not depend on other axes. I will explain this point in the proposal of SparseIterator.
  • To define a sparse-fixed axis, we should provide it's n_col field.

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 25, 2021

Could you elaborate more on sparse-variable axis do not depend on other axes? I don't get the point.

n_col could be specified in match_axis IMO.

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 25, 2021

@MasterJH5574 A second thought on our proposal of tir.cord(axis, reduction=True), I think it better to bind reduction information to block instead of loops.

Because several blocks might share the same loop iterator, but one might view it as parallel axis and another one view it as reduction axis.

@MasterJH5574
Copy link
Collaborator

@yzh119 Here are more thoughts, which were send to the slack channel before 👀.

  1. The sddmm example is incomplete - we should provide the format so that the example is more understandable.
  2. In tir.match_buffer(...) we sometimes use tir.to_dense(axis) to convert a sparse-variable axis to the dense one. IMO we should have tir.to_sparse_variable(axis) and tir.to_sparse_fixed(axis, n_col) as well.
  3. We should provide some examples which contain dense-variable axis and sparse-fixed axis.
  4. As we use tir.cord(axis) and tir.pos(axis) to define iterators, I may need to tweak the detailed design of SparseIterator so that each iterator stores only the necessary information, leaving out some redundant information. Not a big deal. I'm going to do it today.

@MasterJH5574
Copy link
Collaborator

I think it better to bind reduction information to block instead of loops.
Because several blocks might share the same loop iterator, but one might view it as parallel axis and another one view it as reduction axis.

Yes you're right. It's possible that different blocks view a loop var differently. But in the above design we use opaque blocks (blocks that have no block iter) for sparse block. Therefore it's not very convenient to represent the reduction information in block signatures.

One possible design is to add some block iters, and one block iter in a sparse block is required to be bound to only one sparse iterator. An example might be like

for vi, vj in T.fuse(tir.cord(i), tir.cord(j)):
    for vk in T.cord(k):
        for vbi in T.cord(bi):
            for vbj in T.cord(vbj):
                with T.block('sddmm'):
                    T.block_attr({'sparse': True})
                    vi_ = T.sparse_axis(vi)
                    vj_ = T.sparse_axis(vj)
                    vk_ = T.sparse_axis(vk, reduction=True)
                    vbi_ = T.sparse_axis(vbi)
                    vbj_ = T.sparse_axis(vbj)
                    with T.init():
                        C[vi_, vj_, vbi_, vbj_] = 0.
                    C[vi_, vj_, vbi_, vbj_] = C[vi_, vj_, vbi_, vbj_] + A[vi_, vbi_, vk_] * B[vj_, vbj_, vk_]

(Just an example. The API of "sparse_axis" can change.)

Perhaps we should wait for the block iter/binding refactor in TVM script before converging to a detailed design.

@MasterJH5574
Copy link
Collaborator

n_col could be specified in match_axis IMO.

Oh yes I agree. So match_axis might have variable numbers of arguments.


Could you elaborate more on sparse-variable axis do not depend on other axes? I don't get the point.

To be more specific:

  • What we really concern is the sparse iterator dependency, as some schedule primitives like Fuse and Reorder rely on the iterator dependency - if the iterator dependency requirement isn't satisfied, the primitive cannot be applied.
  • Therefore, the key problem is to know that for a given iterator, what iterators it depends on.
  • As discussed in a meeting, dense-fixed and sparse-fixed iterators don't depend on any other iterator. A dense-variable iterator depends on a specific iterator, and a sparse-variable iterator depends on a bunch of other iterators. Specifically,
    • for dense-variable axes, consider the batch matrix multiplication workload. In BMM, the second and third axes of A and B are all dense-variable axes, and iterator vi, vj, vk all depend on iterator vb. Why we know that vi, vj, vk depend on iterator vb? Because we defined that relationship in the format as "the second and third axes of A and B depend on the first axis". Note that this relationship can be stored nowhere but in the format. This is why we must store the axis being depended when define a dense-variable axis in format.

      ## BMM example (not written in sparse TIR)
      for vb in range(0, B):
        for vi in range(0, M[vb]):
          for vj in range(0, N[vb]):
            C[vb, vi, vj] = 0.0
            for vk in range(0, K[vb]):
              C[vb, vi, vj] += A[vb, vi, vk] * B[vb, vk, bi]
    • for sparse-variable axes, suppose A is a buffer, and its i-th axis is a sparse-variable one. Now if a sparse-variable iterator vi goes over this axes, its extent relies on the value of the iterators that go over the first i - 1 axes. For example, in the code below we say vi depends on v1, ..., vi-1. How do we know this? By seeing the BufferLoad A[v1, v2, ..., vi-1, vi], we can immediately know the iterators depended by vi. As a result, we don't need to store the relationship information in format. This is why I say a sparse-variable axis never depend on other axes.

      ## A example that is not rigorous
      for v1 in range(0, l1):
        for v2 in range(0, l2):
          ...
          for vi-1 in range(0, li-1):
            for vi in range(0, li):
              ...
              A[v1, v2, ..., vi-1, vi, ...] = ...

I don't know whether my explanation could convince you. I'll post the detailed proposal for sparse iterators and the iterator dependency soon.

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 26, 2021

@MasterJH5574 , I'm thinking of the sparse softmax example:

if we write it in the opaque block, the program would look like this (it uses the minus max element trick to avoid overflow):

for vi in tir.cord(i):
    for vj in tir.pos(j, reduction=True):
        with tir.block('A_max'):
            tir.block_attr({'sparse': True})
            with tir.init():
                A_max[vi] = tir.const(-float("inf"), "float32")
            A_max[vi] = tir.max(A_max[vi], A[vi, vj])
    for vj in tir.pos(j):
        with tir.block('A_minus_exp'):
            tir.block_attr({'sparse': True})
            A_exp[vi, vj] = tir.exp(A[vi, vj] - A_max[vi])
    for vj in tir.pos(j, reduction=True):
        with tir.block():
            tir.block_attr({'sparse': True})
            with tir.init('Z'):
                Z[vi] = 0.
            Z[vi] = Z[vi] + A_exp[vi. vj]
    for vj in tir.pos(j):
        with tir.block('out'):
            tir.block_attr({'sparse': True})
            out[vi, vj] = A_exp[vi, vj] / Z[vi]

However, we can fuse Z block and A_minus_exp block via compute_inline.

It would look more natural if we move the reduction attribute to blocks:

for vi in tir.cord(i):
    for vj in tir.pos(j):
        with tir.block(name='A_max', sparse_axes=[vi, tir.reduce_axis(vj)]):
            with tir.init():
                A_max[vi] = tir.const(-float("inf"), "float32")
            A_max[vi] = tir.max(A_max[vi], A[vi, vj])
    for vj in tir.pos(j):
        with tir.block(name='A_minus_exp', sparse_axes=[vi, vj]):
            A_exp[vi, vj] = tir.exp(A[vi, vj] - A_max[vi])
    for vj in tir.pos(j):
        with tir.block(name='Z', sparse_iters=[vi, tir.reduce_axis(vj)]):
            with tir.init():
                Z[vi] = 0.
            Z[vi] = Z[vi] + A_exp[vi. vj]
    for vj in tir.pos(j):
        with tir.block(name='out', sparser_iters=[vi, vj]):
            out[vi, vj] = A_exp[vi, vj] / Z[vi]

Then the block Z and A_max would be under the same loops, we can use compute_at:

for vi in tir.cord(i):
    for vj in tir.pos(j):
        with tir.block(name='A_max', sparse_iters=[vi, tir.reduce_axis(vj)]):
            with tir.init():
                A_max[vi] = tir.const(-float("inf"), "float32")
            A_max[vi] = tir.max(A_max[vi], A[vi, vj])
    for vj in tir.pos(j):
        with tir.block(name='Z_A_minus_exp', sparse_iters=[vi, tir.reduce_axis(vj)]):
            with tir.init():
                Z[vi] = 0.
            Z[vi] = Z[vi] + tir.exp(A[vi, vj] - A_max[vi])
    for vj in tir.pos(j):
        with tir.block(name='out', sparse_iters=[vi, vj]):
            out[vi, vj] = tir.exp(A[vi, vj] - A_max[vi]) / Z[vi]

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 26, 2021

For sparse variable axis:

I think you assume that we record axis dependency information in sparse tensors. But what about we don't do so?
By that I mean we don't store axis dependency tree in Sparse Tensors but only stores the list of indptrs and indices, and only match the buffer in TIR scripts.

@MasterJH5574
Copy link
Collaborator

SparseIterVar Proposal 1

This post propose the design of SparseIterVars, which behaves similar to loop vars in TIR.

Note that the class name can change. Possible names are "SparseIterVar", "SparseIterator", "SparseLoopVar".

There are two main types of SparseIterVar: BasicSparseIterVar and FusedSparseIterVar. I'll elaborate them respectively.

BasicSparseIterVar

BasicSparseIterVar represents the basic ways we iterate. Just like SparseAxis which has four types, there are also four kinds of BasicSparseIterVar.

Definition

  • We use dense-fixed BasicSparseIterVar to iterate over dense values from 0 to a specific length. A dense-fixed BasicSparseIterVar only consists of an integer field length.
  • We use dense-variable BasicSparseIterVar to iterate over dense values with regard to a given dense-variable axis. A dense-variable BasicSparseIterVar consists of a dense-variable SparseAxis field axis which contains the other SparseAxis that the dense-variable axis depends on, which enables the compiler to figure out which SparseIterVar the dense-variable SparseIterVar depends on when accessing SparseBuffers.
  • We use sparse-fixed BasicSparseIterVar to iterate over a fixed number of non-zero values with regard to a given sparse-fixed axis. A sparse-fixed BasicSparseIterVar consists of a sparse-fixed SparseAxis field axis.
  • We use sparse-variable BasicSparseIterVar to iterate over non-zero values with regard to a given sparse-variable axis. A sparse-variable BasicSparseIterVar consists of a sparse-variable SparseAxis field axis.

How to define a BasicSparseIterVar in TVM script

As described in @yzh119's post, users can define a BasicSparseIterVar via tir.cord(axis) or tir.pos(axis) where axis is a SparseAxis. According to the type of axis, tir.cord(axis) and tir.pos(axis) return different kinds of BasicSparseIterVar. The rules are as follows:

  • tir.cord(axis) means iterating the dense values along axis. Therefore,
    • when axis is dense-fixed, sparse-fixed and sparse-variable, the returned BasicSparseIterVar is dense-fixed, because a SparseAxis of such types has a fixed length;
    • when axis is dense-variable, the returned BasicSparseIterVar is dense-variable.
  • tir.pos(axis) means iterating the non-zero values with regard to axis. Therefore,
    • when axis is dense-fixed, the returned BasicSparseIterVar is dense-fixed, because every value along this axis is possible to be non-zero;
    • when axis is dense-variable, the returned BasicSparseIterVar is dense-variable for the same reason;
    • when axis is sparse-fixed, the returned BasicSparseIterVar is sparse-fixed;
    • when axis is sparse-variable, the returned BasicSparseIterVar is sparse-variable.

FusedSparseIterVar

A FusedSparseIterVar consists of an ordered array of BasicSparseIterVar, meaning that this FusedSparseIterVar is generated by fusing all BasicSparseIterVars in order.

We can use tir.fuse(tir.cord(i), tir.cord(j)) to define a FusedSparseIterVar. It means we fuse two BasicSparseIterVars created by tir.cord(i) and tir.cord(j) first and then fuse them, yielding the result FusedSparseIterVar.

Note that we never expose FusedSparseIterVars to users in TVM script. As the example below (same as @yzh119's example), in frontend we only and always expose BasicSparseIterVars, not letting users notice the existence of FusedSparseIterVars. In the example users only know that vi and vj were fused. They don't know Instead, the FusedSparseIterVars will be created in backend which is used for SparseTIR's lowering process.

for vi, vj in tir.fuse(tir.cord(i), tir.cord(j)):
    for vk in tir.cord(k):
        for vbi in tir.cord(bi):
            for vbj in tir.cord(vbj):
                with tir.block([], 'sddmm'):
                    tir.block_attr({'sparse': True})
                    with tir.init():
                        C[vi, vj, vbi, vbj] = 0.
                    C[vi, vj, vbi, vbj] += A[vi, vbi, vk] * B[vj, vbj, vk]

The lowering rules for SparseIterVars and SparseBuffer access will be posted in another proposal in the future.


@yzh119 You can take a look. Although it's super long 🤦‍♂️.

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 26, 2021

I'm okay about the SparseIterVar and FusedSparseIterVar, just want to confirm if our SparseIterVar works for all four kinds of axes?

@MasterJH5574
Copy link
Collaborator

Just want to confirm if our SparseIterVar works for all four kinds of axes?

Yes. A dense-fixed SparseIterVar contains a integer length, and all other kinds of SparseIterVar have an axis pointer which points to a SparseAxis with corresponding kind.

In buffer access, the only constraint is that dense-variable SparseIterVars are only allowed to index a dense-variable axis. There's no constraint for SparseIterVars of the other three kinds.

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 27, 2021

Proposal: gradual lowering of sparse iteration loops

Had some discussion w/ @MasterJH5574 on the possibility for gradually lowering sparse iteration loops.
I have some ideas about this: it's not easy to do index rewriting if we directly manipulate the data buffers. For gradual index rewriting, we can borrow the idea of subregion match, and create many intermediate buffers to simplify the rewriting process.

Example

Below is an example of my proposal, let's assume we are trying to lower a program that reduces the last dimension of a sparse tensor.

Original code:

I = tir.match_axis(fmt, "I")
J = tir.match_axis(fmt, "J")
K = tir.match_axis(fmt, "K")
A = tir.match_buffer(a, (I, J, K), "float32", "int32")
B = tir.match_buffer(b, (I, J), "float32", "int32")

for i in tir.cord(I):
    for j in tir.pos(J):
        for k in tir.pos(K):
            with tir.block(name='reduction', sparse_iters=[i, j, tir.reduction(k)]):
                with tir.init():
                    B[i, j] = 0.
                B[i, j] = B[i, j] + A[i, j, k]

where the fmt is created via

fmt = tir.sparse.fomat(
    "i": tir.DenseFixedAxis(),
    "j": tir.SparseVariableAxis("i"),
    "k": tir.SparseVariableAxis("j")
)

Then we lower the program itervar by itervar:

The first step is to sch.lower_sparse(i):

for i in tir.grid(I.length):
    with tir.block([I.length]) as vi:
        tir.bind(i, vi)
        J_i = tir.match_axis(J[vi])
        K_i = tir.match_axis(K[vi])
        B_i = tir.match_buffer(B[vi], (J_i))
        A_i = tir.match_buffer(A[vi], (J_i, K_i))
        for j in tir.pos(J):
            for k in tir.pos(K):
                with tir.block(name='reduction', sparse_iters=[j, tir.reduction(k)]):
                    with tir.init():
                        B_i[j] = 0.
                    B_i[j] = B_i[j] + A_i[j, k]

Then sch.lower_sparse(j):

for i in tir.grid(I.length):
    with tir.block([I.length]) as vi:
        tir.bind(vi, i)
        J_i = tir.match_axis(J[vi])
        K_i = tir.match_axis(K[vi])
        B_i = tir.match_buffer(B[vi], (J_i))
        A_i = tir.match_buffer(A[vi], (J_i, K_i))
        for j in tir.grid(J_i.length):
            with tir.block([J_i.length]) as vj:
                tir.bind(vj, j)
                K_i_j = tir.match_axis(K_i[vj])
                A_i_j = tir.match_buffer(A_i[vj], (K_i_j,))
                B_i_j = tir.match_buffer(B_i[vj], (1,))
                for k in tir.pos(K):
                    with tir.block(name='reduction', sparse_iters=[vj, tir.reduction(vk)]):
                        with tir.init():
                            B_i[j] = 0.
                        B_i[j] = B_i[j] + A_i[j, k]

Then sch.lower_sparse(k):

for i in tir.grid(I.length):
    with tir.block([I.length]) as vi:
        tir.bind(vi, i)
        J_i = tir.match_axis(J[vi])
        K_i = tir.match_axis(K[vi])
        B_i = tir.match_buffer(B[vi], (J_i))
        A_i = tir.match_buffer(A[vi], (J_i, K_i))
        for j in tir.grid(J_i.length):
            with tir.block([J_i.length]) as vj:
                tir.bind(vj, j)
                K_i_j = tir.match_axis(K_i[vj])
                A_i_j = tir.match_buffer(A_i[vj], (K_i_j,))
                B_i_j = tir.match_buffer(B_i[vj], (1,))
                for k in tir.grid(K_i_j.length):
                    with tir.block([tir.reduce_axis((0, K_i_j.length))]) as vk:
                        tir.bind(vk, k)
                        A_i_j_k = tir.match_buffer(A_i_j[vk], (1,))
                        with tir.init():
                            B_i_j[0] = 0.
                        B_i_j[0] = B_i_j[0] + A_i_j_k[0]

What did A_i = tir.match_buffer(A[vi], I) do?

It basically assign pointer A[I.indptr[vi]] to A_i.

@MasterJH5574
Copy link
Collaborator

@yzh119 Another question occurs to me about the gradual lowering. Is it possible to lower a SparseIterVar that was fused into a FusedSparseIterVar before? For example, in the code below

for vi, vj in tir.fuse(tir.cord(i), tir.cord(j)):
    blabla

is it possible to only lower vi without impacting vj? I tend to say no, but haven't thought deeply yet.

@yzh119
Copy link
Collaborator Author

yzh119 commented Sep 28, 2021

No, IMO we can only lower vi,vj together in this case.

@MasterJH5574
Copy link
Collaborator

Discussion & Proposals: Output Buffer's indptr and indices

For workloads whose output buffer contains sparse axes, we cannot know the axes's indptr and indices at compile time, since they are determined by the input buffers. So the main task is to figure out what the indptr and indices are, after the compile time.

TACO's Method

TACO uses a two-round mechanism to solve the problem. The first round of execution is called "assembly", which collects the coordinates that appear as non-zero elements in the output buffer and then allocate the output buffer's indptr and indices. The second round is the real computation, which directly makes use of the output buffer's indptr and indices computed in the assembly round.

However, TACO's method may not be suitable for SparseTIR. Some reasons might be:

  • The way TACO infers indptr and indices is hard to implement on GPUs.
  • For the workloads that need inference (SpGEMM, for example), even though the workload is parallelizable, the compute intensity is low per warp and warp divergency is high.
  • GPU doesn't quite support dymanic memory allocation.

Therefore we decide not to adopt TACO's mathod.


Instead, we only support the workloads that don't need inference. In such workloads, a sparse axis in the output buffer must match a sparse axis (of the same kind) of some input buffer. The two axes have the same indptr and indices.

Proposal 1: User-Specified Axis Matching

To do this, one previous thought is to let users manually specify the axis matching (i.e., let users tell the compiler which sparse axis in the output buffer matches which axis in the input buffer). Then the compiler generates code that assigns values to the output buffer's indptr and indices according to the user-specified matching.

However, due to the inability to allocate dynamic memories (as mentioned above), we need to allocate the output buffer's indptr and indices outside the kernel. Note that this allocation requires to know the size of indptr and indices, and thus requires users' prior knowledge on the sizes.

Proposal 2: Believe-in-User

Another way is simpler: let users to create indices and indptr correctly (including assign them the right values). Then we just believe in the users that the indices and indptr they pass in are all correct. In this way we can just use the output buffer's indptr and indices as the way we use the input buffer's.

In essense, this method leaves all the problems to users: users are responsible to generate the output buffer's indptr and indices with correct sizes and correct values, which increases the burden on users. But the believe-in-user design ease the task of the compiler - the compiler won't need to generate the code that assigns values to the output buffer's indptr and indices.

Summary

In a nutshell,

  • both methods require users' prior knowledge on the size of indptr and indices (and also, data);
  • in proposal 1, the users' task is relatively simple, while the compiler's task is hard;
  • in proposal 2, the users' task is hard, while the compiler's task is quite simple.

Personally I prefer the second proposal. Because I think it's still acceptable for users to create the correct indptr and indices for the output buffer - after all, they are responsible to create the input buffers correctly, and therefore asking them to create one more buffer won't be too hard.

CC @yzh119. What opinions do you have?

@MasterJH5574
Copy link
Collaborator

Proposal: Buffer Access Lowering

Recall of Axis/Iterator Design

According to our design:

  • There are four kinds of SparseAxis and four kinds of BasicSparseIterVar.
  • Dense-variable axes and sparse-variable axes have their corresponding indptrs.
  • Sparse-fixed axes and sparse-variable axes have their corresponding indicess.
  • Dense-variable axes has their corresponding runtime-determined lengthss.

SparseBufferLoad & SparseBufferStore

Like BufferLoad and BufferStore, we now introduce SparseBufferLoad and SparseBufferStore to represent the read/write access to SparseBuffers.

  • SparseBufferLoad is a derived class of PrimExpr, which contains a SparseBuffer as the sourse of the read access, and contains an array of PrimExpr as the access indices.
  • SparseBufferStore is a derived class of Stmt, containing a SparseBuffer as the target of the write access, a value to be stored, and an array of PrimExpr as the indices.

Buffer Access Lowering

This section is about the method we convert SparseBufferLoad/SparseBufferStore to BufferLoad/BufferStore when lowering a whole SparseTIR to a normal TIR.

In this section, we mean a sparse buffer access by "using an array of PrimExpr as indices to access a SparseBuffer". One of the tasks of SparseTIR lowering is to convert sparse buffer accesses to normal buffer accesses in TIR.

Without loss of generality, we suppose a SparseBuffer A has n axes, and now we want to lower buffer access A[v_0, v_1, ..., v_{n-1}]. Where v_0, ..., v_{n-1} are all PrimExpr.

Sparse Indices

To lower sparse buffer accesses, we should be able to convert the original indices [v_0, ..., v_{n-1}] into sparse indices [vs_0, ..., vs_{n-1}], where vs_i means the non-zero element number that corresponds to v_i. (I don't know how to explain the idea clearly... Hope you can understand what I'm saying.)

Function F

Then we can define a group of functions on SparseBuffer A and sparse indices [vs_0, ..., vs_{n-1}]. We define F_{A, i}(vs_0, vs_1, ..., vs_i) to be the flattened offset of sparse indices [vs_0, vs_1, ..., vs_i] on A. The function name F means "flattened". F_{A, i} is written in F_i for short.

With the help of function F, the lowering becomes extremely easy: the lowered access of the sparse buffer access A[v_0, v_1, ..., v_{n-1}] is A_data[F_{n-1}(vs_0, vs_1, ..., vs_{n-1})]. This is because lowering in SparseTIR means "get the flattened offset of a group of indices" in essense. Therefore the remaining task is to devise the algorithms to compute the sparse indices and the function F.

Algorithm of Computing Sparse Indices and Function F

For a given original indices [v_0, ..., v_{n-1}], we compute its corresponding sparse indices and function F values alternatively. That is, we first compute vs_0, and then F_0(vs_0), vs_1, F_1(vs_1), and so on.

Before the main algorithm, we define F_{-1} = 0 to simplify the description of the algorithm.

Given A and [v_0, ..., v_{n-1}] as input, he algorithm is:

  • Enumerate i from 0 to n-1. For each i,
    • Firstly compute vs_i:
      • if v_i is a dense-fixed SparseIterVar,
        • if the i-th axis of A is dense-fixed, vs_i = v_i;
        • if the i-th axis of A is sparse-fixed or sparse-variable, we do binary search for value v_i between A_indices_i + F_i(vs_0, ..., vs_{i-1}, 0) and A_indices_i + F_i(vs_0, ..., vs_{i-1} + 1, 0). (Here we use the C++ style to represent array pointer.) The result of the binary search is vs_i.
        • it's not allowed that the i-th axis of A is dense-variable.
      • if v_i is a dense-variable SparseIterVar,
        • the i-th axis of A is only allowed to be dense-variable, and vs_i = v_i.
      • if v_i is a sparse-fixed SparseIterVar,
        • if the i-th axis of A is dense-fixed, suppose that v_i is iterating over sparse buffer B's axis j, then let f_B_j be the flattened offset of v_i on B (the offset is super easy to compute, so we omit the details here), and then vs_i = B_indices_j[f_B_j].
        • if the i-th axis of A is sparse-fixed,
          • if v_i is iterating the axis, then vs_i = v_i,
          • otherwise first use the way above to convert v_i to the dense value, then use the binary search method (same as the way that v_i is dense-fixed and the i-th axis of A is sparse) to convert the dense value to the desired sparse index.
        • if the i-th axis of A is sparse-variable, first convert v_i to the dense value and then apply binary search method.
        • it's not allowed that the i-th axis of A is dense-variable.
      • if v_i is a sparse-variable SparseIterVar,
        • if the i-th axis of A is dense-fixed, convert v_i to the dense value.
        • if the i-th axis of A is sparse-fixed, first convert v_i to the dense value and then apply binary search method.
        • if the i-th axis of A is sparse-variable,
          • if v_i is overating the axis, vs_i = v_i,
          • otherwise first convert v_i to the dense value and then apply binary search method.
        • it's not allowed that the i-th axis of A is dense-variable.
      • if v_i is not a BasicSparseIterVar, it's supposed to be a dense value (not sure, and will re-think in the future). In this case we use the binary search method to convert the dense value to a sparse index.
    • Then compute F_i(vs_0, ..., vs_i):
      • if the i-th axis of A is dense-fixed (or sparse-fixed), suppose the axis length (or number of columns) is n, then F_i(vs_0, ..., vs_i) = F_{i-1}(vs_0, ..., vs_{i-1}) * n + vs_i.
      • if the i-th axis of A is dense-variable or sparse-variable, F_i(vs_0, ..., vs_i) = A_indptr_i[F_{i-1}(vs_0, ..., vs_{i-1})] + vs_i.

After the above algorithm, we finally get the value F_{n-1}(vs_0, ..., vs_{n-1}). As described above, this value can be directly used to access A_data, which means the lowering for A[v_0, ..., v_{n-1}] is finished.


CC @yzh119. It's very possible that this post has some typos. Feel free to tell me if there's anything you don't understand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants