HeatViT: Hardware-Efficient Adaptive Token Pruning for Vision Transformers

Abstract In A Nutshell

To overcome the challenges presented by ViTs in terms of sophisticated network architectures with high computation and memory requirements, the paper presents a hardware-efficient image-adaptive token pruning framework for efficient acceleration of ViTs on embedded FPGAs. The token selector can be inserted before transformer blocks to dynamically identify informative tokens and consolidate non-informative tokens from the input images. This token selector framework is implemented on hardware using a control logic as all the layers used in the framework used in the ViT architectures, along with software 8-bit fixed-point quantization and polynomial approximations with regularization effect to reduce the quantization errors.

Paper Details

Introduction

ViTs and Transformers have great potential to bring together different domains through generalized network architectures and tackle reliance on scarce domain data, ultimately addressing two fundamental problems of deep learning:

  1. Strong dependence on domain data.
  2. Need for constant improvements in model architecture to serve evolving needs.

However, for ViTs and transformers to become indispensable, they have to tackle the following challenges:

  1. The self-attention block in the transformer architecture has a quadratic time and memory complexity concerning the number of tokens. This makes the scalability of ViTs challenging.
  2. Most of the optimization work on ViTs has been carried out using techniques applied to CNNs, like conventional weight pruning, quantization, and compact architecture design, but with limited accuracy and speed performance.
  3. To reduce quadratic complexity, there have been efforts to remove input tokens with a fixed ratio. But, it ignores input sample-dependent redundancy. Existing image-adaptive approaches simply discard non-informative tokens and do not fully explore token-level redundancies from different attention heads.
  4. Transformer architectures use hardware-unfriendly computations like more non-linear operations, which need to be addressed to enjoy possible additional optimization dimensions.

Software

In the authors’ extensive study on token pruning, two observations were made by analyzing the computation workflow of ViTs:

  1. The information redundancy in input tokens differs among attention heads in ViTs.
  2. Non-informative tokens identified in earlier transformer blocks may still add informative data in the transformer blocks in the propagating layers.

Based on this information, the authors design a token selector module that can be inserted before each transformer block with negligible computation overhead. The tokens at each layer go through a token selector S, which scores the importance of each token, which leads to having different token-level redundancies. Once the informative tokens are selected, instead of discarding non-informative tokens, they are packaged into a single informative token to preserve information for the later transformer blocks. All the informative tokens and the single consolidated package of non-informative tokens are concatenated to provide them as inputs to the next layer. An example of how the token pruning is implemented is shown in the figure below:

Hardware

The token selector is designed using linear layers to reuse the GEneral Matrix Multiply (GEMM) hardware component that is designed for ViTs. The concatenated (informative tokens and aggregated non-informative tokens into a single informative token) tokens, after parsing through the token selector, form a dense input of tokens, avoiding any possible sparse computations (there is sparsity in the tokens due to the process of pruning. When pruning is carried out on the tokens, the pruned non-informative tokens produce sparsity in the token space. Hence, concatenating all the informative tokens into a dense representation from a sparse representation is necessary) on the hardware.

8-bit fixed-width quantization is applied to all weights and activations on the neural network, and polynomial approximations for non-linear computations like GELU, Softmax, and Sigmoid. Along with this, the regularization effect is introduced on quantization error into the design of polynomial approximations to support more emphasis on quantization.

This is implemented on an embedded FPGA accelerator for ViTs. A GEMM engine and token selector are designed on the FGPA. The control logic of the token selector utilizes the components designed for the hardware acceleration of ViTs.

Training

To reduce inference latency on hardware, but not lose accuracy, a latency-aware multi-stage training strategy is devised to:

  1. Determine which transformer blocks require token selectors and insert them.
  2. Optimize the desired pruning rates for these token selectors.

The following process is devised to select the pruning rates. The latency-aware multi-stage training is started with a pre-trained ViT model. Through the fine-tuning of the model, token pruning blocks are inserted and trained in multiple stages over a small number of epochs. The training process started by inserting a single token selector block before the final transformer layer, as the initial blocks are sensitive to input data.

The current iteration of the network is trained, where the token selector is trained while fine-tuning the remaining components of the network. The training process is carried out till the desired latency is achieved without accuracy loss exceeding a selected threshold by adding more token selectors at iteration. This whole process takes about 90% of the training-from-scratch time of the backbone network.

Evaluation

The evaluation is done and compared against Static Token pruning, adaptive token pruning, head pruning, and token channel pruning techniques.

Background (Related Work / Literature Survey)

Vision Transformer

There are three main components in ViTs: Patch embeddings, transformer encoders, and classification output head.

Patch embedding An image $\mathbf{X} \in \mathbb{R}^{H \times W \times C}$ is reshaped into \mathbf{X} \in \mathbb{R}^{P^2 \times C}, where H, W are dimensions of the image, C is the number of channels, P is the patch size. The number of patches in an image is computed using,