Reduction

Author: Tianqi Chen

This is an introduction material on how to do reduction in TVM. Associative reduction operators like sum/max/min are typical construction blocks of linear algebra operations.

In this tutorial, we will demonstrate how to do reduction in TVM.

from __future__ import absolute_import, print_function

import tvm
import numpy as np

Describe Sum of Rows

Assume we want to compute sum of rows as our example. In numpy semantics this can be written as B = numpy.sum(A, axis=1)

The following lines describe the row sum operation. To create a reduction formula, we declare a reduction axis using tvm.reduce_axis. tvm.reduce_axis takes in the range of reductions. tvm.sum takes in the expression to be reduced as well as the reduction axis and compute the sum of value over all k in the declared range.

The equivalent C code is as follows:

for (int i = 0; i < n; ++i) {
  B[i] = 0;
  for (int k = 0; k < m; ++k) {
    B[i] = B[i] + A[i][k];
  }
}
n = tvm.var("n")
m = tvm.var("m")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), "k")
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k], axis=k), name="B")

Schedule the Reduction

There are several ways to schedule a reduction. Before doing anything, let us print out the IR code of default schedule.

s = tvm.create_schedule(B.op)
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

produce B {
  for (i, 0, n) {
    B[i] = 0.000000f
    for (k, 0, m) {
      B[i] = (B[i] + A[((i*m) + k)])
    }
  }
}

You can find that the IR code is quite like the C code. The reduction axis is similar to a normal axis, it can be splitted.

In the following code we split both the row axis of B as well axis by different factors. The result is a nested reduction.

ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
xo, xi = s[B].split(B.op.axis[0], factor=32)
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

produce B {
  for (i.outer, 0, ((n + 31)/32)) {
    for (i.inner, 0, 32) {
      if (likely(((i.outer*32) < (n - i.inner)))) {
        B[((i.outer*32) + i.inner)] = 0.000000f
      }
      for (k.outer, 0, ((m + 15)/16)) {
        for (k.inner, 0, 16) {
          if (likely(((i.outer*32) < (n - i.inner)))) {
            if (likely(((k.outer*16) < (m - k.inner)))) {
              B[((i.outer*32) + i.inner)] = (B[((i.outer*32) + i.inner)] + A[(((((i.outer*32) + i.inner)*m) + (k.outer*16)) + k.inner)])
            }
          }
        }
      }
    }
  }
}

If we are building a GPU kernel, we can bind the rows of B to GPU threads.

s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

produce B {
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = ((n + 31)/32)
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 32
  if (likely(((blockIdx.x*32) < (n - threadIdx.x)))) {
    B[((blockIdx.x*32) + threadIdx.x)] = 0.000000f
  }
  for (k.outer, 0, ((m + 15)/16)) {
    for (k.inner, 0, 16) {
      if (likely(((blockIdx.x*32) < (n - threadIdx.x)))) {
        if (likely(((k.outer*16) < (m - k.inner)))) {
          B[((blockIdx.x*32) + threadIdx.x)] = (B[((blockIdx.x*32) + threadIdx.x)] + A[(((((blockIdx.x*32) + threadIdx.x)*m) + (k.outer*16)) + k.inner)])
        }
      }
    }
  }
}

Reduction Factoring and Parallelization

One problem of building a reduction is that we cannot simply parallelize over the reduction axis. We need to divide the computation of the reduction, store the local reduction result in a temporal array before doing a reduction over the temp array.

The rfactor primitive does such rewrite of the computation. In the following schedule, the result of B is written to a temporary result B.rf. The factored dimension becomes the first dimension of B.rf.

s = tvm.create_schedule(B.op)
ko, ki = s[B].split(B.op.reduce_axis[0], factor=16)
BF = s.rfactor(B, ki)
print(tvm.lower(s, [A, B], simple_mode=True))

Out:

// attr [B.rf] storage_scope = "global"
allocate B.rf[float32 * 16 * n]
produce B.rf {
  for (k.inner, 0, 16) {
    for (i, 0, n) {
      B.rf[((k.inner*n) + i)] = 0.000000f
      for (k.outer, 0, ((m + 15)/16)) {
        if ((k.inner < (m - (k.outer*16)))) {
          B.rf[((k.inner*n) + i)] = (B.rf[((k.inner*n) + i)] + A[((k.inner + (i*m)) + (k.outer*16))])
        }
      }
    }
  }
}
produce B {
  for (ax0, 0, n) {
    B[ax0] = 0.000000f
    for (k.inner.v, 0, 16) {
      B[ax0] = (B[ax0] + B.rf[(ax0 + (k.inner.v*n))])
    }
  }
}

The scheduled operator of B also get rewritten to be sum over the first axis of reduced result of B.f

print(s[B].op.body)

Out:

[reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[0.000000f]), source=[B.rf(k.inner.v, ax0)], axis=[iter_var(k.inner.v, Range(min=0, extent=16))], where=(uint1)1, value_index=0)]

Cross Thread Reduction

We can now parallelize over the factored axis. Here the reduction axis of B is marked to be a thread. TVM allows reduction axis to be marked as thread if it is the only axis in reduction and cross thread reduction is possible in the device.

This is indeed the case after the factoring. We can directly compute BF at the reduction axis as well. The final generated kernel will divide the rows by blockIdx.x and threadIdx.y columns by threadIdx.x and finally do a cross thread reduction over threadIdx.x

xo, xi = s[B].split(s[B].op.axis[0], factor=32)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[B].bind(xi, tvm.thread_axis("threadIdx.y"))
tx = tvm.thread_axis("threadIdx.x")
s[B].bind(s[B].op.reduce_axis[0], tx)
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
s[B].set_store_predicate(tx.var.equal(0))
fcuda = tvm.build(s, [A, B], "cuda")
print(fcuda.imported_modules[0].get_source())

Out:

extern "C" __global__ void default_function__kernel0( float* __restrict__ A,  float* __restrict__ B, int m, int n) {
   float B_rf[1];
  __shared__ float red_buf0[512];
  B_rf[0] = 0.000000e+00f;
  for (int k_outer = 0; k_outer < ((m + 15) / 16); ++k_outer) {
    if ((((int)blockIdx.x) * 32) < (n - ((int)threadIdx.y))) {
      if (((int)threadIdx.x) < (m - (k_outer * 16))) {
        B_rf[0] = (B_rf[0] + A[(((((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) * m) + ((int)threadIdx.x)) + (k_outer * 16))]);
      }
    }
  }
  ((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] = (((((int)blockIdx.x) * 32) < (n - ((int)threadIdx.y))) ? B_rf[0] : 0.000000e+00f);
  __syncthreads();
  if (((int)threadIdx.x) < 8) {
    ((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] = (((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] + ((volatile __shared__ float*)red_buf0)[(((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 8)]);
    ((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] = (((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] + ((volatile __shared__ float*)red_buf0)[(((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 4)]);
    ((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] = (((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] + ((volatile __shared__ float*)red_buf0)[(((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 2)]);
    ((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] = (((volatile __shared__ float*)red_buf0)[((((int)threadIdx.y) * 16) + ((int)threadIdx.x))] + ((volatile __shared__ float*)red_buf0)[(((((int)threadIdx.y) * 16) + ((int)threadIdx.x)) + 1)]);
  }
  __syncthreads();
  if ((((int)blockIdx.x) * 32) < (n - ((int)threadIdx.y))) {
    if (((int)threadIdx.x) == 0) {
      B[((((int)blockIdx.x) * 32) + ((int)threadIdx.y))] = ((volatile __shared__ float*)red_buf0)[(((int)threadIdx.y) * 16)];
    }
  }
}

Verify the correctness of result kernel by comparing it to numpy.

nn = 128
ctx  = tvm.gpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx)
fcuda(a, b)
np.testing.assert_allclose(
    b.asnumpy(),  np.sum(a.asnumpy(), axis=1), rtol=1e-4)

Describe Convolution via 2D Reduction

In TVM, we can describe convolution via 2D reduction in a simple way. Here is an example for 2D convolution with filter size = [3, 3] and strides = [1, 1].

n = tvm.var('n')
Input = tvm.placeholder((n, n), name='Input')
Filter = tvm.placeholder((3, 3), name='Filter')
di = tvm.reduce_axis((0, 3), name='di')
dj = tvm.reduce_axis((0, 3), name='dj')
Output = tvm.compute(
    (n - 2, n - 2),
    lambda i, j: tvm.sum(Input[i + di, j + dj] * Filter[di, dj], axis=[di, dj]),
    name='Output')
s = tvm.create_schedule(Output.op)
print(tvm.lower(s, [Input, Filter, Output], simple_mode=True))

Out:

produce Output {
  for (i, 0, (n + -2)) {
    for (j, 0, (n + -2)) {
      Output[((i*(n + -2)) + j)] = 0.000000f
      for (di, 0, 3) {
        for (dj, 0, 3) {
          Output[((i*(n + -2)) + j)] = (Output[((i*(n + -2)) + j)] + (Input[((j + ((i + di)*n)) + dj)]*Filter[((di*3) + dj)]))
        }
      }
    }
  }
}

Define General Commutative Reduction Operation

Besides the built-in reduction operations like tvm.sum, tvm.min and tvm.max, you can also define your commutative reduction operation by tvm.comm_reducer.

n = tvm.var('n')
m = tvm.var('m')
product = tvm.comm_reducer(lambda x, y: x*y,
    lambda t: tvm.const(1, dtype=t), name="product")
A = tvm.placeholder((n, m), name='A')
k = tvm.reduce_axis((0, m), name='k')
B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B')

Note

Sometimes we would like to perform reduction that involves multiple values like argmax, which can be done by tuple inputs. See Describe Reduction with Collaborative Inputs for more detail.

Summary

This tutorial provides a walk through of reduction schedule.

  • Describe reduction with reduce_axis.
  • Use rfactor to factor out axis if we need parallelism.
  • Define new reduction operation by tvm.comm_reducer

Total running time of the script: ( 0 minutes 0.297 seconds)

Generated by Sphinx-Gallery