Abstract
Recently, pretrained language representation flourishes as the mainstay ofthe natural language understanding community, e.g., BERT. These pretrainedlanguage representations can create stateoftheart results on a wide range ofdownstream tasks. Along with continuous significant performance improvement,the size and complexity of these pretrained neural models continue to increaserapidly. Is it possible to compress these largescale language representationmodels? How will the pruned language representation affect the downstreammultitask transfer learning objectives? In this paper, we propose ReweightedProximal Pruning (RPP), a new pruning method specifically designed for alargescale language representation model. Through experiments on SQuAD and theGLUE benchmark suite, we show that proximal pruned BERT keeps high accuracy forboth the pretraining task and the downstream multiple finetuning tasks athigh prune ratio. RPP provides a new perspective to help us analyze whatlargescale language representation might learn. Additionally, RPP makes itpossible to deploy a large stateoftheart language representation model suchas BERT on a series of distinct devices (e.g., online servers, mobile phones,and edge devices).
Quick Read (beta)
Reweighted Proximal Pruning for
LargeScale Language Representation
Abstract
Recently, pretrained language representation flourishes as the mainstay of the natural language understanding community, e.g., BERT. These pretrained language representations can create stateoftheart results on a wide range of downstream tasks. Along with continuous significant performance improvement, the size and complexity of these pretrained neural models continue to increase rapidly. Is it possible to compress these largescale language representation models? How will the pruned language representation affect the downstream multitask transfer learning objectives? In this paper, we propose Reweighted Proximal Pruning (RPP), a new pruning method specifically designed for a largescale language representation model. Through experiments on SQuAD and the GLUE benchmark suite, we show that proximal pruned BERT keeps high accuracy for both the pretraining task and the downstream multiple finetuning tasks at high prune ratio. RPP provides a new perspective to help us analyze what largescale language representation might learn. Additionally, RPP makes it possible to deploy a large stateoftheart language representation model such as BERT on a series of distinct devices (e.g., online servers, mobile phones, and edge devices).
Reweighted Proximal Pruning for
LargeScale Language Representation
1 Introduction
Pretrained language representations such as GPT (radford2018improving), BERT (devlin2019bert) and XLNet (yang2019xlnet), have shown substantial performance improvements using selfsupervised training on largescale corpora (dai2015semi; peters2018deep; radford2018improving; liu2019roberta). More interestingly, the pretrained BERT model can be finetuned with just one additional output layer to create stateoftheart models for a wide range of tasks, such as question answering (rajpurkar2016squad; rajpurkar2018know), and language inference (bowman2015large; williams2017broad), without substantial taskspecific architecture modifications. BERT is conceptually simple and empirically powerful (devlin2019bert).
However, along with the significant performance enhancement, the parameter volume and complexity of these pretrained language representations significantly increase. As a result, it becomes difficult to deploy these largescale language representations into reallife computation constrained devices including mobile phones and edge devices. Throughout this paper, we attempt to answer the following questions.
Question 1: Is it possible to compress largescale language representations such as BERT via weight pruning?
Question 2: How would the weightpruned, pretrained model affect the performance of the downstream multitask transfer learning objectives?
The problem of weight pruning has been studied under many types of deep neural networks (DNNs) (goodfellow2016deep), such as AlexNet (krizhevsky2012imagenet), VGG (simonyan2014very), ResNet (he2016deep), and MobileNet (howard2017mobilenets). It is shown that weight pruning can result in a notable reduction in the model size. A suite of weight pruning techniques have been developed, such as nonstructured weight pruning (han2015learning), structured weight pruning (wen2016learning), filter pruning (li2016pruning), channel pruning (he2017channel), ADMMNN (ren2019admm) and PCONV (ma2019pconv) to name a few. Different from pruning CNNtype models, BERT not only considers the metrics on the pretraining task, but also needs to make allowance for the downstream multitask transfer learning objectives. Thus, the desired weight pruning needs to preserve the capacity of transfer learning from a sparse pretrained model to downstream finetuning tasks.
In this work, we investigate irregular weight pruning techniques on the BERT model, including the iterative pruning method (han2015learning) and one shot pruning method (liu2018rethinking). However, these methods fail to converge to a sparse pretrained model without incurring significant accuracy drop, or in many cases do not converge at all (see supporting results in Appendix). Note that the aforementioned weight pruning techniques are built on different sparsitypromoting regularization schemes (han2015learning; wen2016learning), e.g., lasso regression (${\mathrm{\ell}}_{1}$ regularization) and ridge regression (${\mathrm{\ell}}_{2}$ regularization). We find that the failure of previous methods on weight pruning of BERT is possibly due to the inaccurate sparse pattern learnt from the simple ${\mathrm{\ell}}_{1}$ or ${\mathrm{\ell}}_{2}$ based sparsitypromoting regularizer. In fact, the difficulty of applying regularization to generate weight sparsity coincides with the observation in (loshchilov2018decoupled) on the imcompatibility of conventional weight decay (${\mathrm{\ell}}_{2}$ regularization) for training superdeep DNNs as BERT. It is pointed out that the main reason is that the direct addition of a regularization penalty term causes divergence from the original loss function and has negative effect on the effectiveness of gradientbased update. To mitigate this limitation, (loshchilov2018decoupled) have modified the regularization in Adam by decoupling weight decay regularization from the gradientbased update, and have achieved stateoftheart results on downstream multitask transfer learning objectives (devlin2019bert).
In this work, we aim at more accurate sparse pattern search motivated by our experiments and the conclusion from loshchilov2018decoupled. We propose Reweighted Proximal Pruning (RPP), which integrates reweighted ${\mathrm{\ell}}_{1}$ minimization (candes2008enhancing) with proximal algorithm (parikh2014proximal). RPP consists of two parts (see Appendix A for an overview of our approach): the reweighted ${\mathrm{\ell}}_{1}$ minimization and the proximal operator. Reweighted ${\mathrm{\ell}}_{1}$ minimization serves as a better method of generating sparsity in DNN models matching the nature of weight pruning, compared with ${\mathrm{\ell}}_{1}$ regularization. Thanks to the closedform solution of proximal operation on a weighted ${\mathrm{\ell}}_{1}$ norm, in RPP the sparsity pattern search can be decoupled from computing the gradient of the training loss. In this way the aforementioned pitfall in prior weight pruning technique on BERT can be avoided. We show that RPP achieves effective weight pruning on BERT for the first time to the best of our knowledge. Experimental results demonstrate that the proximal pruned BERT model keeps high accuracy on a wide range of downstream tasks, including SQuAD (rajpurkar2016squad; rajpurkar2018know) and GLUE (wang2018glue).
We summarize our contributions as follows.

•
We develop the pruning algorithm Reweighted Proximal Pruning (RPP), which acheives the first effective weight pruning result on large pretrained language representation model  BERT. RPP achieves $59.3\%$ weight sparsity without inducing the performance loss on both pretraining and finetuning tasks.

•
We spotlight the relationship between the pruning ratio of the pretrained DNN model and the performance on the downstream multitask transfer learning objectives. We show that many downstream tasks except for SQuAD allows at least $80\%$ pruning ratio compared with $59.3\%$ under the more challenging task SQuAD.

•
We show that different from weight pruning in image classification tasks, there exists a sparsity pattern in pretrained language representation. Our approach not only provides an effective weight pruning algorithm but also offers a new perspective on analyzing largescale language representation.
2 Related Work
BERT and prior work on model compression
BERT (devlin2019bert) is a selfsupervised approach for pretraining a deep transformer encoder (vaswani2017attention), before finetuning it for particular downstream tasks. Pretraining of BERT optimizes two training objectives $$ masked language modeling (MLM) and next sentence prediction (NSP) $$ which require a large collection of unlabeled text. We use BooksCorpus (800M words) (zhu2015aligning) and the English instance of Wikipedia (2,500M words) as the pretraining corpus, the same as devlin2019bert. For detailed information about the BERT model, readers can refer to the original paper (devlin2019bert).
michel2019sixteen mask some heads in multihead attention modules in BERT, and then evaluate the performance on the machine translation task. Similarly, hao2019multi eliminates certain heads in the multihead attention module. First, the limited previous work do not consider the pretraining metrics and the other downstream multimask transfer learning objectives. They only considered the specific machine translation task (out of over 10 transfer tasks), which is only a specific finetuning and is limited for the universal pretrained language representation (BERT). Second, the multihead attention module uses a weight sharing mechanism (vaswani2017attention). So masking some heads does not reduce the weight volume. Finally, multihead attention allows the model to jointly attend to information from different representation subspaces at different positions, while single attention head inhibits this effect (vaswani2017attention). As a result, masking some heads in multihead attention harms the weight sharing mechanism, without weight volume reduction. In summary, the limited previous work in this area are not effective weight pruning method on BERT. shen2019q reports the quantization result of BERT model, which is orthogonal to our work and can be combined for further compression/acceleration.
Reweighted ${\mathrm{\ell}}_{1}$ and proximal algorithm
candes2008enhancing present reweighted ${\mathrm{\ell}}_{1}$ algorithm and demonstrate the remarkable performance and broad applicability of this algorithm in the areas of statistical estimation, error correction and image processing. Proximal algorithms can be viewed as an analogous tool for nonsmooth, constrained, largescale, or distributed versions of these problems (parikh2014proximal). To the best of our knowledge, we are the first to apply reweighted ${\mathrm{\ell}}_{1}$ and proximal algorithm in the DNN weight pruning domain, and achieve effective weight pruning on BERT.
3 Reweighted Proximal Pruning for largescale language representation during pretraining
Pruning for pretrained language representations should not only consider the performance of pretraining objectives, but also make allowance for the downstream finetuning transfer learning tasks. Let ${f}_{i}$ denote the loss function of network for downstream task ${\mathcal{T}}_{i}\sim p(\mathcal{T})$, where $p(\mathcal{T})$ denotes the distribution of tasks. Let $\mathbf{w}$ denote the parameters of the pretrained model (pretraining in BERT), and ${\mathbf{z}}_{i}$ denote the $i$th taskspecified model parameters (finetuning in BERT). The downstream tasks have separate finetuned models, even though they are initialized with the same pretrained parameters (devlin2019bert). Starting from the pretrained parameters $\mathbf{w}$, the parameters ${\mathbf{z}}_{i}(\mathbf{w})$ are obtained through finetuning
$$\underset{\mathbf{w}\in {\mathbb{R}}^{\mathrm{d}}}{\mathrm{minimize}}{\mathrm{f}}_{\mathrm{i}}(\mathbf{w})$$  (1) 
3.1 Pruning formulation in transfer learning
Following the conventional weight pruning formulation, we first consider the problem of weight pruning during pretraining:
$$\begin{array}{cc}& \underset{\mathbf{w}\in {\mathbb{R}}^{\mathrm{d}}}{\mathrm{minimize}}{\mathrm{f}}_{0}(\mathbf{w})+\gamma {\parallel \mathbf{w}\parallel}_{\mathrm{p}}\hfill \end{array}$$  (2) 
where ${f}_{0}$ is the loss function of pruning, $p\in \{0,1\}$ denotes the type of regularization norm, and $\gamma $ is a regularization term. We note that the sparsitypromoting regularizer in the objective could also be replaced with a hard ${\mathrm{\ell}}_{p}$ constraint, ${\mathbf{w}\parallel}_{p}\le \tau $ for some $\tau $.
Let $\widehat{\mathbf{w}}$ denote the solution to problem (2), and the corresponding sparse pattern ${\mathcal{S}}_{\widehat{\mathbf{w}}}$ is given by
$${\mathcal{S}}_{\widehat{\mathbf{w}}}=\{\mathrm{i}{\widehat{\mathrm{w}}}_{\mathrm{i}}=0,\forall \mathrm{i}\in [\mathrm{d}]\}$$  (3) 
For a specific transfer task $i$, we allow an additional retraining/finetuning step to train/finetune weights starting from the pretraining results $\widehat{\mathbf{w}}$ and subject to the determined, fixed sparse pattern ${\mathcal{S}}_{\widehat{\mathbf{w}}}$, denoted as ${\mathbf{z}}_{i}(\widehat{\mathbf{w}};{\mathcal{S}}_{\widehat{\mathbf{w}}})$. That is, we solve the modified problem equation 1
$$\underset{{\mathbf{z}}_{\mathrm{i}}}{\mathrm{minimize}}{\mathrm{f}}_{\mathrm{i}}\left({\mathbf{z}}_{\mathrm{i}}(\widehat{\mathbf{w}};{\mathcal{S}}_{\widehat{\mathbf{w}}})\right)$$  (4) 
Here, different from (1), the taskspecific fine tuning weights variable ${\mathbf{z}}_{i}(\widehat{\mathbf{w}};{\mathcal{S}}_{\widehat{\mathbf{w}}})$ is now defined over ${\mathcal{S}}_{\widehat{\mathbf{w}}}$.
Our goal is to seek a sparse (weight pruned) model during pretraining, with weight collection $\widehat{\mathbf{w}}$ and sparsity ${\mathcal{S}}_{\widehat{\mathbf{w}}}$, which can perform as well as the original pretrained model over multiple new tasks (indexed by $i$). These finetuned models ${\mathbf{z}}_{i}(\widehat{\mathbf{w}};{\mathcal{S}}_{\widehat{\mathbf{w}}})$ (for different $i$) share the identical universal sparsity ${\mathcal{S}}_{\widehat{\mathbf{w}}}$.
3.2 Reweighted Proximal Pruning
In order to enhance the performance of pruning pretrained language representation over multitask downstream transfer learning objectives, we propose Reweighted Proximal Pruning (RPP). RPP consists of two parts: the reweighted ${\mathrm{\ell}}_{1}$ minimization and the proximal operator. Reweighted ${\mathrm{\ell}}_{1}$ minimization serves as a better method of generating sparsity in DNN models matching the natural objective of weight pruning, compared with ${\mathrm{\ell}}_{1}$ regularization. The proximal algorithm then separates the computation of gradient with the proximal operation over a weighted ${\mathrm{\ell}}_{1}$ norm, without adding any penalty loss to the original objective function of DNN models. This is necessary in the weight pruning of superdeep language representation models
3.2.1 Reweighted ${\mathrm{\ell}}_{1}$ minimization
In the previous pruning methods (han2015learning; wen2016learning), ${\mathrm{\ell}}_{1}$ regularization is used to generate sparsity. However, consider that two weights $$ in the DNN model are penalized through ${\mathrm{\ell}}_{1}$ regularization. The larger weight ${w}_{j}$ is penalized more heavily than smaller weight ${w}_{i}$ in ${\mathrm{\ell}}_{1}$ regularization, which violates the original intention of weight pruning, “removing the unimportant connections” (parameters close to zero) (han2015learning). To address this imbalance, we introduce reweighted ${\mathrm{\ell}}_{1}$ minimization (candes2008enhancing) to the DNN pruning domain. Our introduced reweighted ${\mathrm{\ell}}_{1}$ minimization operates in a systematic and iterative manner (detailed process shown in Algorithm 3.2.1), and the first iteration of reweighted ${\mathrm{\ell}}_{1}$ minimization is ${\mathrm{\ell}}_{1}$ regularization. This designed mechanism helps us to observe the performance difference between ${\mathrm{\ell}}_{1}$ and reweighted ${\mathrm{\ell}}_{1}$ minimization. Meanwhile, this mechanism ensures the advancement of reweighted ${\mathrm{\ell}}_{1}$ minimization over ${\mathrm{\ell}}_{1}$ regularization, as the latter is the single, first step of the former.
Consider the regularized weight pruning problem (reweighted ${\mathrm{\ell}}_{1}$ minimization):
$$\underset{\mathbf{w}}{\mathrm{minimize}}\mathit{\hspace{1em}}{\mathrm{f}}_{0}(\mathbf{w})+\gamma \sum _{\mathrm{i}}{\alpha}_{\mathrm{i}}{\mathrm{w}}_{\mathrm{i}}$$  (5) 
where ${\alpha}_{i}({\alpha}_{i}>0)$ factor is a positive value. It is utilized for balancing the penalty, and is different from weight ${w}_{i}$ in DNN model. ${\alpha}_{i}$ factors will be updated in the iterative reweighted ${\mathrm{\ell}}_{1}$ minimization procedure (Step 2 in Algorithm 3.2.1) in a systematic way (candes2008enhancing). If we set $T=1$ for reweighted ${\mathrm{\ell}}_{1}$, then it reduces to ${\mathrm{\ell}}_{1}$ sparse training.
[h!] {algorithmic}[1] \StateInput: Initial pretrained model ${\mathbf{w}}^{0}$, initial reweighted ${\mathrm{\ell}}_{1}$ minimization ratio $\gamma $, initial positive value ${\alpha}^{0}=1$ \For$t=1,2,\mathrm{\dots},T$
$\mathbf{w}={\mathbf{w}}^{(t1)}$, $\alpha ={\alpha}^{(t1)}$ \StateStep 1: Solve problem (5) to obtain a solution ${\mathbf{w}}^{t}$ via iterative proximal algorithm (6) \StateStep 2: Update reweighted factors ${\alpha}_{i}^{t}=\frac{1}{{{w}_{i}^{t}}^{(t)}+\u03f5}$ (the inside ${w}_{i}^{t}$ denotes the weight ${w}_{i}$ in iteration $t$, and the outside $(t)$ denotes the exponent), $\u03f5$ is a small constant, e.g., $\u03f5=0.001$
3.2.2 Proximal method
In the previous pruning methods (han2015learning; wen2016learning), ${\mathrm{\ell}}_{1}$ regularization loss is directly added on the original training objective loss function of DNN models, and the hardthreshold is adopted to execute the pruning action in the final step of pruning (all weights below the hardthreshold become zero). We cannot add this in our reweighted ${\mathrm{\ell}}_{1}$ regularization directly to avoid negative effect on the gradient update as explained before. Then the remaining challenge is to derive an effective solution to problem (5) for given $\{{\alpha}_{i}\}$, namely, Step 1 in Algorithm 2, in which backpropagation based gradient update is only applied on ${f}_{0}(\mathbf{w})$ but not $\gamma {\sum}_{i}{\alpha}_{i}{w}_{i}$.
We adopt the proximal algorithm (parikh2014proximal) to satisfy this requirement through decoupling methodology. In this way, the sparsity pattern search can be decoupled from computing the gradient of the training loss. The proximal algorithm is shown in (parikh2014proximal) to be highly effective (compared with the original solution) on a wide set of nonconvex optimization problems. Additionally, our presented reweighted ${\mathrm{\ell}}_{1}$ minimization (5) has analytical solution through the proximal operator.
To solve problem (5) for a given $\alpha $, the proximal algorithm operates in an iterative manner:
$${\mathbf{w}}_{\mathrm{k}}={\mathrm{prox}}_{{\lambda}_{\mathrm{k}},\mathrm{rw}{\mathrm{\ell}}_{1}}\left({\mathbf{w}}_{\mathrm{k}1}{\lambda}_{\mathrm{k}}{\nabla}_{\mathbf{w}}{\mathrm{f}}_{0}\left({\mathbf{w}}_{\mathrm{k}1}\right)\right)$$  (6) 
where the subscript $k$ denotes the time step of the training process inside RPP, ${\lambda}_{k}({\lambda}_{k}>0)$ is the learning rate, and we set the initial $\mathbf{w}$ to be ${\mathbf{w}}^{(t1)}$ from the last iteration of reweighted ${\mathrm{\ell}}_{1}$. The proximal operator ${\mathrm{prox}}_{{\lambda}_{k},rw{\mathrm{\ell}}_{1}}(\mathbf{a})$ is the solution to the problem
$$\underset{\mathbf{w}}{\mathrm{minimize}}\gamma \sum _{\mathrm{i}}{\alpha}_{\mathrm{i}}\left{\mathrm{w}}_{\mathrm{i}}\right+\frac{1}{2{\lambda}_{\mathrm{k}}}{\parallel \mathbf{w}\mathbf{a}\parallel}_{2}^{2}$$  (7) 
where $\mathbf{a}={\mathbf{w}}_{k1}{\lambda}_{k}{\nabla}_{\mathbf{w}}f\left({\mathbf{w}}_{k1}\right)$. The above problem has the following analytical solution (liu2014sparsity)
$${\mathrm{w}}_{\mathrm{i},\mathrm{k}}=\{\begin{array}{cc}\left(1\frac{\gamma {\lambda}_{\mathrm{k}}{\alpha}_{\mathrm{i}}}{\left{\mathrm{a}}_{\mathrm{i}}\right}\right){\mathrm{a}}_{\mathrm{i}}\hfill & \left{\mathrm{a}}_{\mathrm{i}}\right>{\lambda}_{\mathrm{k}}\gamma {\alpha}_{\mathrm{i}}\hfill \\ 0\hfill & \left{\mathrm{a}}_{\mathrm{i}}\right\le {\lambda}_{\mathrm{k}}\gamma {\alpha}_{\mathrm{i}}.\hfill \end{array}$$  (8) 
We remark that the updating rule (6) can be interpreted as the proximal step (8) over the gradient descent step ${\mathbf{w}}_{k1}{\lambda}_{k}{\nabla}_{\mathbf{w}}f\left({\mathbf{w}}_{k1}\right)$. Such a descent can also be obtained through other optimizers such as AdamW. We use the AdamW (loshchilov2018decoupled) as our optimizer, the same with (devlin2019bert). The concrete process of AdamW with proximal operator is shown in Algorithm C of Appendix.
Why chooses AdamW rather than Adam? loshchilov2018decoupled proposes AdamW to improve the generalization ability of Adam (kingma2014adam). loshchilov2018decoupled shows that ${\mathrm{\ell}}_{2}$ regularization (or the direct addition of a regularization penalty term) is inherently not effective in Adam and has negative effect on the effectiveness of gradientbased update, which is the reason of the difficulty to apply adaptive gradient algorithms to superdeep DNN training for NLU applications (like BERT). loshchilov2018decoupled mitigates this limitation and improves regularization of Adam, by decoupling weight decay regularization from the gradientbased update (loshchilov2018decoupled). AdamW is widely adopted in pretraining large language representations, e.g., BERT (devlin2019bert), GPT (radford2018improving) and XLNet (yang2019xlnet). Our proposed proximal pruning also benefits from the decoupling design ideology. The difference is that proximal pruning is for the generation of sparsity, instead of avoiding overfitting, like decoupled weight decay in AdamW.
Our new and working baseline: New Iterative Pruning (NIP).
To get the identical universal sparsity ${\mathcal{S}}_{\mathbf{w}}$, we tried a series of pruning techniques, including the iterative pruning method (han2015learning) and one shot pruning method (liu2018rethinking). But these methods do not converge to a viable solution. The possible reason for nonconvergence of the iterative pruning method is that the direct promotion of ${\mathrm{\ell}}_{p}$ ($p\in \{1,2\}$) sparsity on the original pretraining objective loss function might be insufficient to prune BERT. To circumvent the convergence issue of conventional iterative pruning methods, we propose a new iterative pruning (NIP) method. Different from iterative pruning (han2015learning), NIP reflects the naturally progressive pruning performance without any externally introduced penalty. We hope that other pruning methods should not perform worse than NIP, otherwise, the effect of the newly introduced sparsitypromoting regularization has negative effects. We will show that NIP is able to successfully prun BERT to certain pruning ratios. We refer readers to Appendix B for our full baseline algorithm.
4 Experiments
In this section, we describe the experiments on pruning pretrained BERT and demonstrate the performance on 10 downstream transfer learning tasks.
4.1 Experiment Setup
We use the official BERT model from Google as the startpoint. Following the notation from devlin2019bert, we denote the number of layers (i.e., transformer blocks) as $L$, the hidden size as $H$, and the number of selfattention heads as $A$. We prune two kinds of BERT model: ${\mathrm{BERT}}_{\mathrm{BASE}}$ ($L=12,H=768,A=12,\text{total parameters}=110\mathrm{M}$) and ${\mathrm{BERT}}_{\mathrm{LARGE}}$ ($L=24,H=1024,A=16,\text{total parameters}=340\mathrm{M}$). As the kernel weights take up most of the weights volume in each layer (i.e. transformer block), the kernel weights are our pruning target.
Data: In pretraining, we use the same pretraining corpora as devlin2019bert: BookCorpus ($800\mathrm{M}$ words) (zhu2015aligning) and English Wikipedia ($2,500\mathrm{M}$ words). Based on the same corpora, we use the same preprocessing script^{1}^{1} 1 https://github.com/googleresearch/bert to create the pretraining data. In finetuning, we report our results on the Stanford Question Answering Dataset (SQuAD) and the General Language Understanding Evaluation (GLUE) benchmark (wang2018glue). We use two versions of SQuAD: V1.1 and V2.0 (rajpurkar2016squad; rajpurkar2018know). The GLUE is a collection of datasets/tasks for evaluating natural language understanding systems^{2}^{2} 2 The datasets/tasks are: CoLA (warstadt2018neural), Stanford Sentiment Treebank (SST) (socher2013recursive), Microsoft Research Paragraph Corpus (MRPC) (dolan2005automatically), Semantic Texual Similarity Benchmark (STS) (agirre2007semeval), Quora Question Pairs (QQP), MultiGenre NLI (MNLI) (williams2017broad), Question NLI (QNLI) (rajpurkar2016squad), Recognizing Textual Entailment (RTE) and Winograd NLI(WNLI) (levesque2012winograd)..
Input/Output representations: We follow the input/output representation setting from devlin2019bert. We use the WordPiece (wu2016google) embeddings with a $30,000$ token vocabulary. The first token of every sentence is always a special classification token ([CLS]). The sentences are differentiated with a special token ([SEP]). Following (devlin2019bert), we use the same input/output representations for both pretraining and finetuning.
Evaluation: In pretraining, BERT considers two objectives: masked language modeling (MLM) and next sentence prediction (NSP). For MLM, a random sample of the tokens in the input sequence is selected and replaced with the special token $([\text{MASK}])$. The MLM objective is a crossentropy loss on predicting the masked tokens. NSP is a binary classification loss for predicting whether two segments follow each other in the original text. We use MLM and NSP to pretrain, retrain and evaluate the pretrained BERT model. In finetuning, F1 scores are reported for SQuAD, QQP and MRPC. Accuracy scores are reported for the other tasks.
All the experiments execute on one Google Cloud TPU V3512 cluster, three Google Cloud TPU V2512 clusters and 110 Google Cloud TPU V38/V28 instances.
Baseline: As there is no public effective BERT pruning method, we use the proposed NIP pruning method on BERT as the baseline method. The detailed algorithm is shown Appendix B. The progressive pruning ratio is $\nabla p=10\%$ (prune $10\%$ more weights in each iteration). Starting from the official ${\mathrm{BERT}}_{\mathrm{BASE}}$, we use 9 iterations. In each iteration $t$ of NIP, we get the sparse ${\mathrm{BERT}}_{\mathrm{BASE}}$ with specific sparsity, as $({\mathbf{w}}^{t};{\mathcal{S}}_{{\mathbf{w}}^{t}})$. Then we retrain the sparse ${\mathrm{BERT}}_{\mathrm{BASE}}$ ${\mathbf{w}}^{t}$ over the sparsity ${\mathcal{S}}_{{\mathbf{w}}^{t}}$. In the retraining process, the initial learning rate is $2\cdot {10}^{5}$, the batch size is $1024$ and the retraining lasts for $10,000$ steps (around 16 epochs). For the other hyperparameters, we follow the original BERT paper devlin2019bert. In each iteration, the well retrained sparse ${\mathrm{BERT}}_{\mathrm{BASE}}$ is the starting point for the finetuning tasks and the next iteration.
4.2 Reweighed Proximal Pruning (RPP)
We apply the proposed Reweighted Proximal Pruning (RPP) method on both ${\mathrm{BERT}}_{\mathrm{BASE}}$ and ${\mathrm{BERT}}_{\mathrm{LARGE}}$, and demonstrate performance improvement. Detailed process of RPP is in Appendix C.
For ${\mathrm{BERT}}_{\mathrm{BASE}}$, we use the hyperparameters exactly the same with our experiments using NIP. The initial learning rate is $\lambda =2\cdot {10}^{5}$ and the batch size is 1024. We iterate the RPP for six times ($T=6$), and each iteration lasts for $100,000$ steps (around 16 epochs). The total number of epochs in RPP is smaller than NIP when achieving 90% sparsity ($$). There is no retraining process in RPP. We set $\gamma \in \{{10}^{2},{10}^{3},{10}^{4},{10}^{5}\}$ and $\u03f5={10}^{9}$ in Algorithm 3.2.1. Recall that RPP reduces to ${\mathrm{\ell}}_{1}$ sparse training as $t=1$.
In Figure 1, we present the accuracy versus the pruning ratio for pretraining tasks MLM and NSP, and finetuning task SQuAD 1.1. Here we compare RPP with NIP. Along with the RPP continuing to iterate, the performance of RPP becomes notably higher than NIP for both the pretraining task and the finetuning task. The gap further increases as the RPP iterates more times. In Figure 1, we find that the NSP accuracy is very robust to pruning. Even when $90\%$ of the attention weights are pruned, the NSP accuracy keeps above $95\%$ in RPP algorithm and around $90\%$ in NIP algorithm. For MLM accuracy and SQuAD F1 score, the performance drops quickly as the prune ratio increases. RPP slows down the decline trend to a great extent. On SQuAD 1.1 dataset/task, RPP keeps the F1 score of ${\mathrm{BERT}}_{\mathrm{BASE}}$ at 88.5 ($0$ degradation compared with original BERT) at $41.2\%$ prune ratio, while the F1 score of ${\mathrm{BERT}}_{\mathrm{BASE}}$ applied with NIP drops to 84.6 ($3.9$ degradation) at $40\%$ prune ratio. At $80\%$ prune ratio, RPP keeps the F1 score of ${\mathrm{BERT}}_{\mathrm{BASE}}$ at 84.7 ($3.8$ degradation), while the F1 score of ${\mathrm{BERT}}_{\mathrm{BASE}}$ applied with NIP drops to 68.8 ($19.7$ degradation compared with the original BERT). In addition to the finetuning task of SQuAD 1.1, the other transfer learning tasks show the same trend (RPP consistently outperforms NIP) and the detailed results are reported in Appendix.
For ${\mathrm{BERT}}_{\mathrm{LARGE}}$, we use the hyperparameters exactly the same with our experiments using NIP except for the batch size. The initial learning rate is $2\cdot {10}^{5}$ and the batch size is 512. We iterate the RPP for four times ($T=4$), and each iteration lasts for $100,000$ steps (around 8 epochs). There is no retraining process either. We set $\gamma \in \{{10}^{2},{10}^{3}\}$ and $\u03f5={10}^{9}$ in Algorithm 3.2.1. The experimental results about pruning ${\mathrm{BERT}}_{\mathrm{LARGE}}$ and then finetuning are shown in Table 1.
Method  Prune Ratio($\%$)  SQuAD 1.1  QQP  MNLI  MRPC  CoLA 

NIP  50.0  85.3 (5.6)  85.1 (6.1)  77.0 (9.1)  83.5 (5.5)  76.3 (5.2) 
80.0  75.1 (15.8)  81.1 (10.1)  73.81 (12.29)  68.4 (20.5)  69.13 (12.37)  
RPP  59.3  90.23 (0.67)  91.2 (0.0)  86.1 (0.0)  88.1 (1.2)  82.8 (+1.3) 
88.4  81.69 (9.21)  89.2 (2.0)  81.4 (4.7)  81.9 (7.1)  79.3 (2.2) 
Method  Prune Ratio($\%$)  SQuAD 2.0  QNLI  MNLIM  SST2  RTE 

NIP  50.0  75.3 (6.6)  90.2 (1.1)  82.5 (3.4)  91.3 (1.9)  68.6 (1.5) 
80.0  70.1 (11.8)  80.5 (10.8)  78.4 (7.5)  88.7 (4.5)  62.8 (7.3)  
RPP  59.3  81.3 (0.6)  92.3 (+1.0)  85.7 (0.2)  92.4 (0.8)  70.1 (0.0) 
88.4  80.7 (1.2)  88.0 (3.3)  81.8 (4.1)  90.5 (2.7)  67.5 (2.6) 
4.3 Visualizing Attention Pattern in BERT
We visualize the sparse pattern of the kernel weights in sparse BERT model applied with RPP, and present several examples in Figure 2. Because we directly visualize the value of identical universal sparsity ${\mathcal{S}}_{\mathbf{w}}$ without any auxiliary function instead of activation map, the attention pattern is universal and data independent.
BERT’s model architecture is a multilayer, bidirectional transformer encoder based on the original implementation (vaswani2017attention). Following (vaswani2017attention), the transformer architecture is based on “scaled dotproduct attention.” The input consists of queries, keys and values, denoted as matrices $Q$, $K$ and $V$, respectively. The output of attention model is computed as
$$\text{Attention}(\mathrm{Q},\mathrm{K},\mathrm{V})=\mathrm{softmax}\left(\frac{{\mathrm{QK}}^{\mathrm{T}}}{\sqrt{{\mathrm{d}}_{\mathrm{k}}}}\right)\mathrm{V}$$  (9) 
where ${d}_{k}$ is the dimension. We visualize the sparse matrices $Q$, $K$ and $V$ successively in Figure 2. The visualization of sparse $Q$ demonstrates obvious vertical distribution. The visualization of sparse $K$ is mainly in vertical distribution, with some strong horizontal linking lines. The visualization interpretation reveals that the query matrix $Q$ mainly models the information inside each sequence, while the key matrix $K$ uses strong horizontal linking lines to build the relationship between different sequences in the context. In summary, the sparse attention pattern exhibits obvious structured distribution.
4.4 $t$SNE Visualization
$t$Distributed Stochastic Neighbor Embedding ($t$SNE) is a technique for dimensionality reduction that is particularly well suited for the visualization of highdimensional datasets (maaten2008visualizing). Pretrained word embeddings are an integral part of modern NLP systems (devlin2019bert) and one contribution of BERT is pretrained contextual embedding. Hence, we visualize word embedding in the original BERT model and the BERT model applied with RPP in Figure 3 using $t$SNE. The similarity of visualized word embeddings (in BERT and BERT applied with RPP) illustrates the advancement and effectiveness of RPP.
5 Conclusions and Future Work
This paper presents the pruning algorithm RPP, which achieves the first effective weight pruning result on large pretrained language representation model  BERT. RPP achieves $59.3\%$ weight sparsity without inducing the performance loss on both pretraining and finetuning tasks. We spotlight the relationship between the pruning ratio of the pretrained DNN model and the performance on the downstream multitask transfer learning objectives. We show that many downstream tasks except SQuAD allows at least $80\%$ pruning ratio compared with $59.3\%$ under task SQuAD. Our proposed Reweighted Proximal Pruning provides a new perspective to analyze what does a large language representation (BERT) learn.
Acknowledgments
This research is supported by National Science Foundation under grants CCF1919117, CNS1704662, CCF193750, CCF1733701, and CCF1901378. The authors gratefully acknowledge the support of the computing resources sponsor: Google LLC. We would like to thank Google for providing the computing resources: one Google Cloud TPU V3512 cluster, three Google Cloud TPU V2512 clusters and 110 Google Cloud TPU V38/V28 instances.
References
Appendix
Appendix A Overview of Proposed BERT Pruning
Figure A1 shows the overview of pruning BERT using RPP and then finetuning on a wide range of downstream transfer learning tasks. Through RPP, we find the identical universal sparsity, which could be finetuned over the downstream transfer learning tasks.
Appendix B Algorithm of New Iterative Pruning
Algorithm B shows the detail process of our proposed NIP algorithm. {algorithm} {algorithmic}[1] \StateInput: Initial model weights $\mathbf{w}$, initial prune ratio $p=0\%$, progressive prune ratio $\nabla p$ \For$t=1,2,\mathrm{\dots},T$ \State$\mathbf{w}={\mathbf{w}}^{(t1)}$ \StateSample batch of data from the pretraining data \StateObtain sparsity ${\mathcal{S}}_{\mathbf{w}}$ through hard threshold pruning, prune ratio ${p}^{t}=t\cdot \nabla p$ \StateRetrain $\mathbf{w}$ over sparsity constraint ${\mathcal{S}}_{\mathbf{w}}$ \Forall tasks in $\{{\mathcal{T}}_{i}\}$ \StateFinetune ${\mathbf{z}}_{i}(\mathbf{w};{\mathcal{S}}_{\mathbf{w}})$ over sparsity ${\mathcal{S}}_{\mathbf{w}}$ (if the desired prune ratio ${p}^{t}$ has been reached for downstream task $i$) \EndFor\EndFor
Appendix C Algorithm of Reweighted Proximal Pruning (RPP)
Algorithm C shows the detail process of our enhanced AdamW (loshchilov2018decoupled) with proximal operator. {algorithm}[h!] {algorithmic}[1] \StateGiven $\alpha =0.001,{\beta}_{1}=0.9,{\beta}_{2}=0.999,\u03f5={10}^{6},\lambda \in \mathbb{R}$ \StateInitialize time step $k\leftarrow 0$, parameters of pretrained model $\mathbf{w}$, first moment vector ${\mathbf{m}}_{t=0}\leftarrow \mathrm{\U0001d7ce}$, second moment vector ${\mathbf{v}}_{t=0}\leftarrow \mathrm{\U0001d7ce}$, schedule multiplier ${\eta}_{k=0}\in \mathbb{R}$ \Repeat\State$k\leftarrow k+1$ \State$\nabla {f}_{k}\left({\mathbf{w}}_{k1}\right)\leftarrow \text{SelectBatch}\left({\mathbf{w}}_{k1}\right)$ \State${\bm{g}}_{k}\leftarrow \nabla {f}_{k}\left({\mathbf{w}}_{k1}\right)$ \State${\bm{m}}_{k}\leftarrow {\beta}_{1}{\bm{m}}_{k1}+\left(1{\beta}_{1}\right){\bm{g}}_{k}$ \State${\bm{v}}_{k}\leftarrow {\bm{\beta}}_{2}{\bm{v}}_{k1}+\left(1{\beta}_{2}\right){\bm{g}}_{k}^{2}$ \State${\widehat{\bm{m}}}_{k}\leftarrow {\bm{m}}_{k}/\left(1{\beta}_{1}^{k}\right)$ \State${\widehat{\bm{v}}}_{k}\leftarrow {\bm{v}}_{k}/\left(1{\beta}_{2}^{k}\right)$ \State${\eta}_{k}\leftarrow \mathrm{SetScheduleMultiplier}(k)$ \State$\mathbf{a}\leftarrow {\mathbf{w}}_{k1}{\eta}_{k}\left(\alpha {\widehat{\bm{m}}}_{k}/(\sqrt{{\widehat{\bm{v}}}_{k}}+\u03f5)+\lambda {\mathbf{w}}_{k1}\right)$ \State${\mathbf{w}}_{k}\leftarrow {\mathrm{prox}}_{{\lambda}_{k},rw{\mathrm{\ell}}_{1}}(\mathbf{a})$ \Untilstopping criterion is met \State\Returnoptimized sparse model $\mathbf{w}$ in pretraining
Appendix D Downstream Transfer Learning Tasks
As we mentioned in our main paper, we prune the pretrained BERT model (using NIP and RPP) and then finetune the sparse pretrained model to different downstream transfer learning tasks. In this section, we exhibit the performance of pruned BERT using NIP and RPP on a wide range of downstream transfer learning tasks to demonstrate our conclusions in the main paper.
D.1 QQP
Quora Question Pairs is a binary classification task where the goal is to determine if two questions asked on Quora are semantically equivalent.
D.2 MRPC
Microsoft Research Paraphrase Corpus consists of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent. dolan2005automatically
D.3 MNLI
MultiGenre Natural Language Inference is a largescale, crowdsourced entailment classification task williams2017broad. Given a pair of sentences, the goal is to predict whether the second sentence is an entailment, contradiction, or neutral with respect to the first one.
D.4 MNLIM
MultiGenre Natural Language Inference has a separated evaluation MNLIM. Following (devlin2019bert), the finetuning process on MNLIM is separated from MNLI. So we present our results on MNLIM in this subsection.
D.5 QNLI
Question Natural Language Inference is a version of the Stanford Question Answering Dataset (Rajpurkar et al., 2016) which has been converted to a binary classification task Wanget al., 2018a). The positive examples are (question, sentence) pairs which do contain the correct answer, and the negative examples are (question, sentence) from the same paragraph which do not contain the answer.
D.6 SST2
The Stanford Sentiment Treebank is a binary singlesentence classification task consisting of sentences extracted from movie reviews with human annotations of their sentiment(socher2013recursive).
D.7 CoLA
The Corpus of Linguistic Acceptability is a binary singlesentence classification task, where the goal is to predict whether an English sentence is linguistically “acceptable” or not (warstadt2018neural).
Appendix E Non convergence of Pruning BERT using previous methods
As we mentioned in our main paper, we investigate a series of pruning techniques to prune BERT, include the iterative pruning method (han2015learning) and the one shot pruning (liu2018rethinking). However, most of the previous pruning techniques requires to directly add the ${\mathrm{\ell}}_{1}/{\mathrm{\ell}}_{2}$ regularization loss on the original training objection loss function of DNN model. We execute a school of experiments and find that, this kind of regularization method is not compatible with BERT. The theoretical analysis of this incompatibility is in our main paper. We show the experiment results about this incompatibility in this section. For the sake of fair comparison, we not only adopt the same hyperparameters (in our experiments about NIP and RPP) on iterative pruning and one shot pruning, we execute a wide set of hyperparamters to make the iterative pruning and one shot pruning work. We set the learning $\lambda \in \{2\cdot {10}^{4},{10}^{4},5\cdot {10}^{5},3\cdot {10}^{5},2\cdot {10}^{5},1\cdot {10}^{5},1\cdot {10}^{6},1\cdot {10}^{7},1\cdot {10}^{8}\}$, batch size $B\in \{256,512,1024,2048\}$. The number of training epochs reaches the maximum of 128 in Figure 10(b). The maximum number of retraining epochs exceeds the number (40) of epochs used in pretraining BERT from scratch devlin2019bert. We execute the same hyperparameters (with NIP and RPP) and attempt more hyperparameters on the iterative pruning and one shot pruning, but iterative and one shot pruning could not converge to a valid solution.
In Figure A9, we directly add the ${\mathrm{\ell}}_{1}$ regularization loss on the original loss function of BERT, to generate the sparsity in BERT model. We find that this kind of regularization leads to nonconvergence easily (in Figure A9) and often leads to the gradient exception.
In Figure A10, we directly add the ${\mathrm{\ell}}_{2}$ regularization loss on the original loss function of BERT, to generate the sparsity in BERT model. There are gradient exceptions (in Figure 10(a)) when directly add the ${\mathrm{\ell}}_{2}$ regularization loss to the original training objectives of BERT. After we add the protecting bound to ${\mathrm{\ell}}_{2}$ in avoid of gradient exception, the training process could last longer (128 epochs in Figure 10(b)). However, the training loss curve shows that it could not converge to a valid solution (in Figure 10(b)).