Machine Translation with Fairseq
Published on 11 Oct 2018
### Pre-processing FAIR=~/lib/fairseq python $FAIR/preprocess.py --source-lang en --target-lang hi \ --trainpref train.bpe --validpref dev.bpe --testpref test.bpe \ --destdir data-bin/ #### Pre-process other test sets Use `dict` generated in the above step to preprocess (binarize) additional data python $FAIR/preprocess.py -s en -t fr \ --srcdict data-bin/dict.en.txt --tgtdict data-bin/dict.fr.txt \ --testpref newstest14.bpe, test.bpe --destdir test-bin #### Preprocess (if only source side test data is available) python $FAIR/preprocess.py --srcdict data-bin/dict.hi.txt --only-source --testpref test.bpe --destdir tmp/ -s hi -t en ### Training mkdir -p /ssd_scratch/cvit/binu.jasim/checkpoints/fconv python $FAIR/train.py data-bin/ \ --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ --arch fconv_iwslt_de_en --save-dir /ssd_scratch/cvit/binu.jasim/checkpoints/fconv ### Inference Predict translations of test sentences # for((i=20; i<25; i++)); do # CKPT=$CKPT:checkpoint/fconv/checkpoint$i.pt # done CKPT=checkpoint/checkpoint_best.pt python $FAIR/generate.py data-bin/ --path $CKPT \ --beam 5 --batch-size 128 \ > tmp/gen.out Post-process to separate hypothesis and target grep ^H gen.out | cut -f3 > gen.out.sys grep ^T gen.out | cut -f2 > gen.out.ref `spm_decode`: sed 's/ //g' | sed 's/▁/ /g' | sed 's/^ //g' #### Translate raw text (Interactive) - Option 1: (Very slow) cat test.bpe.hi | python interactive.py data-bin/ --path $CKPT --buffer-size 256 > gen.out - Option 2: binarize as described above ## Transformer Hyperparameters See [fairseq-examples-translation](https://github.com/pytorch/fairseq/tree/master/examples/translation) for some sample hyperparameter settings. Read: [Training Tips for the Transformer Model by Martin Popel, Ondřej Bojar](https://ufal.mff.cuni.cz/pbml/110/art-popel-bojar.pdf) python $FAIR/train.py data-bin --arch transformer \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ --lr 0.001 --min-lr 1e-09 \ --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --max-tokens 3584 --update-freq 32 --save-dir /ssd_scratch/cvit/binu.jasim/checkpoints/transformer_1M_en_fr The above setting is from https://github.com/pytorch/fairseq/issues/187 But it takes too much to run (like 12 hours on 1 million en-fr data) Latest setting python $FAIR/train.py data-bin --arch transformer \ --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 7168 \ --lr-scheduler inverse_sqrt \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --update-freq 16 --warmup-updates 16000 --warmup-init-lr 0.25 \ --save-dir /ssd_scratch/cvit/binu.jasim/checkpoints/transformer_1M (use the first one itself) `confused`. Read [this nice piece of advice by Myle Ott](https://github.com/pytorch/fairseq/issues/417) for setting learning rate for Adam optimizer. ### A Few Notes about Hyper-parameters 1. Use Adam over SGD (SGD is slow & many SOTA use Adam, even though there are claims that SGD is typically used for SOTA results) 2. When using transformer_wmt_en_de (base), make sure to increase the learning rate. `--lr=0.0007` is a good learning rate for the base model with 8 GPUs. 3. We can use even a high `lr=0.001` if we use `--update-freq 32 or 16`. This is a way to simulate higher batch size (*wps*). Make sure to have an effective batch size (wps) of at least 25k tokens. 4. Higher learning rates can be used with higher batch sizes. Typically they result in faster training and even better results. 5. Does our GPU support `--fp16`? If yes, go for it. Much faster training. 6. `--share-all-embeddings` is a good option. We would need combined vocabulary then. Otherwise use *share-input-output-decoder-embedding* 7. Use warmup. `--warmup-updates` 4000 is enough. 8. Stopping criteria is typically to look at the validation loss (manually!) and stop if it plateau. 9. Best performance usually comes after averaging the last N checkpoints (with N between 5 and 10). You can do this with the average_checkpoints script: ```python scripts/average_checkpoints.py --inputs /path/to/checkpoints --num-epoch-checkpoints 10 --output /path/to/checkpoints/averaged.pt ```