SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

  • 2021-06-02 17:51:05
  • Gowthami Somepalli, Micah Goldblum, Avi Schwarzschild, C. Bayan Bruss, Tom Goldstein
Tabular data underpins numerous high-impact applications of machine learningfrom fraud detection to genomics and healthcare. Classical approaches tosolving tabular problems, such as gradient boosting and random forests, arewidely used by practitioners. However, recent deep learning methods haveachieved a degree of performance competitive with popular techniques. We devisea hybrid deep learning approach to solving tabular data problems. Our method,SAINT, performs attention over both rows and columns, and it includes anenhanced embedding method. We also study a new contrastive self-supervisedpre-training method for use when labels are scarce. SAINT consistently improvesperformance over previous deep learning methods, and it even outperformsgradient boosting methods, including XGBoost, CatBoost, and LightGBM, onaverage over a variety of benchmark tasks.


