Understanding black-box machine learning models is important towards theirwidespread adoption. However, developing globally interpretable models thatexplain the behavior of the entire model is challenging. An alternativeapproach is to explain black-box models through explaining individualprediction using a locally interpretable model. In this paper, we propose anovel method for locally interpretable modeling - Reinforcement Learning-basedLocally Interpretable Modeling (RL-LIM). RL-LIM employs reinforcement learningto select a small number of samples and distill the black-box model predictioninto a low-capacity locally interpretable model. Training is guided with areward that is obtained directly by measuring agreement of the predictions fromthe locally interpretable model with the black-box model. RL-LIM near-matchesthe overall prediction performance of black-box models while yieldinghuman-like interpretability, and significantly outperforms state of the artlocally interpretable models in terms of overall prediction performance andfidelity.
Quick Read (beta)
RL-LIM: Reinforcement Learning-based
Locally Interpretable Modeling
Understanding black-box machine learning models is important towards their widespread adoption. However, developing globally interpretable models that explain the behavior of the entire model is challenging. An alternative approach is to explain black-box models through explaining individual prediction using a locally interpretable model. In this paper, we propose a novel method for locally interpretable modeling – Reinforcement Learning-based Locally Interpretable Modeling (RL-LIM). RL-LIM employs reinforcement learning to select a small number of samples and distill the black-box model prediction into a low-capacity locally interpretable model. Training is guided with a reward that is obtained directly by measuring agreement of the predictions from the locally interpretable model with the black-box model. RL-LIM near-matches the overall prediction performance of black-box models while yielding human-like interpretability, and significantly outperforms state of the art locally interpretable models in terms of overall prediction performance and fidelity.
|Department of Electrical and|
|Computer Engineering, UCLA, CA|
|Sercan Ö. Arık|
|Google Cloud AI|
|Google Cloud AI|
Artificial Intelligence (AI) is advancing at a rapid pace, particularly with recent advances in deep neural networks and ensemble methods (Goodfellow et al., 2016; He et al., 2016; Chen & Guestrin, 2016; Ke et al., 2017). This progress has been fueled by ‘black-box’ machine learning models where the decision making is controlled by complex non-linear interactions between many parameters that are difficult for humans to understand and interpret. However, in many real-world applications AI systems are not only expected to perform well but are also required to be interpretable: doctors need to understand why a particular treatment is recommended, and financial institutions need to understand why a loan was declined. Use cases of model interpretability vary across applications: it can provide trust to users by showing rationales behind decisions, enable detection of systematic failure cases, and provide actionable feedback for improving models (Rudin, 2018).
Many studies have suggested a trade-off between performance and interpretability (Virág & Nyitrai, 2014; Johansson et al., 2011). This is correct in that globally interpretable models, which attempt to explain the entire model behavior, typically yield considerably worse performance than ‘black-box’ models (Lipton, 2016). To go beyond the performance limitations of globally interpretable models, another promising direction is locally interpretable models, which instead of explaining the entire model explain a single prediction (Ribeiro et al., 2016). Methodologically, while a globally interpretable model fits a single inherently interpretable model (such as a linear model or a shallow decision tree) to the entire training set, locally interpretable models aim to fit an inherently interpretable model locally, i.e. for each instance individually, by distilling knowledge from a high performance black-box model. Such locally interpretable models are very useful for real-world AI deployments to provide succinct and human-like explanations to users. They can be used to identify systematic failure cases (e.g. by seeking common trends in input dependence for failure cases), detect biases (e.g. by quantifying feature importance for a particular variable), and provide actionable feedback to improve a model (e.g. understand failure cases and what training data to collect).
To be useful in practice, locally interpretable models need to maximize two objectives: (i) the overall prediction performance (how well it predicts compared to the ground truth labels) – for the model to be accurate, and (ii) fidelity (how well it approximates the ‘black-box’ model predictions) – to ensure the model is reliably approximating the black-box model’s predictions in the neighborhood of interest (Plumb et al., 2019; Lakkaraju et al., 2019). To this end, a few methods have recently been proposed for locally interpretable modeling: Local Interpretable Model-agnostic Explanations (LIME) (Ribeiro et al., 2016), Supervised Local modeling methods (SILO) (Bloniarz et al., 2016), and Model Agnostic Supervised Local Explanations (MAPLE) (Plumb et al., 2018). LIME in particular has gained notable popularity and has been deployed in many applications due to its simplicity. However, the overall prediction performance and fidelity metrics are not reaching desired levels in many cases (Alvarez-Melis & Jaakkola, 2018; Zhang et al., 2019; Ribeiro et al., 2018; Lakkaraju et al., 2017). Indeed, as we show in our experiments, there are frequent cases where existing locally interpretable models even underperform commonly low-performing globally interpretable models.
One of the fundamental challenges to fit a locally interpretable model is the representational capacity difference while applying distillation. Black-box machine learning models, such as deep neural networks or ensemble models, have much larger representational capacity than locally interpretable models. This can result in underfitting with conventional distillation techniques, leading to suboptimal performance (Hinton et al., 2015; Wang et al., 2019). We address this fundamental challenge by proposing a novel Reinforcement Learning-based method to fit Locally Interpretable Models which we call RL-LIM. RL-LIM efficiently utilizes the small representational capacity of locally interpretable models by training with a small number of samples that are determined to have the highest value contribution to the fitting of a locally interpretable model. In order to select these highest-value instances, we train instance-wise weight estimators (modeled with deep neural networks) using a reinforcement signal that quantifies the fidelity metric (i.e. how well does the model approximate the black-box model predictions). The contributions of this paper can be summarized as:
We introduce the first method that tackles interpretability through data-weighted training, and show that reinforcement learning is highly effective for end-to-end training of such a model.
We show that distillation of a black-box model into a low-capacity interpretable model can be significantly improved by fitting with a small subset of relevant samples that is controlled efficiently by our method.
On various classification and regression datasets, we demonstrate that RL-LIM significantly outperforms alternative models (LIME, SILO and MAPLE) in overall prediction performance and fidelity metrics – in most cases, the overall performance of locally interpretable models obtained by RL-LIM is very similar to complex black-box models.
2 Related Work
Locally interpretable models: There are various approaches to interpret black-box models – (Gilpin et al., 2018) provides a good overview. One approach is to directly decompose the prediction into feature attributions by considering what-if cases. Shapley values (Štrumbelj & Kononenko, 2014) and their computationally-efficient variants (Lundberg & Lee, 2017) are commonly-used methods in this category. Other notable methods are based on activation differences, e.g. DeepLIFT (Shrikumar et al., 2017), or saliency maps using the gradient flows, e.g. CAM (Zhou et al., 2016) and Grad-CAM (Selvaraju et al., 2017). In this paper, we focus on the direction of locally interpretable modeling – distilling a black-box model into an interpretable model for each input instance.
Locally Interpretable Model-agnostic Explanation (LIME) (Ribeiro et al., 2016) is the most popular method for locally interpretable modeling. LIME is based on modifying a data instance by tweaking the feature values and then learning from the impact of the modifications on the output. A fundamental challenge for LIME is the need for a meaningful distance metric to determine neighborhoods, as simple metrics like Euclidean distance may yield poor fidelity in some cases and the estimation can be highly-sensitive to normalization (Alvarez-Melis & Jaakkola, 2018) especially with categorical variables. Supervised Local modeling methods (SILO) (Bloniarz et al., 2016)) aims to improve LIME by determining the neighborhoods for each instance using ad-hoc tree-based ensemble methods. Model Agnostic Supervised Local Explanations (MAPLE) (Plumb et al., 2018) furthers adds a method for feature selection on top of SILO – it utilizes ad-hoc tree-based ensemble methods to determine the weights of training instances for each target instance and uses the weights to optimize a locally interpretable model. However, SILO and MAPLE still have shortcomings because the tree-based ensemble methods are optimized independently from the locally interpretable model – lack of joint optimization results in suboptimal fidelity for the locally interpretable model. Overall, to construct a locally interpretable model, a key problem is how to select the optimal training instances for each testing instance, because the selected training instances mostly determine the constructed locally interpretable model. The number of possibilities for training instance selection is extremely large (exponential in the number of training instances). LIME heuristically utilizes Euclidean distances, whereas SILO and MAPLE use ad-hoc tree-based ensemble methods. Our proposed method, RL-LIM, takes a very different perspective: to properly and efficiently explore the large possible solution space, RL-LIM utilizes reinforcement learning to find the optimal policy that selects the training instances that maximize the fidelity of the locally interpretable model.
Data-weighted training: Optimal weighing of training data is a paramount problem in machine learning. By upweighting valuable instances and downweighting the low quality or problematic instances, better performance can be obtained in certain learning scenarios, such as imbalanced or noisy labels (Jiang et al., 2018). One approach for data weighting is utilizing Influence Functions (Koh & Liang, 2017), that are based on oracle access to gradients and Hessian-vector products. Jointly-trained student-teacher methods constitute another approach (Jiang et al., 2018; Bengio et al., 2009) to learn a data-driven curriculum. Using the feedback from the teacher network, training instance-wise weights are learned for the student model. Aligned with our motivations, meta learning is considered for data weighting in Ren et al. (2018). Their proposed method utilizes gradient descent-based meta learning, guided by a small validation set, to maximize the target performance.
In this work we consider data-weighted training for a novel purpose: interpretability. Unlike gradient descent-based meta learning, our approach uses reinforcement learning to integrate the reward directly with the fidelity metric. Aforementioned works estimate the same ranking of training instances for the entire dataset. Instead, our method yields an instance-wise ranking of training data points, different for each testing instance. This enables efficient distillation of a black-box model prediction into a locally interpretable model.
3 Reinforcement Learning-based Modeling
We consider a training dataset for training of a black-box model , where is the feature vector in a -dimensional feature space and is the corresponding label in a label space . We also assume that there exists a probe dataset where is the number of probe instances. The probe dataset is used to evaluate the model performance to guide meta-learning as in Ren et al. (2018). If there is no explicit probe dataset, we can randomly partition a subset of the training dataset as the probe dataset and the remainder as the training dataset. RL-LIM is composed of three models:
Black-box model – any machine learning model that needs to be explained (e.g. a deep neural network or a decision tree-based ensemble model),
Locally interpretable model – an inherently interpretable model by design (e.g. a linear model or a shallow decision tree),
Instance-wise weight estimation model – a function that outputs the instance-wise weights to fit the locally interpretable model. It uses concatenation of a probe feature, a training feature, and a corresponding black-box model prediction on the training feature as its inputs. It can be a complex machine learning model – e.g. here a deep neural network.
Our objective is to construct an accurate locally interpretable model such that the predictions made by it are similar to the predictions of the given black-box model – i.e. the locally interpretable model has high fidelity. We use a loss function, to quantify the fidelity of the locally interpretable model (e.g. mean absolute error, lower the better).
The representational capacity difference between the black-box model and the locally interpretable model is the bottleneck we aim to address. Ideally, to avoid underfitting, locally interpretable models should be learned with a minimal number of training instances that are most effective in capturing the model behavior. We propose an instance-wise weight estimation model to estimate the probabilities of training instances that should be used for fitting the locally interpretable model. Integrating with the accurate locally interpretable modeling goal, we propose the following objective:
where is a hyper-parameter that controls the number of training instances used to fit the locally interpretable model (we study the impact of performance on in Section 4.2), and represents the instance-wise weight for each training pair for the probe data . is the loss function to fit the locally interpretable model, for which we use the mean squared error between predicted values for regression and logits for classification. and are the trainable parameters, whereas (the pre-trained black-box model) is fixed.
The first term in the objective function represents the local prediction differences between black-box model and locally interpretable model (referred to as fidelity metric). The second term in the objective function represents the expected number of selected training points to fit the locally interpretable model. Lastly, the constraint ensures that the locally interpretable model is derived from weighted loss function, where weights are the output of the instance-wise weight estimator . Our formulation does not assume any constraint on – it could be any inherently interpretable model suitable for the data type of interest. Next, we describe how Eq. (1) can be efficiently addressed with reinforcement learning.
3.1 Training and inference
The RL-LIM method, shown in Fig. 1, can be thought of as encompassing 5 stages:
Stage 0 – Black-box model training: This stage is the preliminary stage for RL-LIM. Given the training set , the black-box model is trained to minimize a loss function () (e.g. mean squared error for regression or cross-entropy for classification), i.e., . If the pre-trained black-box model is already saved, we can skip this stage and retrieve the given pre-trained black-box model to .
Stage 1 – Auxiliary dataset construction: Using the pre-trained black-box model , we create auxiliary training and probe datasets, as (where ) and (where ), respectively. These auxiliary datasets (, ) are used for instance-wise weight estimation models and locally interpretable model training.
Stage 2 – Interpretable baseline training: To improve the stability of the instance-wise weight estimator training, a baseline model is observed to be beneficial. As the baseline model , we use a globally interpretable model (such as a linear model or shallow decision tree) optimized to replicate the predictions of the black-box model: .
Stage 3 – Instance-wise weight estimator training: We train an instance-wise weight estimator using the auxiliary datasets (, ). To encourage exploration, we consider probabilistic selection, with a sampler block that is based on the output of the instance-wise weight estimator – represents the probability that is selected to train locally interpretable model for the probe instance . Let the binary vector represent the selection operation, such that is selected for training locally interpretable model for when . Correspondingly, is the probability mass function for given :
As the original form of the optimization problem in Eq. (1) is intractable due to the expectation operations, we employ approximations:
The sample mean is used as an approximation of the first term of the objective function as .
The second term of the objective, which represents the average selection probability, is approximated as the number of selected instances (divided by ) to have .
The constraint term is approximated using the sample mean of the training loss as .
The sampler block yields a non-differential objective, and we cannot train the instance-wise weight estimator using conventional gradient descent-based optimization. There are approximations such as training in expectation (Raffel et al., 2017) or Gumbel-softmax (Jang et al., 2016). Instead, motivated by its many successful applications (Ranzato et al., 2015; Zaremba & Sutskever, 2015; Zhang & Lapata, 2017), we use REINFORCE algorithm (Williams, 1992) such that the selection action is rewarded by the performance of its impact. The loss function for the instance-wise weight estimator is expressed as:
To apply the REINFORCE algorithm, we directly compute the gradient as:
Using the gradient , we employ the following steps iteratively to update the parameters of the instance-wise weight estimator :
Estimate instance-wise weights and instance-wise selection vector for each training and probe instance in a mini-batch.
Optimize the locally interpretable model with the selection vector for each probe instance:
Update the instance-wise weight estimation model parameter :
where is a learning rate and is the baseline loss against which we benchmark the performance improvement. We repeat the steps above until convergence.
Stage 4 – Interpretable inference: Unlike when training, we use a fixed instance-wise weight estimator (without the sampler and interpretable baseline) and merely fit the locally interpretable model at inference. Given the test instance , we obtain the selection probabilities from the instance-wise weight estimator, and using these as the weights, we fit the locally interpretable model via weighted optimization. The outputs of the trained interpretable model are the instance-wise predictions and the corresponding explanations (e.g., local dynamics of the black-box model predictions at given by the coefficients of the fitted linear model).
3.2 Computational cost
In this subsection, we analyze the computational cost of RL-LIM for training and inference. As a representative and commonly used example, we assume linear regression as the locally interpretable model, which has a computational complexity of to fit, where is the number of features and is the number of training instances. When (which is often the case in practice), the training computational complexity is approximated as (Tan, 2018).
Training: Given a pre-trained black-box model, Stage 1 involves running inference times and the total complexity depends on the complexity of the black-box model. Unless the black-box model is very complex, the computational complexity of Stage 1 becomes much smaller than Stage 3. Stage 2 has negligible computational overhead. At Stage 3, we iteratively train the instance-wise weight estimator and fit the locally interpretable model from scratch using weighted optimization. Therefore, the computational complexity is where is the number of iterations in Stage 3 (typically until convergence). Thus, the training complexity scales roughly linearly with the number of training instances.
Interpretable inference: To infer with the locally interpretable model, we need to fit the locally interpretable model after obtaining the instance-wise weights from the trained instance-wise weight estimator. Thus, for each testing instance, the computational complexity is .11 1 A subset of the training dataset can be used to reduce complexity (with decreased fidelity).
For instance, on a single NVIDIA V100 GPU, on Facebook Comment dataset (consisting 600,000 samples), RL-LIM yields a training time of less than 5 hours (including Stage 1, 2 and 3) and an interpretable inference time of less than 10 seconds per a testing instance. On the other hand, LIME results in much longer interpretable inference time (around 30 seconds per a testing instance) due to acquiring a large number of black-box model predictions for the inputs perturbations, whereas SILO and MAPLE are similar to RL-LIM.
We compare RL-LIM to multiple benchmarks on 3 synthetic datasets and 5 UCI public datasets. The source-code can be found at https://github.com/google-research/google-research/tree/master/rllim.
Datasets: The 3 public datasets for regression problems are: (1) \colorblueBlog Feedback, (2) \colorblueFacebook Comment, (3) \colorblueNews Popularity; the other 2 public datasets for classification problems are: (4) \colorblueAdult Income, (5) \colorblueWeather. Details of the data descriptions can be found in the hyper-links of each dataset (colored in blue). Data statistics can be found in Table 3 in Appendix A. In this section, we mainly focus on the tabular datasets because the local dynamics are more important and useful to explain for them; however, RL-LIM method can be generalized to other data types in a straightforward way.
Black-box models: We focus on approximating black-box models that are shown to yield competitive performance on the target tasks: 3 tree-based ensemble methods (1) \colorblueXGBoost (Chen & Guestrin, 2016), (2) \colorblueLightGBM (Ke et al., 2017), (3) \colorblueRandom Forests (RF) (Breiman, 2001); and deep neural networks (4) \colorblueMulti-layer Perceptron (MLP). Also, we use (5) \colorblueRidge Regression (RR) and (6) \colorblueRegression Tree (RT) (for regression) and (7) \colorblueLogistic Regression (LR) and (8) \colorblueDecision Tree (DT) (for classification) as globally interpretable models to benchmark.22 2 We use python packages (including Sklearn and Tensorflow) to implement those predictive models and the details can be found in the hyper-links (colored in blue) of each model and Appendix B. We focus on two types of locally interpretable models: (1) Ridge regression, (2) Shallow regression tree (with a max depth of 3). We report the performance with ridge regression for regression and with shallow regression tree for classification in this section. The results of the other two combinations (with ridge regression for classification and with shallow regression tree for regression) are described in Appendix E.
Comparisons to previous work: We compare the performance of RL-LIM with three competing methods: (1) Local Interpretable Model-agnostic Explanations (\colorblueLIME) (Ribeiro et al., 2016), (2) Supervised Local modeling methods (\colorblueSILO) (Bloniarz et al., 2016), (3) Model Agnostic Supervised Local Explanations (\colorblueMAPLE) (Plumb et al., 2018).
Performance metrics: To evaluate the performance of locally interpretable models using real-world datasets, we quantify the overall prediction performance and its fidelity. We assume a disjoint testing dataset for evaluation. For the overall prediction performance, we compare the predictions of the locally interpretable models with the ground-truth labels. We use Mean Absolute Error (MAE) for regression and Average Precision Recall (APR) for classification. For fidelity, we compare the outputs (predicted values for regression and logits for classification) of the locally interpretable models and of the black-box model. We consider two metrics: score (Legates & McCabe, 1999) and Local MAE (LMAE). The details of the metrics are described in Appendix C.
Implementation details: We implement instance-wise weight estimator using a multi-layer perceptron with tanh activation. The number of hidden units and layers are optimized by the cross-validation. In most cases, 5-layer perceptron with 100 hidden units performs reasonably-well across all datasets. All features are normalized to be between zero and one, using standard minmax scaler. Categorical variables are transformed using one-hot encoding.
4.1 Experiments on synthetic datasets – Recovering local dynamics
On real-world datasets it is challenging to directly evaluate the explanation quality of the locally interpretable models due to the absence of ground-truth explanations. Thus we initially focus on synthetic datasets (with known ground-truth explanations) to directly evaluate how well the locally interpretable models can recover the underlying local dynamics. We construct three synthetic datasets such that the 11-dimensional input features are sampled from and are:
Syn1: if and if
Syn2: if and if
Syn3: if and if
All three datasets have different local dynamics in different input regimes. We directly use the ground truth function as the black-box model and focus on how well locally interpretable modeling can capture the local dynamics. We evaluate the performance of capturing local dynamics using Absolute Weight Difference (AWD): , where w is the ground truth coefficients to generate and is the derived coefficient from the locally interpretable models. We use the estimated coefficients of the ridge regression as the derived local dynamics ().
As shown in Fig. 2, RL-LIM significantly outperforms other benchmarks in discovering the local dynamics on all three datasets and in different regimes. RL-LIM can actively learn the linear and non-linear decision boundaries for the local dynamics. Note that LIME completely fails to recover the local dynamics as it uses the Euclidean distance uniformly for all features and cannot distinguish the special properties of the features that alter the local dynamics. SILO and MAPLE only use the predictions to discover the local dynamics; thus, it is hard to discover the decision boundary that depends on the other variables which are independent to the predictions. Fig. 5 in Appendix D shows the learning curves of RL-LIM demonstrating the efficiency of reinforcement learning.
4.2 The effect of the number of selected samples on fidelity
In RL-LIM, optimal distillation is enabled by using a small subset of training instances to fit the low-capacity locally interpretable model. The number of selected instances is controlled by in our method – if is high/low, RL-LIM penalizes more/less on the number of selected instances; thus, less/more instances are selected to construct the locally interpretable model.
We analyze the efficacy of in controlling the likelihood of selection and the dependency of fidelity on . We expect that if we select a too small/large number of training instances, the locally interpretable model will overfit/underfit which negatively affects the fidelity in both cases. Fig. 3 shows that there is a clear relationship between and the local fidelity. If is too large, RL-LIM selects too small number of instances; thus, the fitted locally interpretable model is less accurate (due to overfitting). On the other hand, if is too small, RL-LIM selects too large number of instances and deteriorates fidelity (due to underfitting). To achieve the optimal , we conduct cross-validation experiments and select which achieves the best validation fidelity (e.g. in Syn2). Fig. 3 shows the average selection probability of the training instances for each . As increases, the average selection probabilities monotonically decrease due to the higher penalty on the number of selected training instances. Note that even using a small portion of training instances, RL-LIM can accurately distill the predictions of black-box models into locally interpretable models which is crucial to understand and interpret the predictions using the most relevant training instances.
4.3 Experiments on real datasets – Overall performance and fidelity
On multiple real datasets, we evaluate the overall prediction performance and fidelity. For the regression and classification problems, we use ridge regression and shallow regression trees as the locally interpretable model. More results can be found in Appendix E.
As can be seen in Table 1, the performance of globally interpretable ridge regression (trained on the entire dataset from the scratch) is much worse than other complex non-linear models, implying that modeling non-linear relationships between the features and the labels is important towards high prediction performance. For other locally interpretable modeling methods (LIME, SILO, MAPLE), the performance is far worse than the original black-box model, showing that they fail at efficiently distilling the non-linear black-box models. In some cases (especially on the Facebook dataset), the performance of the benchmarks is even worse than the performance of global ridge regression (highlighted in red), questioning the value of using these locally interpretable models instead of globally interpretable ridge regression.
In contrast, RL-LIM achieves similar overall prediction performance to the black-box models and significantly outperforms global ridge regression. Table 1 also compares the fidelity in terms of score for regression using ridge regression as the locally interpretable model (LMAE results can be found in Appendix E.3). We observe that scores for some cases (especially on Facebook dataset and LIME) are negative which represent that the outputs of the locally interpretable models are even worse than the constant mean value estimator. On the other hand, RL-LIM achieves higher and positive values consistently for all datasets and black-box models than other benchmarks.
Table 2 shows a similar analysis for classification using shallow regression trees (with max depth of 3) as the locally interpretable model33 3 Regression trees are used to model logit outputs for classification.. The overall prediction performance of four black-box models are significantly better than the globally interpretable decision tree which demonstrates the superior fitting by complex black-box models. Among the locally interpretable models, RL-LIM achieves the best APR and score for most cases, underlining its strength in distilling the predictions of the black-box model accurately. In some cases, the benchmarks (especially for LIME) achieve lower overall prediction performance than the globally interpretable decision tree (highlighted in red). The overall prediction performance and fidelity metrics of all locally interpretable models seem better for classification problems than regression problems. We expect that the predictions of black-box models are mostly highly confident, i.e. located near 0 or 1; thus, locally interpretable models can easily distill the predictions of the black-box models for classification than regression.
4.4 Qualitative analyses – Interpretations of RL-LIM on Adult Income dataset
We qualitatively analyze the local explanations provided by RL-LIM on the Adult Income dataset (qualitative analyses on Weather dataset can be found in Appendix E.4). Although RL-LIM is able to provide local explanations for each individual separately, we analyze its explanations in subgroup granularity for better visualization and understanding (instance granularity analyses are described in Appendix E.4). Fig. 4 represents the feature importance (derived by RL-LIM as the local explanations) for five subgroups in predicting the annual income using XGBoost as the black-box model. We use ridge regression as the locally interpretable model and the absolute value of fitted coefficients as the estimated feature importance. As can be observed in Fig. 4, for age subgroups, capital gain seems much more important for mature people (older than 25) than young people (younger than 25). For education subgroups, capital gain/loss, occupation, and native countries are more critical for highly-educated people (Doctorate, Prof-school, and Masters graduates) than the others. We do not discover notable biases of black-box models for gender, marital status, and race subgroups.
We propose a novel method for locally interpretable modeling of pre-trained black-box models. Our proposed method employs reinforcement learning to select a small number of valuable instances and use them to train a low-capacity locally interpretable model. The selection mechanism is guided with a reward obtained from the similarity of predictions of the locally interpretable model and the black-box model. Our approach near-matches the performance of black-box models and significantly outperforms alternative techniques in terms of overall prediction performance and fidelity metrics consistently across various datasets and black-box models.
Discussions with Besim Avci, Henry Tappen and Zizhao Zhang are gratefully acknowledged.
- Alvarez-Melis & Jaakkola (2018) David Alvarez-Melis and Tommi S Jaakkola. On the robustness of interpretability methods. arXiv preprint arXiv:1806.08049, 2018.
- Bengio et al. (2009) Yoshua Bengio, Jérôme Louradour, Ronan Collobert, and Jason Weston. Curriculum learning. In International Conference on Machine Learning, pp. 41–48. ACM, 2009.
- Bloniarz et al. (2016) Adam Bloniarz, Ameet Talwalkar, Bin Yu, and Christopher Wu. Supervised neighborhoods for distributed nonparametric regression. In Artificial Intelligence and Statistics, pp. 1450–1459, 2016.
- Breiman (2001) Leo Breiman. Random forests. Machine Learning, 45(1):5–32, 2001.
- Chen & Guestrin (2016) Tianqi Chen and Carlos Guestrin. Xgboost: A scalable tree boosting system. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 785–794. ACM, 2016.
- Gilpin et al. (2018) L. H. Gilpin, D. Bau, B. Z. Yuan, A. Bajwa, M. Specter, and L. Kagal. Explaining explanations: An overview of interpretability of machine learning. In 2018 IEEE 5th International Conference on Data Science and Advanced Analytics (DSAA), pp. 80–89, Oct 2018.
- Goodfellow et al. (2016) Ian Goodfellow, Yoshua Bengio, and Aaron Courville. Deep Learning. MIT Press, 2016. http://www.deeplearningbook.org.
- He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778, 2016.
- Hinton et al. (2015) Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
- Jang et al. (2016) Eric Jang, Shixiang Gu, and Ben Poole. Categorical reparameterization with gumbel-softmax. In International Conference on Learning Representations, 2016.
- Jiang et al. (2018) Lu Jiang, Zhengyuan Zhou, Thomas Leung, Li-Jia Li, and Li Fei-Fei. Mentornet: Learning data-driven curriculum for very deep neural networks on corrupted labels. In International Conference on Machine Learning, pp. 2309–2318, 2018.
- Johansson et al. (2011) Ulf Johansson, Cecilia Sönströd, Ulf Norinder, and Henrik Boström. Trade-off between accuracy and interpretability for predictive in silico modeling. Future medicinal chemistry, 3(6):647–663, 2011.
- Ke et al. (2017) Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, Weidong Ma, Qiwei Ye, and Tie-Yan Liu. Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information Processing Systems, pp. 3146–3154, 2017.
- Koh & Liang (2017) Pang Wei Koh and Percy Liang. Understanding black-box predictions via influence functions. In International Conference on Machine Learning, pp. 1885–1894, 2017.
- Lakkaraju et al. (2017) Himabindu Lakkaraju, Ece Kamar, Rich Caruana, and Jure Leskovec. Interpretable & explorable approximations of black box models. arXiv preprint arXiv:1707.01154, 2017.
- Lakkaraju et al. (2019) Himabindu Lakkaraju, Ece Kamar, Rich Caruana, and Jure Leskovec. Faithful and customizable explanations of black box models. In Proceedings of the 2019 AAAI/ACM Conference on AI, Ethics, and Society, pp. 131–138. ACM, 2019.
- Legates & McCabe (1999) David R Legates and Gregory J McCabe. Evaluating the use of “goodness-of-fit” measures in hydrologic and hydroclimatic model validation. Water Resources Research, 35(1):233–241, 1999.
- Lipton (2016) Zachary C Lipton. The mythos of model interpretability. arXiv preprint arXiv:1606.03490, 2016.
- Lundberg & Lee (2017) Scott M Lundberg and Su-In Lee. A unified approach to interpreting model predictions. In Advances in Neural Information Processing Systems, pp. 4765–4774, 2017.
- Plumb et al. (2018) Gregory Plumb, Denali Molitor, and Ameet S Talwalkar. Model agnostic supervised local explanations. In Advances in Neural Information Processing Systems, pp. 2515–2524, 2018.
- Plumb et al. (2019) Gregory Plumb, Maruan Al-Shedivat, Eric Xing, and Ameet Talwalkar. Regularizing black-box models for improved interpretability. arXiv preprint arXiv:1902.06787, 2019.
- Raffel et al. (2017) Colin Raffel, Minh-Thang Luong, Peter J Liu, Ron J Weiss, and Douglas Eck. Online and linear-time attention by enforcing monotonic alignments. In International Conference on Machine Learning, pp. 2837–2846, 2017.
- Ranzato et al. (2015) Marc’Aurelio Ranzato, Sumit Chopra, Michael Auli, and Wojciech Zaremba. Sequence level training with recurrent neural networks. arXiv preprint arXiv:1511.06732, 2015.
- Ren et al. (2018) Mengye Ren, Wenyuan Zeng, Bin Yang, and Raquel Urtasun. Learning to reweight examples for robust deep learning. In International Conference on Machine Learning, pp. 4331–4340, 2018.
- Ribeiro et al. (2016) Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. Why should i trust you?: Explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1135–1144. ACM, 2016.
- Ribeiro et al. (2018) Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. Anchors: High-precision model-agnostic explanations. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018.
- Rudin (2018) Cynthia Rudin. Please Stop Explaining Black Box Models for High Stakes Decisions. arXiv:1811.10154, 2018.
- Selvaraju et al. (2017) Ramprasaath R Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, and Dhruv Batra. Grad-cam: Visual explanations from deep networks via gradient-based localization. In Proceedings of the IEEE International Conference on Computer Vision, pp. 618–626, 2017.
- Shrikumar et al. (2017) Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. Learning important features through propagating activation differences. In International Conference on Machine Learning-Volume, pp. 3145–3153, 2017.
- Štrumbelj & Kononenko (2014) Erik Štrumbelj and Igor Kononenko. Explaining prediction models and individual predictions with feature contributions. Knowledge and Information Systems, 41(3):647–665, 2014.
- Tan (2018) Pang-Ning Tan. Introduction to Data Mining. Pearson Education India, 2018.
- Virág & Nyitrai (2014) Miklós Virág and Tamás Nyitrai. Is there a trade-off between the predictive power and the interpretability of bankruptcy models? the case of the first hungarian bankruptcy prediction model. Acta Oeconomica, 64(4):419–440, 2014.
- Wang et al. (2019) Tongzhou Wang, Jun-Yan Zhu, Antonio Torralba, and Alexei A. Efros. Dataset distillation, 2019. URL https://openreview.net/forum?id=Sy4lojC9tm.
- Williams (1992) Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine Learning, 8(3-4):229–256, 1992.
- Zaremba & Sutskever (2015) Wojciech Zaremba and Ilya Sutskever. Reinforcement learning neural turing machines-revised. arXiv preprint arXiv:1505.00521, 2015.
- Zhang & Lapata (2017) Xingxing Zhang and Mirella Lapata. Sentence simplification with deep reinforcement learning. In Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, pp. 584–594, 2017.
- Zhang et al. (2019) Yujia Zhang, Kuangyan Song, Yiming Sun, Sarah Tan, and Madeleine Udell. “why should you trust my explanation?” understanding uncertainty in lime explanations. arXiv preprint arXiv:1904.12991, 2019.
- Zhou et al. (2016) Bolei Zhou, Aditya Khosla, Agata Lapedriza, Aude Oliva, and Antonio Torralba. Learning deep features for discriminative localization. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2921–2929, 2016.
Appendix A Data statistics
|Problem||Data Name||of samples||of features||Label distribution|
Appendix B Hyper-parameters of the predictive models
In this paper, we use 8 different predictive models. For each predictive model, the corresponding hyper-parameters used in the experiments are as follows:
XGBoost: booster - gbtree, max depth - 6, learning rate - 0.3, number of estimators - 1000, max depth - 6, reg alpha - 0
LightGBM: booster - gbdt, max depth - None, learning rate - 0.1, number of estimators - 1000, min data in leaf - 20
Random Forests: number of estimators - 1000, criterion - gini, max depth - None, warm start - False
Multi-layer Perceptron: Number of layers - 4, hidden units - [feature dimensions, feature dimensions/2, feature dimensions/4, feature dimensions/8], activation function - relu, early stoping - True with patient 10, batch size - 256, maximum number of epochs - 200, optimizer - Adam
Ridge Regression: alpha - 1
Regression Tree: max depth - 3, criterion - gini
Logistic Regression: solver - lbfgs, no regularization
Decision Tree: max depth - 3, criterion - gini
We follow the default settings for the other hyper-parameters that are not mentioned here.
Appendix C Performance metrics
Mean Absolute Error (MAE):
Local MAE (LMAE):
score (Legates & McCabe, 1999):
If , the predictions of the locally interpretable model perfectly match the predictions of the black-box model. On the other hand, if , the locally interpretable model performs as similar as the constant mean value estimator. If , the locally interpretable model performs worse than the constant mean value estimator.
Appendix D Learning curves of RL-LIM
Appendix E Additional results
E.1 Regression with shallow regression tree as the locally interpretable model
E.2 Classification with ridge regression as the locally interpretable model
E.3 Regression with ridge regression as the locally interpretable model - Fidelity analysis in terms of Local MAE (LMAE)
E.4 Qualitative analyses – Interpretations of RL-LIM on Weather dataset
We qualitatively analyze the local explanations provided by RL-LIM on Weather dataset at subgroup granularity. Fig. 6 shows the feature importance for six subgroups in predicting whether it will rain tomorrow, using XGBoost as the black-box model. We use ridge regression as the locally interpretable model and the absolute value of fitted coefficients as the estimated feature importance. For rain fall subgroups, humidity and wind gust speed seem more important for heavy rain (rain fall 5) than light rain (rain fall 5). For temperature subgroups, rainfall, wind gust speed and humidity are more important for cold days (temperature (at 3pm) 10) than warm day (temperature (at 3pm) 20). In general, for heavy rain, fast wind speed, low pressure, and low temperature subgroups, humidity, wind gust speed and rain fall variables are more critical to predict whether it will rain tomorrow than light rain, slow wind speed, high pressure, and high temperature subgroups . We do not discover notable biases of the black-box model for humidity subgroups.
We further analyze the local explanations provided by RL-LIM on the Weather dataset at instance granularity. Fig. 6 represents the feature importance (derived by RL-LIM as the local explanations) for 10 instances belong to a subgroup with ‘rain fall 1, wind speed (at 3pm) 5, and temperature (at 3pm) 30’ and the other 10 instances belong to the other subgroup with ‘rain fall 15, wind speed (at 3pm) 25, and temperature (3pm) 10’. Other experiment settings are the same with the previous analyses in the subgroup granularity. There are clear differences between feature importance of two subgroups (left and right in Fig. 6). Even within the same subgroup, we can observe differences in feature importance across different instances, that are efficiently provided by RL-LIM.