Large pretrained Transformer models have proven to be extremely capable at tackling natural language tasks. However, handling long textual sequences continues to be a significant challenge for Transformer models.
What is the problem?
Quadratic growth in memory consumption of the attention computation in Transformers.
What is this article all about?
The empirical investigation conducted in paper “Investigating Efficiently Extending Transformers for Long Input Summarization” by Jason Phang, Yao Zhao and Peter J. Liu into the architectural changes, model configurations and pretraining schemes to identify the better approaches for training Transformer models in tackling long input summarization. The experiments conducted and their findings are discussed in this post.
What are the key findings?
They evaluate a set of efficient Transformer variants and propose a simpler blockwise local Transformer architecture with staggered blocks and global tokens that strikes a good balance of performance and memory efficiency.
PEGASUS-X and PEGASUS -XBase:-
Based on the findings from our empirical investigation, we adapt the pretrained PEGASUS Large model in tackling long input summarization on up to 16K input tokens. The resulting model, which we call PEGASUS-X attains top scores on long summarization tasks, outperforming much larger models like LongT5 (Guo et al., 2021) in some cases, and sets the state of the art of two tasks: GovReport and PubMed.
For Code and Weights:-https://github.com/google-research/pegasus
Let's dive deep into each of these segments:
What are the key challenges faced in Long Input Summarization:-
- Computational Challenge:- Let’s understand the quadratic scaling with an example - Consider a summarization model with 12 encoder and 12 decoder layers, pre trained on an input length of 512 and fine tuned on a task with input sequence length 16384, using output length of 512 in both cases. Pre training is typically done with shorter sequences, while fine-tuning uses long inputs for summaries. Fine-tuning can now be more resource intensive and slower than pretraining, which is contrary to the conventional paradigm. Since the encoder inputs have increased 32 times, the memory requirements will increase by 1024 times. Even efficient transformers that achieve linear scaling of memory can still consume 32 times memory. This leads to several design choices which are evaluated through experimentation.
- Dataset Challenge:-The main issues in current datasets are: relative simplicity of summarization, lack of diverse inputs, potential leakage of data due to the data collection procedures, and low quantity of examples for training.
Experimentations and their findings:-
1.Encoder Architecture:-
The investigation begins with the efficacy of swapping the encoder for an efficient Transformer encoder. Big Bird takes the approach of using sparse attention computation, combining sliding window attention, random attention and a set of global-attention tokens.
Conversely, Performer takes the approach of factoring attention matrices through orthogonal random features.
In addition, we also introduce two simple variants of local attention Transformer encoders. First, we use a simple block-local Transformer (Local), where encoder input tokens are divided into non overlapping blocks, tokens can only attend to other tokens within the block.
Second, we extend this local Transformer by adding a set of global tokens with learnable embeddings, that can attend to as well as be attended from every encoder’s token (GlobalLocal).
However, we opt for the simpler block-local attention rather than sliding window attention, and compensate for the lack of overlapping blocks by staggering the local attention blocks.
Among the short tasks, the full-attention Transformer performs best, followed by BigBird. On the long tasks, BigBird and Global-Local models perform the best, but BigBird consumes significantly more memory and trains much slower than the other architectures.
On the other hand, we find that the Local and Global-Local encoders strike a good balance of both performance and efficiency. The simple local attention encoder, which uses a block-local attention mechanism, attains performance surprisingly close to that of BigBird.
Local and Global-local configurations:-
Staggering:-First staggering of local attention blocks was introduced. Unlike in sliding window attention, in block-local attention, tokens can only attend other tokens within the same block. If the input tokens are divided up into the same blocks in every layer, it means that no information is exchanged across blocks through the entire encoder. To address this pitfall, we introduce a small architectural change wherein we stagger the block allocation across alternating layers. We have shown an example of this in Figure 2. Concretely, we stagger attention blocks by shifting the block boundaries by half a block every other layer:
We show the results of both of these changes in Table 2. We have found that staggering local blocks improves performance in both Local and Global-Local models by a noticeable amount. We have highlighted this improvement in the performance even with the Global-Local models.
Global-Local: Block Size and Number of Global Tokens:-
Larger block sizes and/or numbers of global tokens leads to improved performance, although the effect saturates.
For the remainder of the ablation experiments, they stick to a block size of 64 and 32 global tokens for consistency.
Positional Encoding Scheme:- Sinusoidal position encodings still remain to be a good choice for long input Transformers.
Scaling Encoder and Decoder Layer:-They fix the total number of layers to 24, and consider both encoder-heavy and decoder-heavy distributions, for both Local and Global-Local models. They observe that the impact of distribution of encoder and decoder layers on performance is relatively smaller. For Local models, they see a slight boost from decoder-heavy models. For Global-Local models, we observe that a balanced encoder-decoder outperforms encoder-heavy and decoder-heavy models, both of which perform comparably.
Pretraining vs Fine-tuning Architectures:-For Local models, they found that pre-training with local attention using small block sizes tends to discourage the performance, but at moderate block sizes (e.g. 64) there is little difference between the two approaches. In contrast, we found that for Global-Local, pretraining with the efficient architecture tends to perform better. We hypothesize that this difference arises because of the presence of the learned global embedding tokens, which are randomly initialized when adapting from a pre-trained Transformer and hence may benefit from pre-training and being jointly trained with the local attention.
Pre-training Scheme:-We consider two setups for pretraining: short-input pretraining, with 512 input tokens and 256 output tokens, and long-input pretraining, with 4096 input tokens and 256 output tokens.
Given a fixed compute budget, allocating some portion of training to long-input training can improve performance, although the precise optimal allocation is difficult to determine. Exclusively long pre-training results in worse performance.
Partial Cross Attention:-
Another major memory bottleneck is the encoder-decoder cross-attention. It is found that reducing the number of cross-attention layers leads to a drop in performance, but the impact on performance is smaller than expected. For instance, with only cross-attention on the first and sixth layer, the Global-Local model still outperforms a Local model. The reduction of cross-attention layers also leads to a corresponding improvement in training step and reduction in memory consumption.
PEGASUS-X:-The experimentation with two model sizes PEGASUS-X (PEGASUS eXtended), based on PEGASUSLarge; and PEGASUS-XBase, based on a newly trained PEGASUSBase model which we call PEGASUSBase+.
Only two new sets of parameters were introduced: the global token embeddings, and a separate LayerNorm for the global input representations in each Transformer layer. This is approximately 1M more parameters for PEGASUS-XBase and 2M more for PEGASUS-X.
In Table 11 we can see the performance of PEGASUS models to those of PEGASUS-X on three long-input summarization tasks: arXiv, Big Patent and PubMed. In all three tasks there is significant amount of improvement in performance of PEGASUSXBase over PEGASUSBase+, and PEGASUS-X over PEGASUSLarge.
Fine-tuning PEGASUS-X-base JAX/Flax on E2E Cloud:-
1.Launch A100 GPU on E2E Cloud
If need any help to launch the node please follow https://docs.e2enetworks.com/computes/nodes/launchnode.html#how-to-launch-nodes-from-myaccount-portal
2.Install dependency:- To install the dependency run the following commands one by one.
git clone https://github.com/google-research/pegasus
3.Tokenizer Download:-
wget “https://storage.googleapis.com/pegasus_ckpt/c4.unigram.newline.10pct.96000.model”
4.Download Checkpoints fine tuned on :-
wget “https://storage.googleapis.com/pegasus_ckpt/px/tuned/large/scrolls_summscreen.ckpt”
5.Fine-tuning
First, we need to prepare the data for fine-tuning using TFDS.
To fine-tune the model, we need to modify a config file
You can modify this file if want any changes and then run this command:-
E2E Networks is the leading accelerated Cloud Computing player which provides the latest Cloud GPUs at a great value. Connect with us at sales@e2enetworks.com
References:-