Другим большим достижением я бы назвал работу про GShard середины 2020 года (“GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding”,
https://arxiv.org/abs/2006.16668). Среди авторов Дмитрий Лепихин (пользуясь случаем, хочу передать привет :) ) и другие, включая всё того же Noam Shazeer.
В этой работе MoE применили к трансформерам, в которых каждый второй feed-forward слой в энкодере и декодере заменили на MoE с top-2 активными в каждый момент времени экспертами. Модель отскейлили до 600B параметров (2048 экспертов, 36 слоёв трансформера) и обучали на 2048 TPU v3 под задачу мультиязычного машинного перевода, и получили заметное улучшение качества.
Чтобы трансформер с MoE можно было так отскейлить, разработали модуль GShard, предоставляющий API для шардирования моделей, по сути представляющий собой аннотации для тензоров.
В работе даже попробовали обучить модель на 1T параметров (2048 экспертов, 60 слоёв), но с ней были проблемы по части стабильности при обучении на bfloat16 (остальное обучалось на float32; напоминаю, что про форматы чисел с плавающей точкой есть подробная статья
https://moocaholic.medium.com/fp64-fp32-fp16-bfloat16-tf32-and-other-members-of-the-zoo-a1ca7897d407).
Ну и собственно Switch Transformer во многом продолжает эту линию. Алгоритм роутинга в MoE упростили, теперь активен только один эксперт всё в том же feed-forward слое (был отдельный эксперимент по добавлению MoE в слой внимания, и это даёт лучший результат, но работает на float32, а на bfloat16 расходится). Каждый токен отправляется в соответствующего ему эксперта.
Делали это поверх Mesh-Tensorflow, в работе есть сколько-то оптимизаций для улучшения обучения (например, везде используют bfloat16, что ощутимо облегчает коммуникацию, и только внутри функции роутера всё приводят к float32; более хитрые инициализации и регуляризации).
Сравниваются с моделями семейста T5: T5-base (223M), T5-large (739M), для сравнения с ними делают Switch-base (7.4B) и Switch-large (26.3B) так, чтобы по количеству вычислений (FLOPS) они совпадали.
Делают предобучение на Masked Language Model (MLM, по типу BERT’а) и файн-тюнинг на задачах из GLUE, SQuAD, SuperGLUE и т.д.
Был ещё дополнительный лосс, помогающий балансировать нагрузку по различным экспертам, чтобы никакой из них не перетрудился и не стал узким местом, к кому стоит очередь. Исключать, кстати, такого всё равно нельзя и на случай, когда кто-то окажется перегружен и не успеет обработать какой-то из входящих токенов, этот токен (вернее его эмбеддинг) пробрасывается на выход, благо там везде residual connections. Было предположение, что качество от этого страдает, поэтому сделали No-Token-Left-Behind роутинг, который, если обнаруживает, что токен отправлен на перегруженного эксперта, рероутит его на другого. Но это на практике ничего не дало.
Новых state-of-the-art здесь не получают, но зато показывают, что, во-первых, Switch Transformer превосходит обычный трансформер по качеству; во-вторых, он хорошо скейлится по числу экспертов (чем больше, тем выше качество; при этом не забывайте, что активен в каждый момент только один эксперт, так что это не ансамбль в традиционном виде), а также, в-третьих, при фиксированном объёме вычислений достигает лучшего результата, чем традиционный трансформер (в 7 раз быстрее для base варианта).
Большой Switch Transformer можно отдистиллировать (про дистилляцию у нас в канале вообще было много всего) в обычный, например, Switch-base в T5-base, и это даст качество выше, чем если с нуля обучать аналогичный обычный (T5-base), сохраняется примерно 30% улучшения. Сравнились также с мультиязычным mT5, получили прирост на всех языках (про T5 и mT5 мы писали
https://t.me/gonzo_ML/442).