Reweighted Proximal Pruning for Large-Scale Language Representation

  • 2019-09-27 04:10:10
  • Fu-Ming Guo, Sijia Liu, Finlay S. Mungall, Xue Lin, Yanzhi Wang
  • 2

Abstract

Recently, pre-trained language representation flourishes as the mainstay ofthe natural language understanding community, e.g., BERT. These pre-trainedlanguage representations can create state-of-the-art results on a wide range ofdownstream tasks. Along with continuous significant performance improvement,the size and complexity of these pre-trained neural models continue to increaserapidly. Is it possible to compress these large-scale language representationmodels? How will the pruned language representation affect the downstreammulti-task transfer learning objectives? In this paper, we propose ReweightedProximal Pruning (RPP), a new pruning method specifically designed for alarge-scale language representation model. Through experiments on SQuAD and theGLUE benchmark suite, we show that proximal pruned BERT keeps high accuracy forboth the pre-training task and the downstream multiple fine-tuning tasks athigh prune ratio. RPP provides a new perspective to help us analyze whatlarge-scale language representation might learn. Additionally, RPP makes itpossible to deploy a large state-of-the-art language representation model suchas BERT on a series of distinct devices (e.g., online servers, mobile phones,and edge devices).

 

Quick Read (beta)

Reweighted Proximal Pruning for
Large-Scale Language Representation

Fu-Ming Guo Northeastern University Sijia Liu MIT-IBM Watson AI Lab, IBM Research Finlay S. Mungall United States Federal Aviation Administration Xue Lin1&Yanzhi Wang Northeastern University
Abstract

Recently, pre-trained language representation flourishes as the mainstay of the natural language understanding community, e.g., BERT. These pre-trained language representations can create state-of-the-art results on a wide range of downstream tasks. Along with continuous significant performance improvement, the size and complexity of these pre-trained neural models continue to increase rapidly. Is it possible to compress these large-scale language representation models? How will the pruned language representation affect the downstream multi-task transfer learning objectives? In this paper, we propose Reweighted Proximal Pruning (RPP), a new pruning method specifically designed for a large-scale language representation model. Through experiments on SQuAD and the GLUE benchmark suite, we show that proximal pruned BERT keeps high accuracy for both the pre-training task and the downstream multiple fine-tuning tasks at high prune ratio. RPP provides a new perspective to help us analyze what large-scale language representation might learn. Additionally, RPP makes it possible to deploy a large state-of-the-art language representation model such as BERT on a series of distinct devices (e.g., online servers, mobile phones, and edge devices).

Reweighted Proximal Pruning for
Large-Scale Language Representation

1 Introduction

Pre-trained language representations such as GPT (radford2018improving), BERT (devlin2019bert) and XLNet (yang2019xlnet), have shown substantial performance improvements using self-supervised training on large-scale corpora (dai2015semi; peters2018deep; radford2018improving; liu2019roberta). More interestingly, the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering (rajpurkar2016squad; rajpurkar2018know), and language inference (bowman2015large; williams2017broad), without substantial task-specific architecture modifications. BERT is conceptually simple and empirically powerful (devlin2019bert).

However, along with the significant performance enhancement, the parameter volume and complexity of these pre-trained language representations significantly increase. As a result, it becomes difficult to deploy these large-scale language representations into real-life computation constrained devices including mobile phones and edge devices. Throughout this paper, we attempt to answer the following questions.

Question 1: Is it possible to compress large-scale language representations such as BERT via weight pruning?

Question 2: How would the weight-pruned, pre-trained model affect the performance of the downstream multi-task transfer learning objectives?

The problem of weight pruning has been studied under many types of deep neural networks (DNNs) (goodfellow2016deep), such as AlexNet (krizhevsky2012imagenet), VGG (simonyan2014very), ResNet (he2016deep), and MobileNet (howard2017mobilenets). It is shown that weight pruning can result in a notable reduction in the model size. A suite of weight pruning techniques have been developed, such as non-structured weight pruning (han2015learning), structured weight pruning (wen2016learning), filter pruning (li2016pruning), channel pruning (he2017channel), ADMM-NN (ren2019admm) and PCONV (ma2019pconv) to name a few. Different from pruning CNN-type models, BERT not only considers the metrics on the pre-training task, but also needs to make allowance for the downstream multi-task transfer learning objectives. Thus, the desired weight pruning needs to preserve the capacity of transfer learning from a sparse pre-trained model to downstream fine-tuning tasks.

In this work, we investigate irregular weight pruning techniques on the BERT model, including the iterative pruning method (han2015learning) and one shot pruning method (liu2018rethinking). However, these methods fail to converge to a sparse pre-trained model without incurring significant accuracy drop, or in many cases do not converge at all (see supporting results in Appendix). Note that the aforementioned weight pruning techniques are built on different sparsity-promoting regularization schemes (han2015learning; wen2016learning), e.g., lasso regression (1 regularization) and ridge regression (2 regularization). We find that the failure of previous methods on weight pruning of BERT is possibly due to the inaccurate sparse pattern learnt from the simple 1 or 2 based sparsity-promoting regularizer. In fact, the difficulty of applying regularization to generate weight sparsity coincides with the observation in (loshchilov2018decoupled) on the imcompatibility of conventional weight decay (2 regularization) for training super-deep DNNs as BERT. It is pointed out that the main reason is that the direct addition of a regularization penalty term causes divergence from the original loss function and has negative effect on the effectiveness of gradient-based update. To mitigate this limitation, (loshchilov2018decoupled) have modified the regularization in Adam by decoupling weight decay regularization from the gradient-based update, and have achieved state-of-the-art results on downstream multi-task transfer learning objectives (devlin2019bert).

In this work, we aim at more accurate sparse pattern search motivated by our experiments and the conclusion from loshchilov2018decoupled. We propose Reweighted Proximal Pruning (RPP), which integrates reweighted 1 minimization (candes2008enhancing) with proximal algorithm (parikh2014proximal). RPP consists of two parts (see Appendix A for an overview of our approach): the reweighted 1 minimization and the proximal operator. Reweighted 1 minimization serves as a better method of generating sparsity in DNN models matching the nature of weight pruning, compared with 1 regularization. Thanks to the closed-form solution of proximal operation on a weighted 1 norm, in RPP the sparsity pattern search can be decoupled from computing the gradient of the training loss. In this way the aforementioned pitfall in prior weight pruning technique on BERT can be avoided. We show that RPP achieves effective weight pruning on BERT for the first time to the best of our knowledge. Experimental results demonstrate that the proximal pruned BERT model keeps high accuracy on a wide range of downstream tasks, including SQuAD (rajpurkar2016squad; rajpurkar2018know) and GLUE (wang2018glue).

We summarize our contributions as follows.

  • We develop the pruning algorithm Reweighted Proximal Pruning (RPP), which acheives the first effective weight pruning result on large pre-trained language representation model - BERT. RPP achieves 59.3% weight sparsity without inducing the performance loss on both pre-training and fine-tuning tasks.

  • We spotlight the relationship between the pruning ratio of the pre-trained DNN model and the performance on the downstream multi-task transfer learning objectives. We show that many downstream tasks except for SQuAD allows at least 80% pruning ratio compared with 59.3% under the more challenging task SQuAD.

  • We show that different from weight pruning in image classification tasks, there exists a sparsity pattern in pre-trained language representation. Our approach not only provides an effective weight pruning algorithm but also offers a new perspective on analyzing large-scale language representation.

2 Related Work

BERT and prior work on model compression

BERT (devlin2019bert) is a self-supervised approach for pre-training a deep transformer encoder (vaswani2017attention), before fine-tuning it for particular downstream tasks. Pretraining of BERT optimizes two training objectives - masked language modeling (MLM) and next sentence prediction (NSP) - which require a large collection of unlabeled text. We use BooksCorpus (800M words) (zhu2015aligning) and the English instance of Wikipedia (2,500M words) as the pre-training corpus, the same as devlin2019bert. For detailed information about the BERT model, readers can refer to the original paper (devlin2019bert).

michel2019sixteen mask some heads in multi-head attention modules in BERT, and then evaluate the performance on the machine translation task. Similarly, hao2019multi eliminates certain heads in the multi-head attention module. First, the limited previous work do not consider the pre-training metrics and the other downstream multi-mask transfer learning objectives. They only considered the specific machine translation task (out of over 10 transfer tasks), which is only a specific fine-tuning and is limited for the universal pre-trained language representation (BERT). Second, the multi-head attention module uses a weight sharing mechanism (vaswani2017attention). So masking some heads does not reduce the weight volume. Finally, multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions, while single attention head inhibits this effect (vaswani2017attention). As a result, masking some heads in multi-head attention harms the weight sharing mechanism, without weight volume reduction. In summary, the limited previous work in this area are not effective weight pruning method on BERT. shen2019q reports the quantization result of BERT model, which is orthogonal to our work and can be combined for further compression/acceleration.

Reweighted 1 and proximal algorithm

candes2008enhancing present reweighted 1 algorithm and demonstrate the remarkable performance and broad applicability of this algorithm in the areas of statistical estimation, error correction and image processing. Proximal algorithms can be viewed as an analogous tool for non-smooth, constrained, large-scale, or distributed versions of these problems (parikh2014proximal). To the best of our knowledge, we are the first to apply reweighted 1 and proximal algorithm in the DNN weight pruning domain, and achieve effective weight pruning on BERT.

3 Reweighted Proximal Pruning for large-scale language representation during pre-training

Pruning for pre-trained language representations should not only consider the performance of pre-training objectives, but also make allowance for the downstream fine-tuning transfer learning tasks. Let fi denote the loss function of network for downstream task 𝒯ip(𝒯), where p(𝒯) denotes the distribution of tasks. Let 𝐰 denote the parameters of the pre-trained model (pre-training in BERT), and 𝐳i denote the i-th task-specified model parameters (fine-tuning in BERT). The downstream tasks have separate fine-tuned models, even though they are initialized with the same pre-trained parameters (devlin2019bert). Starting from the pre-trained parameters 𝐰, the parameters 𝐳i(𝐰) are obtained through fine-tuning

minimize𝐰dfi(𝐰) (1)

3.1 Pruning formulation in transfer learning

Following the conventional weight pruning formulation, we first consider the problem of weight pruning during pre-training:

minimize𝐰df0(𝐰)+γ𝐰p (2)

where f0 is the loss function of pruning, p{0,1} denotes the type of regularization norm, and γ is a regularization term. We note that the sparsity-promoting regularizer in the objective could also be replaced with a hard p constraint, |𝐰pτ for some τ.

Let 𝐰^ denote the solution to problem (2), and the corresponding sparse pattern 𝒮𝐰^ is given by

𝒮𝐰^={i|w^i=0,i[d]} (3)

For a specific transfer task i, we allow an additional retraining/fine-tuning step to train/fine-tune weights starting from the pre-training results 𝐰^ and subject to the determined, fixed sparse pattern 𝒮𝐰^, denoted as 𝐳i(𝐰^;𝒮𝐰^). That is, we solve the modified problem equation 1

minimize𝐳ifi(𝐳i(𝐰^;𝒮𝐰^)) (4)

Here, different from (1), the task-specific fine tuning weights variable 𝐳i(𝐰^;𝒮𝐰^) is now defined over 𝒮𝐰^.

Our goal is to seek a sparse (weight pruned) model during pre-training, with weight collection 𝐰^ and sparsity 𝒮𝐰^, which can perform as well as the original pre-trained model over multiple new tasks (indexed by i). These fine-tuned models 𝐳i(𝐰^;𝒮𝐰^) (for different i) share the identical universal sparsity 𝒮𝐰^.

3.2 Reweighted Proximal Pruning

In order to enhance the performance of pruning pre-trained language representation over multi-task downstream transfer learning objectives, we propose Reweighted Proximal Pruning (RPP). RPP consists of two parts: the reweighted 1 minimization and the proximal operator. Reweighted 1 minimization serves as a better method of generating sparsity in DNN models matching the natural objective of weight pruning, compared with 1 regularization. The proximal algorithm then separates the computation of gradient with the proximal operation over a weighted 1 norm, without adding any penalty loss to the original objective function of DNN models. This is necessary in the weight pruning of super-deep language representation models

3.2.1 Reweighted 1 minimization

In the previous pruning methods (han2015learning; wen2016learning), 1 regularization is used to generate sparsity. However, consider that two weights wi,wj(wi<wj) in the DNN model are penalized through 1 regularization. The larger weight wj is penalized more heavily than smaller weight wi in 1 regularization, which violates the original intention of weight pruning, “removing the unimportant connections” (parameters close to zero) (han2015learning). To address this imbalance, we introduce reweighted 1 minimization (candes2008enhancing) to the DNN pruning domain. Our introduced reweighted 1 minimization operates in a systematic and iterative manner (detailed process shown in Algorithm 3.2.1), and the first iteration of reweighted 1 minimization is 1 regularization. This designed mechanism helps us to observe the performance difference between 1 and reweighted 1 minimization. Meanwhile, this mechanism ensures the advancement of reweighted 1 minimization over 1 regularization, as the latter is the single, first step of the former.

Consider the regularized weight pruning problem (reweighted 1 minimization):

minimize𝐰f0(𝐰)+γiαi|wi| (5)

where αi(αi>0) factor is a positive value. It is utilized for balancing the penalty, and is different from weight wi in DNN model. αi factors will be updated in the iterative reweighted 1 minimization procedure (Step 2 in Algorithm 3.2.1) in a systematic way (candes2008enhancing). If we set T=1 for reweighted 1, then it reduces to 1 sparse training.

{algorithm}

[h!] RPP procedure for reweighted 1 minimization {algorithmic}[1] \StateInput: Initial pre-trained model 𝐰0, initial reweighted 1 minimization ratio γ, initial positive value α0=1 \Fort=1,2,,T

\State

𝐰=𝐰(t-1), α=α(t-1) \StateStep 1: Solve problem (5) to obtain a solution 𝐰t via iterative proximal algorithm (6) \StateStep 2: Update reweighted factors αit=1|wit|(t)+ϵ (the inside wit denotes the weight wi in iteration t, and the outside (t) denotes the exponent), ϵ is a small constant, e.g., ϵ=0.001

\EndFor

3.2.2 Proximal method

In the previous pruning methods (han2015learning; wen2016learning), 1 regularization loss is directly added on the original training objective loss function of DNN models, and the hard-threshold is adopted to execute the pruning action in the final step of pruning (all weights below the hard-threshold become zero). We cannot add this in our reweighted 1 regularization directly to avoid negative effect on the gradient update as explained before. Then the remaining challenge is to derive an effective solution to problem (5) for given {αi}, namely, Step 1 in Algorithm 2, in which back-propagation based gradient update is only applied on f0(𝐰) but not γiαi|wi|.

We adopt the proximal algorithm (parikh2014proximal) to satisfy this requirement through decoupling methodology. In this way, the sparsity pattern search can be decoupled from computing the gradient of the training loss. The proximal algorithm is shown in (parikh2014proximal) to be highly effective (compared with the original solution) on a wide set of non-convex optimization problems. Additionally, our presented reweighted 1 minimization (5) has analytical solution through the proximal operator.

To solve problem (5) for a given α, the proximal algorithm operates in an iterative manner:

𝐰k=proxλk,rw-1(𝐰k-1-λk𝐰f0(𝐰k-1)) (6)

where the subscript k denotes the time step of the training process inside RPP, λk(λk>0) is the learning rate, and we set the initial 𝐰 to be 𝐰(t-1) from the last iteration of reweighted 1. The proximal operator proxλk,rw-1(𝐚) is the solution to the problem

minimize𝐰γiαi|wi|+12λk𝐰-𝐚22 (7)

where 𝐚=𝐰k-1-λk𝐰f(𝐰k-1). The above problem has the following analytical solution (liu2014sparsity)

wi,k={(1-γλkαi|ai|)ai|ai|>λkγαi0|ai|λkγαi. (8)

We remark that the updating rule (6) can be interpreted as the proximal step (8) over the gradient descent step 𝐰k-1-λk𝐰f(𝐰k-1). Such a descent can also be obtained through other optimizers such as AdamW. We use the AdamW (loshchilov2018decoupled) as our optimizer, the same with (devlin2019bert). The concrete process of AdamW with proximal operator is shown in Algorithm C of Appendix.

Why chooses AdamW rather than Adam? loshchilov2018decoupled proposes AdamW to improve the generalization ability of Adam (kingma2014adam). loshchilov2018decoupled shows that 2 regularization (or the direct addition of a regularization penalty term) is inherently not effective in Adam and has negative effect on the effectiveness of gradient-based update, which is the reason of the difficulty to apply adaptive gradient algorithms to super-deep DNN training for NLU applications (like BERT). loshchilov2018decoupled mitigates this limitation and improves regularization of Adam, by decoupling weight decay regularization from the gradient-based update (loshchilov2018decoupled). AdamW is widely adopted in pretraining large language representations, e.g., BERT (devlin2019bert), GPT (radford2018improving) and XLNet (yang2019xlnet). Our proposed proximal pruning also benefits from the decoupling design ideology. The difference is that proximal pruning is for the generation of sparsity, instead of avoiding over-fitting, like decoupled weight decay in AdamW.

Our new and working baseline: New Iterative Pruning (NIP).

To get the identical universal sparsity 𝒮𝐰, we tried a series of pruning techniques, including the iterative pruning method (han2015learning) and one shot pruning method (liu2018rethinking). But these methods do not converge to a viable solution. The possible reason for non-convergence of the iterative pruning method is that the direct promotion of p (p{1,2}) sparsity on the original pre-training objective loss function might be insufficient to prune BERT. To circumvent the convergence issue of conventional iterative pruning methods, we propose a new iterative pruning (NIP) method. Different from iterative pruning (han2015learning), NIP reflects the naturally progressive pruning performance without any externally introduced penalty. We hope that other pruning methods should not perform worse than NIP, otherwise, the effect of the newly introduced sparsity-promoting regularization has negative effects. We will show that NIP is able to successfully prun BERT to certain pruning ratios. We refer readers to Appendix B for our full baseline algorithm.

4 Experiments

In this section, we describe the experiments on pruning pre-trained BERT and demonstrate the performance on 10 downstream transfer learning tasks.

4.1 Experiment Setup

We use the official BERT model from Google as the startpoint. Following the notation from devlin2019bert, we denote the number of layers (i.e., transformer blocks) as L, the hidden size as H, and the number of self-attention heads as A. We prune two kinds of BERT model: BERTBASE (L=12,H=768,A=12,total parameters=110M) and BERTLARGE (L=24,H=1024,A=16,total parameters=340M). As the kernel weights take up most of the weights volume in each layer (i.e. transformer block), the kernel weights are our pruning target.

Data: In pre-training, we use the same pre-training corpora as devlin2019bert: BookCorpus (800M words) (zhu2015aligning) and English Wikipedia (2,500M words). Based on the same corpora, we use the same preprocessing script11 1 https://github.com/google-research/bert to create the pre-training data. In fine-tuning, we report our results on the Stanford Question Answering Dataset (SQuAD) and the General Language Understanding Evaluation (GLUE) benchmark (wang2018glue). We use two versions of SQuAD: V1.1 and V2.0 (rajpurkar2016squad; rajpurkar2018know). The GLUE is a collection of datasets/tasks for evaluating natural language understanding systems22 2 The datasets/tasks are: CoLA (warstadt2018neural), Stanford Sentiment Treebank (SST) (socher2013recursive), Microsoft Research Paragraph Corpus (MRPC) (dolan2005automatically), Semantic Texual Similarity Benchmark (STS) (agirre2007semeval), Quora Question Pairs (QQP), Multi-Genre NLI (MNLI) (williams2017broad), Question NLI (QNLI) (rajpurkar2016squad), Recognizing Textual Entailment (RTE) and Winograd NLI(WNLI) (levesque2012winograd)..

Input/Output representations: We follow the input/output representation setting from devlin2019bert. We use the WordPiece (wu2016google) embeddings with a 30,000 token vocabulary. The first token of every sentence is always a special classification token ([CLS]). The sentences are differentiated with a special token ([SEP]). Following (devlin2019bert), we use the same input/output representations for both pre-training and fine-tuning.

Evaluation: In pre-training, BERT considers two objectives: masked language modeling (MLM) and next sentence prediction (NSP). For MLM, a random sample of the tokens in the input sequence is selected and replaced with the special token ([MASK]). The MLM objective is a cross-entropy loss on predicting the masked tokens. NSP is a binary classification loss for predicting whether two segments follow each other in the original text. We use MLM and NSP to pre-train, retrain and evaluate the pre-trained BERT model. In fine-tuning, F1 scores are reported for SQuAD, QQP and MRPC. Accuracy scores are reported for the other tasks.

All the experiments execute on one Google Cloud TPU V3-512 cluster, three Google Cloud TPU V2-512 clusters and 110 Google Cloud TPU V3-8/V2-8 instances.

Baseline: As there is no public effective BERT pruning method, we use the proposed NIP pruning method on BERT as the baseline method. The detailed algorithm is shown Appendix B. The progressive pruning ratio is p=10% (prune 10% more weights in each iteration). Starting from the official BERTBASE, we use 9 iterations. In each iteration t of NIP, we get the sparse BERTBASE with specific sparsity, as (𝐰t;𝒮𝐰t). Then we retrain the sparse BERTBASE 𝐰t over the sparsity 𝒮𝐰t. In the retraining process, the initial learning rate is 210-5, the batch size is 1024 and the retraining lasts for 10,000 steps (around 16 epochs). For the other hyperparameters, we follow the original BERT paper devlin2019bert. In each iteration, the well retrained sparse BERTBASE is the starting point for the fine-tuning tasks and the next iteration.

Figure 1: Evaluate the performance of pruned BERTBASE using NIP and RPP, respectively (MLM and NSP accuracy on pre-training data and F1 score of fine-tuning on SQuAD 1.1 are reported).

4.2 Reweighed Proximal Pruning (RPP)

We apply the proposed Reweighted Proximal Pruning (RPP) method on both BERTBASE and BERTLARGE, and demonstrate performance improvement. Detailed process of RPP is in Appendix C.

For BERTBASE, we use the hyperparameters exactly the same with our experiments using NIP. The initial learning rate is λ=210-5 and the batch size is 1024. We iterate the RPP for six times (T=6), and each iteration lasts for 100,000 steps (around 16 epochs). The total number of epochs in RPP is smaller than NIP when achieving 90% sparsity (96<144). There is no retraining process in RPP. We set γ{10-2,10-3,10-4,10-5} and ϵ=10-9 in Algorithm 3.2.1. Recall that RPP reduces to 1 sparse training as t=1.

In Figure 1, we present the accuracy versus the pruning ratio for pre-training tasks MLM and NSP, and fine-tuning task SQuAD 1.1. Here we compare RPP with NIP. Along with the RPP continuing to iterate, the performance of RPP becomes notably higher than NIP for both the pre-training task and the fine-tuning task. The gap further increases as the RPP iterates more times. In Figure 1, we find that the NSP accuracy is very robust to pruning. Even when 90% of the attention weights are pruned, the NSP accuracy keeps above 95% in RPP algorithm and around 90% in NIP algorithm. For MLM accuracy and SQuAD F1 score, the performance drops quickly as the prune ratio increases. RPP slows down the decline trend to a great extent. On SQuAD 1.1 dataset/task, RPP keeps the F1 score of BERTBASE at 88.5 (0 degradation compared with original BERT) at 41.2% prune ratio, while the F1 score of BERTBASE applied with NIP drops to 84.6 (3.9 degradation) at 40% prune ratio. At 80% prune ratio, RPP keeps the F1 score of BERTBASE at 84.7 (3.8 degradation), while the F1 score of BERTBASE applied with NIP drops to 68.8 (19.7 degradation compared with the original BERT). In addition to the fine-tuning task of SQuAD 1.1, the other transfer learning tasks show the same trend (RPP consistently outperforms NIP) and the detailed results are reported in Appendix.

For BERTLARGE, we use the hyperparameters exactly the same with our experiments using NIP except for the batch size. The initial learning rate is 210-5 and the batch size is 512. We iterate the RPP for four times (T=4), and each iteration lasts for 100,000 steps (around 8 epochs). There is no retraining process either. We set γ{10-2,10-3} and ϵ=10-9 in Algorithm 3.2.1. The experimental results about pruning BERTLARGE and then fine-tuning are shown in Table 1.

Table 1: BERTLARGE pruning results on a set of transfer learning tasks. The degradation is contrasted with the original BERT (without pruning) for transfer learning.
Method Prune Ratio(%) SQuAD 1.1 QQP MNLI MRPC CoLA
NIP 50.0 85.3 (-5.6) 85.1 (-6.1) 77.0 (-9.1) 83.5 (-5.5) 76.3 (-5.2)
80.0 75.1 (-15.8) 81.1 (-10.1) 73.81 (-12.29) 68.4 (-20.5) 69.13 (-12.37)
RPP 59.3 90.23 (-0.67) 91.2 (-0.0) 86.1 (-0.0) 88.1 (-1.2) 82.8 (+1.3)
88.4 81.69 (-9.21) 89.2 (-2.0) 81.4 (-4.7) 81.9 (-7.1) 79.3 (-2.2)
Method Prune Ratio(%) SQuAD 2.0 QNLI MNLIM SST-2 RTE
NIP 50.0 75.3 (-6.6) 90.2 (-1.1) 82.5 (-3.4) 91.3 (-1.9) 68.6 (-1.5)
80.0 70.1 (-11.8) 80.5 (-10.8) 78.4 (-7.5) 88.7 (-4.5) 62.8 (-7.3)
RPP 59.3 81.3 (-0.6) 92.3 (+1.0) 85.7 (-0.2) 92.4 (-0.8) 70.1 (-0.0)
88.4 80.7 (-1.2) 88.0 (-3.3) 81.8 (-4.1) 90.5 (-2.7) 67.5 (-2.6)

4.3 Visualizing Attention Pattern in BERT

We visualize the sparse pattern of the kernel weights in sparse BERT model applied with RPP, and present several examples in Figure 2. Because we directly visualize the value of identical universal sparsity 𝒮𝐰 without any auxiliary function instead of activation map, the attention pattern is universal and data independent.

Figure 2: Visualization of sparsity 𝒮 in pruned BERT model 𝐰. (RPP at 99% prune ratio) The yellow tiny spots are the remaining values, and the blue ground represents all the zero values after pruning.

BERT’s model architecture is a multi-layer, bidirectional transformer encoder based on the original implementation (vaswani2017attention). Following (vaswani2017attention), the transformer architecture is based on “scaled dot-product attention.” The input consists of queries, keys and values, denoted as matrices Q, K and V, respectively. The output of attention model is computed as

 Attention (Q,K,V)=softmax(QKTdk)V (9)

where dk is the dimension. We visualize the sparse matrices Q, K and V successively in Figure 2. The visualization of sparse Q demonstrates obvious vertical distribution. The visualization of sparse K is mainly in vertical distribution, with some strong horizontal linking lines. The visualization interpretation reveals that the query matrix Q mainly models the information inside each sequence, while the key matrix K uses strong horizontal linking lines to build the relationship between different sequences in the context. In summary, the sparse attention pattern exhibits obvious structured distribution.

4.4 t-SNE Visualization

Figure 3: t-SNE visualization of BERTLARGE embedding applied with RPP. (From left to right: t-SNE of original BERT embedding at viewpoint A, t-SNE of sparse BERT embedding at view point A, t-SNE of original BERT embedding at view point B, t-SNE of sparse BERT embedding at view point B)

t-Distributed Stochastic Neighbor Embedding (t-SNE) is a technique for dimensionality reduction that is particularly well suited for the visualization of high-dimensional datasets (maaten2008visualizing). Pre-trained word embeddings are an integral part of modern NLP systems (devlin2019bert) and one contribution of BERT is pre-trained contextual embedding. Hence, we visualize word embedding in the original BERT model and the BERT model applied with RPP in Figure 3 using t-SNE. The similarity of visualized word embeddings (in BERT and BERT applied with RPP) illustrates the advancement and effectiveness of RPP.

5 Conclusions and Future Work

This paper presents the pruning algorithm RPP, which achieves the first effective weight pruning result on large pre-trained language representation model - BERT. RPP achieves 59.3% weight sparsity without inducing the performance loss on both pre-training and fine-tuning tasks. We spotlight the relationship between the pruning ratio of the pre-trained DNN model and the performance on the downstream multi-task transfer learning objectives. We show that many downstream tasks except SQuAD allows at least 80% pruning ratio compared with 59.3% under task SQuAD. Our proposed Reweighted Proximal Pruning provides a new perspective to analyze what does a large language representation (BERT) learn.

Acknowledgments

This research is supported by National Science Foundation under grants CCF-1919117, CNS-1704662, CCF-193750, CCF-1733701, and CCF-1901378. The authors gratefully acknowledge the support of the computing resources sponsor: Google LLC. We would like to thank Google for providing the computing resources: one Google Cloud TPU V3-512 cluster, three Google Cloud TPU V2-512 clusters and 110 Google Cloud TPU V3-8/V2-8 instances.

References

Appendix

Appendix A Overview of Proposed BERT Pruning

Figure A1: Overview of pruning BERT using Reweighted Proximal Pruning algorithm.

Figure A1 shows the overview of pruning BERT using RPP and then fine-tuning on a wide range of downstream transfer learning tasks. Through RPP, we find the identical universal sparsity, which could be fine-tuned over the downstream transfer learning tasks.

Appendix B Algorithm of New Iterative Pruning

Algorithm B shows the detail process of our proposed NIP algorithm. {algorithm} New Iterative Pruning (NIP) algorithm {algorithmic}[1] \StateInput: Initial model weights 𝐰, initial prune ratio p=0%, progressive prune ratio p \Fort=1,2,,T \State𝐰=𝐰(t-1) \StateSample batch of data from the pre-training data \StateObtain sparsity 𝒮𝐰 through hard threshold pruning, prune ratio pt=tp \StateRetrain 𝐰 over sparsity constraint 𝒮𝐰 \Forall tasks in {𝒯i} \StateFine-tune 𝐳i(𝐰;𝒮𝐰) over sparsity 𝒮𝐰 (if the desired prune ratio pt has been reached for downstream task i) \EndFor\EndFor

Appendix C Algorithm of Reweighted Proximal Pruning (RPP)

Algorithm C shows the detail process of our enhanced AdamW (loshchilov2018decoupled) with proximal operator. {algorithm}[h!] Our enhanced AdamW (loshchilov2018decoupled) with proximal operator {algorithmic}[1] \StateGiven α=0.001,β1=0.9,β2=0.999,ϵ=10-6,λ \StateInitialize time step k0, parameters of pre-trained model 𝐰, first moment vector 𝐦t=0𝟎, second moment vector 𝐯t=0𝟎, schedule multiplier ηk=0 \Repeat\Statekk+1 \Statefk(𝐰k-1) SelectBatch (𝐰k-1) \State𝒈kfk(𝐰k-1) \State𝒎kβ1𝒎k-1+(1-β1)𝒈k \State𝒗k𝜷2𝒗k-1+(1-β2)𝒈k2 \State𝒎^k𝒎k/(1-β1k) \State𝒗^k𝒗k/(1-β2k) \StateηkSetScheduleMultiplier(k) \State𝐚𝐰k-1-ηk(α𝒎^k/(𝒗^k+ϵ)+λ𝐰k-1) \State𝐰kproxλk,rw-1(𝐚) \Untilstopping criterion is met \State\Returnoptimized sparse model 𝐰 in pre-training

Appendix D Downstream Transfer Learning Tasks

As we mentioned in our main paper, we prune the pre-trained BERT model (using NIP and RPP) and then fine-tune the sparse pre-trained model to different down-stream transfer learning tasks. In this section, we exhibit the performance of pruned BERT using NIP and RPP on a wide range of downstream transfer learning tasks to demonstrate our conclusions in the main paper.

D.1 QQP

Quora Question Pairs is a binary classification task where the goal is to determine if two questions asked on Quora are semantically equivalent.

Figure A2: Evaluate the performance of pruned BERTBASE using NIP and RPP, respectively (MLM and NSP accuracy on pre-training data and F1 score of fine-tuning on QQP are reported).

D.2 MRPC

Microsoft Research Paraphrase Corpus consists of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent. dolan2005automatically

Figure A3: Evaluate the performance of pruned BERTBASE using NIP and RPP, respectively (MLM and NSP accuracy on pre-training data and F1 score of fine-tuning on MRPC are reported).

D.3 MNLI

Multi-Genre Natural Language Inference is a large-scale, crowdsourced entailment classification task williams2017broad. Given a pair of sentences, the goal is to predict whether the second sentence is an entailment, contradiction, or neutral with respect to the first one.

Figure A4: Evaluate the performance of pruned BERTBASE using NIP and RPP, respectively (MLM and NSP accuracy on pre-training data and accuracy of fine-tuning on MNLI are reported).

D.4 MNLIM

Multi-Genre Natural Language Inference has a separated evaluation MNLIM. Following (devlin2019bert), the fine-tuning process on MNLIM is separated from MNLI. So we present our results on MNLIM in this subsection.

Figure A5: Evaluate the performance of pruned BERTBASE using NIP and RPP, respectively (MLM and NSP accuracy on pre-training data and accuracy of fine-tuning on MNLIM are reported).

D.5 QNLI

Question Natural Language Inference is a version of the Stanford Question Answering Dataset (Rajpurkar et al., 2016) which has been converted to a binary classification task Wanget al., 2018a). The positive examples are (question, sentence) pairs which do contain the correct answer, and the negative examples are (question, sentence) from the same paragraph which do not contain the answer.

Figure A6: Evaluate the performance of pruned BERTBASE using NIP and RPP, respectively (MLM and NSP accuracy on pre-training data and accuracy of fine-tuning on QNLI are reported).

D.6 SST-2

The Stanford Sentiment Treebank is a binary single-sentence classification task consisting of sentences extracted from movie reviews with human annotations of their sentiment(socher2013recursive).

Figure A7: Evaluate the performance of pruned BERTBASE using NIP and RPP, respectively (MLM and NSP accuracy on pre-training data and accuracy of fine-tuning on SST-2 are reported).

D.7 CoLA

The Corpus of Linguistic Acceptability is a binary single-sentence classification task, where the goal is to predict whether an English sentence is linguistically “acceptable” or not (warstadt2018neural).

Figure A8: Evaluate the performance of pruned BERTBASE using NIP and RPP, respectively (MLM and NSP accuracy on pre-training data and accuracy of fine-tuning on CoLA are reported).

Appendix E Non convergence of Pruning BERT using previous methods

(a) training loss curve 1
(b) training loss curve 2
Figure A9: Training loss curve of directly adding 1 loss on the original objective loss function of BERT
(a) training loss curve 1
(b) training loss curve 2
Figure A10: Training loss curve of directly adding 2 loss on the original objective loss function of BERT.

As we mentioned in our main paper, we investigate a series of pruning techniques to prune BERT, include the iterative pruning method (han2015learning) and the one shot pruning (liu2018rethinking). However, most of the previous pruning techniques requires to directly add the 1/2 regularization loss on the original training objection loss function of DNN model. We execute a school of experiments and find that, this kind of regularization method is not compatible with BERT. The theoretical analysis of this incompatibility is in our main paper. We show the experiment results about this incompatibility in this section. For the sake of fair comparison, we not only adopt the same hyperparameters (in our experiments about NIP and RPP) on iterative pruning and one shot pruning, we execute a wide set of hyperparamters to make the iterative pruning and one shot pruning work. We set the learning λ{210-4,10-4,510-5,310-5,210-5,110-5,110-6,110-7,110-8}, batch size B{256,512,1024,2048}. The number of training epochs reaches the maximum of 128 in Figure 10(b). The maximum number of retraining epochs exceeds the number (40) of epochs used in pre-training BERT from scratch devlin2019bert. We execute the same hyperparameters (with NIP and RPP) and attempt more hyperparameters on the iterative pruning and one shot pruning, but iterative and one shot pruning could not converge to a valid solution.

In Figure A9, we directly add the 1 regularization loss on the original loss function of BERT, to generate the sparsity in BERT model. We find that this kind of regularization leads to non-convergence easily (in Figure A9) and often leads to the gradient exception.

In Figure A10, we directly add the 2 regularization loss on the original loss function of BERT, to generate the sparsity in BERT model. There are gradient exceptions (in Figure 10(a)) when directly add the 2 regularization loss to the original training objectives of BERT. After we add the protecting bound to 2 in avoid of gradient exception, the training process could last longer (128 epochs in Figure 10(b)). However, the training loss curve shows that it could not converge to a valid solution (in Figure 10(b)).