Implicit Bias and Fast Convergence Rates for Self-attention

Abstract

Self-attention, the core mechanism of transformers, distinguishes them from traditional neural networks and drives their outstanding performance. Towards developing the fundamental optimization principles of self-attention, we investigate the implicit bias of gradient descent (GD) in training a self-attention layer with fixed linear decoder in binary classification. Drawing inspiration from the study of GD in linear logistic regression over separable data, recent work demonstrates that as the number of iterations t approaches infinity, the key-query matrix Wt converges locally (with respect to the initialization direction) to a hard-margin SVM solution Wmm. Our work enhances this result in four aspects. Firstly, we identify non-trivial data settings for which convergence is provably global, thus shedding light on the optimization landscape. Secondly, we provide the first finite-time convergence rate for Wt to Wmm, along with quantifying the rate of sparsification in the attention map. Thirdly, through an analysis of normalized GD and Polyak step-size, we demonstrate analytically that adaptive step-size rules can accelerate the convergence of self-attention. Additionally, we remove the restriction of prior work on a fixed linear decoder. Our results reinforce the implicit-bias perspective of self-attention and strengthen its connections to implicit-bias in linear logistic regression, despite the intricate non-convex nature of the former.

Type
Publication
BGPT Workshop@ICLR'24; Under review