Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[discuss] Remove predicate for non-tail block after loop split #419

Open
yzh119 opened this issue Jul 23, 2021 · 1 comment
Open

[discuss] Remove predicate for non-tail block after loop split #419

yzh119 opened this issue Jul 23, 2021 · 1 comment

Comments

@yzh119
Copy link
Collaborator

yzh119 commented Jul 23, 2021

When loop extent is not a multiple of split factor, split would insert a predicate to avoid invalid memory access. However, this would cause a lot of redundant if statements if I unfold the inner loop after split:

@tvm.script.tir
def f(a: ty.handle) -> None:
    A = tir.match_buffer(a, [100,], 'float32')
    with tir.block([100], 'A') as i:
        A[i] = 1.

sch = tir.Schedule(f, debug_mode=True)
blk = sch.get_block('A')
i, = sch.get_loops(blk)
io, ii = sch.split(i, factor=32)
sch.unroll(ii)
print(tvm.lower(sch.mod['main'])) 

The result is:

primfn(a: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {A: Buffer(A_1: Pointer(float32), float32, [100], [])}
  buffer_map = {a: A} {
  for (i0_outer: int32, 0, 4) {
    A_1[(i0_outer*32)] = 1f32
    A_1[((i0_outer*32) + 1)] = 1f32
    A_1[((i0_outer*32) + 2)] = 1f32
    A_1[((i0_outer*32) + 3)] = 1f32
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 4)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 5)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 6)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 7)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 8)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 9)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 10)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 11)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 12)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 13)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 14)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 15)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 16)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 17)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 18)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 19)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 20)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 21)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 22)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 23)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 24)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 25)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 26)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 27)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 28)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 29)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 30)] = 1f32
    }
    if (i0_outer < 3) {
      A_1[((i0_outer*32) + 31)] = 1f32
    }
  }
}

The inserted if statements would harm ILP.

I wonder currently do we have solutions to combine these if statements in TIR? If not, can we decompose the outer loop into to parts (non-tail and tail), and only apply unroll for the non-tail part?

for (i0_outer: int32, 0, 3) {
  A_1[(i0_outer*32)] = 1f32
  A_1[((i0_outer*32) + 1)] = 1f32
  A_1[((i0_outer*32) + 2)] = 1f32
  A_1[((i0_outer*32) + 3)] = 1f32
  A_1[((i0_outer*32) + 4)] = 1f32
  A_1[((i0_outer*32) + 5)] = 1f32
  A_1[((i0_outer*32) + 6)] = 1f32
  A_1[((i0_outer*32) + 7)] = 1f32
  A_1[((i0_outer*32) + 8)] = 1f32
  A_1[((i0_outer*32) + 9)] = 1f32
  A_1[((i0_outer*32) + 10)] = 1f32
  A_1[((i0_outer*32) + 11)] = 1f32
  A_1[((i0_outer*32) + 12)] = 1f32
  A_1[((i0_outer*32) + 13)] = 1f32
  A_1[((i0_outer*32) + 14)] = 1f32
  A_1[((i0_outer*32) + 15)] = 1f32
  A_1[((i0_outer*32) + 16)] = 1f32
  A_1[((i0_outer*32) + 17)] = 1f32
  A_1[((i0_outer*32) + 18)] = 1f32
  A_1[((i0_outer*32) + 19)] = 1f32
  A_1[((i0_outer*32) + 20)] = 1f32
  A_1[((i0_outer*32) + 21)] = 1f32
  A_1[((i0_outer*32) + 22)] = 1f32
  A_1[((i0_outer*32) + 23)] = 1f32
  A_1[((i0_outer*32) + 24)] = 1f32
  A_1[((i0_outer*32) + 25)] = 1f32
  A_1[((i0_outer*32) + 26)] = 1f32
  A_1[((i0_outer*32) + 27)] = 1f32
  A_1[((i0_outer*32) + 28)] = 1f32
  A_1[((i0_outer*32) + 29)] = 1f32
  A_1[((i0_outer*32) + 30)] = 1f32
  A_1[((i0_outer*32) + 31)] = 1f32
}
for (i0_inner: int32, 0, 100 - 3 * 32) {
  A_1[3 * 32 + i0_inner] = 1f32
}
@yzh119
Copy link
Collaborator Author

yzh119 commented Jul 23, 2021

I noticed a related discussion in tvm issue: apache/tvm#1979 , it's still open.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant