๐Ÿง‘โ€๐Ÿซ Lecture 17-18

lecture
distributed training
Distributed training
Author

Gijeong Seong

Published

May 10, 2024

์ด๋ฒˆ ๊ฐ•์˜๋Š” ๋ถ„์‚ฐ ํ•™์Šต(distributed training)์— ๋Œ€ํ•œ ๋‚ด์šฉ์„ ๋‹ค๋ฃฌ๋‹ค. LLM๋“ฑ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ด ์ ์  ๋ณต์žกํ•ด์ง€๋ฉด์„œ, ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ์˜ ํ›ˆ๋ จ์€ ๋‹จ์ผ GPU๋กœ๋Š” ๋” ์ด์ƒ ์ถฉ๋ถ„ํ•˜์ง€ ์•Š๊ณ , ์—ฌ๋Ÿฌ๊ฐœ์˜ GPU๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ๊ฐ€๋Šฅํ•˜๋‹ค. ์ด๋ฒˆ ๊ธ€์—์„œ๋Š” ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์˜ ํ›ˆ๋ จ์„ ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•œ ๋‹ค์–‘ํ•œ ๋ถ„์‚ฐ ํ›ˆ๋ จ ๊ธฐ๋ฒ•๋“ค์„ ๋‹ค๋ฃจ๊ณ ์ž ํ•œ๋‹ค. ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”, ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌํ™”, ํ…์„œ ๋ณ‘๋ ฌํ™” ๋“ฑ์˜ ๊ธฐ๋ฒ•๋“ค์„ ํ†ตํ•ด ํšจ์œจ์ ์ธ ๋ถ„์‚ฐ ํ›ˆ๋ จ์„ ์‹คํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์‚ดํŽด๋ณผ ๊ฒƒ์ด๋‹ค. ๋˜ํ•œ, ๋ถ„์‚ฐ ํ›ˆ๋ จ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ GPU๊ฐ„ ํ†ต์‹ ์˜ ๋ณ‘๋ชฉ ํ˜„์ƒ์— ๋Œ€ํ•œ ๋Œ€์ฒ˜๋ฒ•์œผ๋กœ ๊ฐ€์ค‘์น˜ ์••์ถ•(gradient compression) ๋ฐ ์ง€์—ฐ ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ(delayed gradient update) ๋˜ํ•œ ์‚ดํŽด๋ณผ ๊ฒƒ์ด๋‹ค.

1. Background and motivation

์œ„ ํ‘œ๋Š” Nvidia A100 ๊ธฐ์ค€์œผ๋กœ ๋ชจ๋ธ๋ณ„๋กœ ํ›ˆ๋ จ์— ๊ฑธ๋ฆฌ๋Š” ์‹œ๊ฐ„์ด๋‹ค. distributed training์—†์ด ๋‹จ์ผ GPU๋กœ ํ•™์Šตํ•œ๋‹ค๋ฉด, GPT-3๋Š” ํ•™์Šตํ•˜๋Š”๋ฐ 355๋…„์ด๋‚˜ ๊ฑธ๋ฆฐ๋‹ค.

์œ„ ๋†๋‹ด์ฒ˜๋Ÿผ, ํ›ˆ๋ จ์— ๋„ˆ๋ฌด ๋งŽ์€ ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฐ๋‹ค๋ฉด ๋ชจ๋ธ์„ ๋ฐœ์ „์‹œํ‚ค๊ธฐ ์–ด๋ ค์šธ ๊ฒƒ์ด๋‹ค. distribution training์„ ์‚ฌ์šฉํ•˜๋ฉด, ์ด์ƒ์ ์œผ๋กœ๋Š” 10์ผ์ด ๊ฑธ๋ฆฌ๋Š” ์ž‘์—…์„ 1024๊ฐœ์˜ GPU๋กœ๋Š” 14๋ถ„๋งŒ์— ํ•™์Šต์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

2. Parallelization methods for distributed training

๊ทธ๋ ‡๋‹ค๋ฉด ์–ด๋–ป๊ฒŒ ๋ณ‘๋ ฌ๋กœ ํ›ˆ๋ จ์‹œํ‚ฌ ์ˆ˜ ์žˆ์„๊นŒ? ์ด๋ฒˆ ์žฅ์—์„œ๋Š” ๊ฐ„๋‹จํžˆ ์„ธ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์„ ์†Œ๊ฐœํ•˜๊ณ , ๊ฐ ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์„ค๋ช…์€ ํ•˜๋‚˜์”ฉ ๋‹ค๋ฃจ๋„๋ก ํ•˜๊ฒ ๋‹ค.

์ฒซ ๋ฒˆ์งธ๋Š” data parallelism์ด๋‹ค. ์ด ๋ฐฉ๋ฒ•์€ ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ์—ฌ๋Ÿฌ๊ฐœ์˜ GPU์— ๋‚˜๋ˆ ์„œ ํ•™์Šตํ•œ ๋’ค, ๊ณ„์‚ฐ๋œ ๊ฐ๊ฐ์˜ gradient๋ฅผ ๋‹ค์‹œ ์ค‘์•™์—์„œ ํ•ฉ์น˜๋Š” ๋ฐฉ์‹์ด๋‹ค.

๋‘ ๋ฒˆ์งธ๋Š” pipeline parallelism์ด๋‹ค. ์ด ๋ฐฉ๋ฒ•์€ ๋ชจ๋ธ์„ layer ๋ณ„๋กœ ๋‚˜๋ˆ ์„œ ํ• ๋‹นํ•˜๋Š” ๊ฒƒ์ด๋‹ค. GPU๋ณ„๋กœ layer๋ฅผ ๋‚˜๋ˆ  ๊ณ„์‚ฐํ•˜๋Š” ๊ฒƒ์ธ๋ฐ, ํ•˜๋‚˜์˜ GPU์— ๋ชจ๋ธ์˜ ๋ชจ๋“  ๊ฐ€์ค‘์น˜๋ฅผ ๋„ฃ๊ธฐ์—” ๋„ˆ๋ฌด ํด ๋•Œ ์‚ฌ์šฉํ•œ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ tensor parallelism์ด๋‹ค. ์ด๋Š” layer๋ณด๋‹ค ๋” ์ž˜๊ฒŒ, ํ…์„œ ๋‹จ์œ„๋กœ ์ชผ๊ฐœ๋Š” ๊ฒƒ์ด๋‹ค. ์ง๊ด€์ ์œผ๋กœ ๋ณด๋ฉด ๋ชจ๋ธ์„ ์„ธ๋กœ๋กœ ์ชผ๊ฐœ์„œ ์„œ๋กœ ๋‹ค๋ฅธ GPU์— ํ• ๋‹นํ•˜๋Š” ๊ฒƒ์ธ๋ฐ, ์ž์„ธํ•œ ์„ค๋ช…์€ ๋’ค์— ๋‹ค์‹œ ๋‹ค๋ฃจ๊ฒ ๋‹ค

3. Data parallelism

data parallelism์„ ํ•  ๋–„์—๋Š” parameter server์™€ worker nodes๋ผ๋Š” ๋‘ ๊ฐ€์ง€ ์š”์†Œ๊ฐ€ ์กด์žฌํ•œ๋‹ค. worker nodes๋Š” ๋ถ„ํ• ๋œ ๋ฐ์ดํ„ฐ๋กœ ํ•™์Šต์„ ์ง„ํ–‰ํ•˜๊ณ , parameter server๋Š” worker๋กœ๋ถ€ํ„ฐ ์ „๋‹ฌ๋ฐ›์€ gradient๋ฅผ ํ•ฉ์นœ๋‹ค.

์ˆœ์„œ๋ฅผ ์‚ดํŽด๋ณด๋ฉด

  1. parameter server์—์„œ worker๋กœ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ๋ณต์‚ฌ
  2. ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹์„ worker์— ์ ์ ˆํžˆ ๋ฐฐ๋ถ„
  3. worker์—์„œ gradient ๊ณ„์‚ฐ
  4. gradient๋ฅผ parameter server์— ์ „๋‹ฌ
  5. parameter server์—์„œ ์ž์ฒด์ ์œผ๋กœ ๋ชจ๋ธ์„ ์—…๋ฐ์ดํŠธ

4. Communication primitives

๊ทธ๋Ÿฐ๋ฐ, data parallelism์—์„œ paramter server์™€ worker๊ฐ„์˜ ํ†ต์‹ ์€ ์–ด๋–ป๊ฒŒ ์ด๋ฃจ์–ด์งˆ๊นŒ?

๊ฐ€์žฅ ๋‹จ์ˆœํ•œ ๋ฐฉ์‹์€ ์†Œ์ผ“ ํ†ต์‹ ๊ฐ™์€ 1:1 ๋ฐฉ์‹์ด๋‹ค.

๋‹ค๋ฅธ ๋ฐฉ์‹์œผ๋กœ๋Š” ์ผ๋Œ€๋‹ค ๋ฐฉ์‹์ด ์žˆ๋‹ค. ํ•œ ๋…ธ๋“œ๋กœ๋ถ€ํ„ฐ ์ •๋ณด๋ฅผ ๋ถ„์‚ฐํ•˜๊ณ , ํ•ฉ์น˜๋Š” ๊ฒƒ์ด๋‹ค.

์ผ๋Œ€๋‹ค ๋ฐฉ์‹์—์„œ ์กฐ๊ธˆ ๋ณ€ํ˜•ํ•˜๋ฉด, reduce ๋ฐฉ์‹์ด ๋œ๋‹ค. reduce ๋ฐฉ์‹์€ gather๊ณผ ๋น„์Šทํ•˜์ง€๋งŒ, ์ •๋ณด๋ฅผ ๋‹จ์ˆœํžˆ ๋ชจ์œผ๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ํ•ฉ์น˜๊ฑฐ๋‚˜ ํ‰๊ท ์„ ๋‚ด๋Š” ๊ฒƒ์œผ๋กœ(๋นจํŒŒ๋…ธ์ดˆ์˜ ํ•ฉ)ํ•˜๋‚˜์˜ ๊ฐ’์„ ๋งŒ๋“ค์–ด๋‚ด๊ณ , ๊ทธ๋Ÿฐ ๋’ค broadcast๋ฅผ ํ†ตํ•ด ๊ฐ’์„ ๋ชจ๋“  node๋กœ ์†ก์‹ ํ•œ๋‹ค

๊ทธ ๋‹ค์Œ์€, ๋‹ค๋Œ€๋‹ค ๋ฐฉ์‹์ด๋‹ค. all-reduce๋Š” ๋ชจ๋“  ๋…ธ๋“œ์— ๋Œ€ํ•ด reduce๋ฅผ ์ง„ํ–‰ํ•˜๋Š” ๊ฒƒ์ด๊ณ , all-gather์„ ํ†ตํ•ด ํ•ฉ์น˜๊ฑฐ๋‚˜ ํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ๋ชจ๋“  ์ •๋ณด๋ฅผ ๋‹ค ๊ฐ–๊ฒŒ ๋œ๋‹ค.

์šฐ๋ฆฌ๊ฐ€ ์ด์ „ ์žฅ์—์„œ ์‚ดํŽด๋ณธ data parallelism ๋ฐฉ์‹์˜ ๋ณต์žก๋„๋ฅผ ๊ตฌํ•ด๋ณด๋ฉด, O(N)์˜ ๋ณต์žก๋„๋ฅผ๊ฐ€์ง€๊ณ  ์ด๋Š” ๊ฝค ๋ถ€๋‹ด๋˜๋Š” ์ •๋„์ด๋‹ค. ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด all reduce๋ผ๋Š” ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•œ๋‹ค.

๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ์‹์€, sequential๋กœ ํ•˜๋‚˜์”ฉ reduceํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค. ์ด ๋ฐฉ์‹์€ ์‹œ๊ฐ„๊ณผ bandwith ์ธก๋ฉด์—์„œ ๋ชจ๋‘ O(N)์˜ ๋ณต์žก๋„๋ฅผ ๊ฐ€์ง„๋‹ค.

Ring ๋ฐฉ์‹์ฒ˜๋Ÿผ ์˜†์˜ ๋…ธ๋“œ๋กœ๋งŒ ์ „์†กํ•˜๊ฒŒ ๋œ๋‹ค๋ฉด, ์‹œ๊ฐ„์€ ๊ทธ๋Œ€๋กœ O(N)์ด์ง€๋งŒ, bandwith๋Š” O(1)์˜ ๋ณต์žก๋„๋ฅผ ๊ฐ–๊ฒŒ ๋œ๋‹ค.

์ข…ํ•ฉํ•ด๋ณด๋ฉด ์œ„ ํ‘œ์™€ ๊ฐ™๋‹ค. ์–ด๋–ค ๋ฐฉ์‹๋„ O(N)์„ ์š”๊ตฌํ•˜๋Š”๋ฐ, ์–ด๋–ค ์‹์œผ๋กœ ๋” ์ค„์ผ ์ˆ˜ ์žˆ์„๊นŒ?

recursive all reduce ๋ฐฉ์‹์ด ๊ทธ ํ•ด๋ฒ•์ด๋‹ค. ์˜†์— ์žˆ๋Š” ๋…ธ๋“œ๋ผ๋ฆฌ ์ •๋ณด๋ฅผ ๊ตํ™˜ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ, log(N)์œผ๋กœ ํ•ด๊ฒฐ ๊ฐ€๋Šฅํ•˜๋‹ค.

5. Reducing memory in data parallelism: ZeRO-1 / 2 / 3 and FSDP

data parallelism์—์„œ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋” ์ค„์ด๋Š” ZeRO ์‹œ๋ฆฌ์ฆˆ ๋ฐฉ๋ฒ•์„ ์†Œ๊ฐœํ•œ๋‹ค.

ํ•™์Šต์„ ํ• ๋•Œ weight์™€ gradient ์ด์™ธ์—๋„ adam๋“ฑ์˜ optimizer๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด optimizer state(momentum, variance๊ฐ’, master copy)๋ฅผ ์ €์žฅํ•ด์•ผ ํ•œ๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ํ•˜๋‚˜์˜ ํŒŒ๋ผ๋ฏธํ„ฐ์— 16byte๊ฐ€ ํ•„์š”ํ•˜๊ฒŒ ๋œ๋‹ค(fp16๊ธฐ์ค€, weight=2, gradient=2, optim state=12(4*3)) ์œ„ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ optimizer state๋ฅผ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋˜๊ฒŒ ๋œ๋‹ค.

๊ทธ๋ž˜์„œ ๊ณ ์•ˆํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ด optimizer state๋ฅผ sharding(๋ถ„์‚ฐ์ €์žฅ)ํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค. ์ด๋Ÿฌ๋ฉด ํ•ด๋‹นํ•˜๋Š” ๊ตฌ๊ฐ„์— ๋Œ€ํ•œ weight๋ฐ–์— ์—…๋ฐ์ดํŠธ๋ฅผ ํ•˜์ง€ ๋ชปํ•˜๊ฒŒ ๋˜์ง€๋งŒ, ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์€ ์ค„์–ด๋“ค๊ฒŒ ๋œ๋‹ค.(12/N bytes for optim state)

optimizer state๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, gradient๋„ shardingํ•  ์ˆ˜ ์žˆ๋‹ค.(2/N bytes for gradient)

weight๊นŒ์ง€๋„ shardingํ•  ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ gradient์— ๋น„ํ•ด ์ด๊ฑด ์ข€ ์–ด๋ ค์šด๋ฐ, inference๋ฅผ ํ•  ๋•Œ ๋ชจ๋“  weight๊ฐ€ ํ•„์š”ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๋”ฐ๋ผ์„œ ์ด ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๋ฉด, inference์‹œ์—๋Š” ๋‹ค๋ฅธ GPU(node)๋กœ๋ถ€ํ„ฐ weight๊ฐ’์„ ๊ฐ€์ ธ์™€์„œ ์ง„ํ–‰ํ•˜๊ฒŒ ๋œ๋‹ค. (2/N bytes for weights)

6. Pipeline parallelism

data parallelism๊ณผ ๋‹ฌ๋ฆฌ pipeline parallelism์€, ๋ชจ๋ธ์„ ๋ถ„ํ• ํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค.

๊ฐ„๋‹จํ•˜๊ฒŒ ๊ตฌํ˜„์„ ์ƒ๊ฐํ•ด๋ณด๋ฉด, forward-backward์— ๋Œ€ํ•ด ์œ„์™€ ๊ฐ™์ด ์ผ์ž์ ์ธ ๊ตฌ์กฐ๋กœ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ, ์ด๋Ÿฐ ๋ฐฉ์‹์ด๋ฉด F0์ดํ›„ ํ•œ์ฐธ ๋’ค์— B0์ด ์ด๋ฃจ์–ด ์ง€๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ๋น„๊ฒŒ๋˜๋Š” ์‹œ๊ฐ„์ด ๋งŽ์•„์ง„๋‹ค.

์ด๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋ฐฉ๋ฒ•์ด Gpipe์— ์ œ์‹œ๋œ๋‹ค. batch๋ฅผ ์ž˜๊ฒŒ ์ชผ๊ฐœ์„œ forward-backward๊ฐ„์˜ ๊ฐ„๊ฒฉ์„ ์ค„์ด๋Š” ๋ฐฉ์‹์ด๋‹ค. ์ด์ „์˜ naiveํ•œ ๋ฐฉ๋ฒ•์—์„œ ์ด ๋ฐฉ๋ฒ•์œผ๋กœ ๋ฐ”๊ฟ€ ์‹œ 2.5๋ฐฐ์ •๋„ utilization์„ ๋Š˜๋ฆด ์ˆ˜ ์žˆ๋‹ค.(25%->57%)

7. Tensor parallelism

ํ•˜์ง€๋งŒ 57%์˜ utilization๋„ ์ƒ๊ฐํ•ด๋ดค์„ ๋•Œ ๋„ˆ๋ฌด ์ ๊ฒŒ ํ™œ์šฉ๋˜๋Š” ๊ฒƒ ๊ฐ™๊ธฐ๋„ ํ•˜๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ๋” ์ž˜๊ฒŒ ์ชผ๊ฐœ๋Š” ๋ฐฉ์‹์„ ์ƒ๊ฐํ•ด๋ณผ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.

์œ„ ๊ทธ๋ฆผ์€ ๊ฐ๊ฐ MLP์™€ self-attention layer์—์„œ tensor parallelism์„ ์ ์šฉํ–ˆ์„ ๋–„์˜ ๊ตฌ์กฐ์ด๋‹ค. ๋…น์ƒ‰ f ๋ฅผ ํ†ตํ•ด X๋ฅผ N๊ฐœ(์˜ˆ์‹œ์—์„œ๋Š” 2๊ฐœ)์˜ chunk๋กœ ์ชผ๊ฐœ๊ณ , ๊ฐ๊ฐ์˜ X๋ฅผ ์„œ๋กœ๋‹ค๋ฅธ GPU์—์„œ ์—ฐ์‚ฐ์„ ์ง„ํ–‰ํ•œ๋‹ค.

8. Hybrid (mixed) parallelism and how to auto-parallelize

์œ„์—์„œ ์šฐ๋ฆฌ๋Š” 3๊ฐ€์ง€ ๋ฐฉ๋ฒ•์˜ parallelism ๋ฐฉ๋ฒ•์„ ์‚ดํŽด๋ณด์•˜๋‹ค. ์ด๋ฒˆ ์žฅ์—์„œ๋Š” ์ด๋Ÿฐ ๋ฐฉ๋ฒ•๋“ค์„ ์กฐํ•ฉํ•ด์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์‚ดํŽด๋ณธ๋‹ค.

๋จผ์ € ๊ฐ ๋ณ‘๋ ฌํ™” ๋ฐฉ๋ฒ•์„ ๋‹ค์‹œ ์š”์•ฝํ•ด๋ณด๋ฉด, data parallelism์€ ๋ฐ์ดํ„ฐ๋ฅผ ๋‚˜๋ˆ„๋Š” ๊ฒƒ์ด๊ณ , pipieline parallelism์€ model์„ layer๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๋Š” ๊ฒƒ์ด๋ฉฐ tensor parallelism์€ model์„ tensor๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๋Š” ๊ฒƒ์ด๋‹ค.

์œ„ ๊ทธ๋ฆผ์€ Data parallelism + Pipeline parallelism์„ ๋™์‹œ์— ์ ์šฉํ•œ ๋ชจ์Šต์ด๋‹ค. 2๊ฐœ์˜ GPU๋‹จ์œ„๋กœ data๋ฅผ ๋‚˜๋ˆ„๊ณ , 2๊ฐœ GPU๊ฐ„์— pipeline parallelism์„ ์ ์šฉํ•œ๋‹ค.

์œ„ ๊ทธ๋ฆผ์€ pipeline parallelism + tensor parallelism์ด๋‹ค. ์ด๋ ‡๋“ฏ 2๊ฐœ์˜ ๊ธฐ๋ฒ•๋งŒ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ 3๊ฐœ์˜ ๋ณ‘๋ ฌํ™” ๋ฐฉ๋ฒ•์„ ๋™์‹œ์— ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ๋‹ค(3d parallelism)

๊ทธ๋ฆผ์œผ๋กœ ํ‘œํ˜„ํ•˜๋ฉด ์œ„์™€ ๊ฐ™๋‹ค.

์ƒ๊ฐํ•ด ๋ณผ ๋ฌธ์ œ๋Š”, ์ด๋ ‡๊ฒŒ ์—ฌ๋Ÿฌ๊ฐœ์˜ ๋ณ‘๋ ฌํ™” ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•  ๋•Œ, ์–ด๋–ค ์‹์œผ๋กœ ์‚ฌ์šฉํ•ด์•ผ ์ตœ์ ์˜ ๋ฐฉ์‹์ด ๋˜๋Š๋ƒ์˜ ๋ฌธ์ œ์ด๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ๋Š” ๋ชจ๋ธ์ด GPU์—์„œ ๋Œ์•„๊ฐ€์ง€ ์•Š์„ ์ •๋„๋กœ ํฌ๋‹ค๋ฉด pipeline, layer๊ฐ€ GPU์—์„œ ๋Œ์•„๊ฐ€์ง€ ์•Š์„ ์ •๋„๋กœ ํฌ๋‹ค๋ฉด tensor ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜์ง€๋งŒ ๋‹จ์ˆœํžˆ ๊ทธ๋Ÿฐ ๋ฐฉ์‹์„ ์ ์š”ํ•˜๋Š” ๊ฒƒ์ด ์ตœ์ ์˜ ๋ฐฉ๋ฒ•์€ ์•„๋‹ˆ๋‹ค.

์ด๋ฅผ NAS์™€ ๋น„์Šทํ•œ ๋ฐฉ์‹์œผ๋กœ ์—ฐ์‚ฐ๊ฐ„์˜ ๊ด€๊ณ„๋ฅผ ์„ค์ •ํ•˜๋Š” inter-op pass์™€ ์—ฐ์‚ฐ ๋‚ด๋ถ€์˜ ๋™์ž‘์„ ์„ค์ •ํ•˜๋Š” intra-op pass ๋‘ ๋‹จ๊ณ„๋ฅผ ํ†ตํ•ด ์ž๋™์œผ๋กœ ์ฐพ์•„์ฃผ๋Š” ๋ฐฉ์‹์ด ์กด์žฌํ•œ๋‹ค.

9 Understand the bandwidth and latency bottleneck of distributed training

๋ถ„์‚ฐ ํ•™์Šต ๋ฐฉ์‹์—๋Š” ๋ฐ˜๋“œ์‹œ ๋”ฐ๋ผ์˜ค๋Š” ๋ฌธ์ œ์ ์ด, communication์— ๋”ฐ๋ฅธ bottleneck์ด๋‹ค. ๋ชจ๋ธ์ด ์ปค์ง€๋ฉด ๋” ์—ฌ๋Ÿฌ๊ณณ์— ๋ถ„์‚ฐํ•ด์•ผ ํ•˜๊ณ , ๋ฐ์ดํ„ฐ ํฌ๊ธฐ๋„ ์ปค์ง€๋ฉฐ, ์ „์†ก ์†๋„๋„ ๊ธธ์–ด์ง„๋‹ค. ๊ทธ๋ ‡๊ฒŒ ๋˜๋ฉด communication latency๊ฐ€ ๊ธธ์–ด์งˆ ์ˆ˜ ๋ฐ–์— ์—†๋‹ค.

์‹ค์ œ๋กœ, ๋ถ„์‚ฐ ํ•™์Šต์—์„œ GPU์˜ ๊ฐœ์ˆ˜์™€ speed๋Š” ์ •ํ™•ํžˆ y=x์˜ ๊ทธ๋ž˜ํ”„๋ฅผ ๊ทธ๋ฆฌ์ง€ ์•Š๋Š”๋‹ค. ์ด๋Š” ์ค‘๊ฐ„์ค‘๊ฐ„์— bottleneck์œผ๋กœ ์ž‘์šฉํ•˜๋Š” ๋ถ€๋ถ„์ด ์žˆ์Œ์„ ๋ณด์—ฌ์ค€๋‹ค.

๊ทธ๋ฆฌ๊ณ  ์ด latency๋Š” node๊ฐ„ ๊ฑฐ๋ฆฌ๊ฐ€ ๋ฉ€์ˆ˜๋ก ๋‹น์—ฐํžˆ ํ›จ์”ฌ ๋Š˜์–ด๋‚œ๋‹ค.

10. Gradient compression: overcome the bandwidth bottleneck

๊ทธ๋ ‡๋‹ค๋ฉด ์ด๋Ÿฐ data์ „์†ก์—์„œ bottleneck์„ ์ค„์ด๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ํ• ๊นŒ?

์šฐ๋ฆฌ๊ฐ€ ์ด์ „์— ํ•ด์™”๋˜ pruning, quantization ๋ฐฉ์‹์„ ๋˜‘๊ฐ™์ด ์ ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. pruning/quantization์œผ๋กœ ๋ฐ์ดํ„ฐ ํฌ๊ธฐ๋ฅผ ์ค„์ธ๋‹ค๋ฉด, ๋‹น์—ฐํžˆ ์ „์†กํ•  ๋ฐ์ดํ„ฐ ์–‘๋„ ์ค„์–ด๋“ค๊ณ  bottleneck๋„ ์ค„์–ด๋“ ๋‹ค.

10.1. Gradient Pruning: Sparse Communication, Deep Gradient Compression

pruning ๋ฐฉ์‹์€ ๊ฐ„๋‹จํ•˜๊ฒŒ gradient์ค‘์—์„œ top-k๊ฐœ๋งŒ ์ „์†กํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค. top-k์˜ ๊ธฐ์ค€์€ ๋‹จ์ˆœํžˆ magnitude๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ์ „์†ก๋˜์ง€ ์•Š์€ gradient๋„ error feedback์„ ํ†ตํ•ด local์— ๋‚ด๋น„๋‘์–ด ์‚ฌ์šฉํ•œ๋‹ค.

ํ•˜์ง€๋งŒ ์ด๋Ÿฐ ๋ฐฉ์‹์€ ๊ฝค๋‚˜ ์„ฑ๋Šฅ ์ €ํ•˜๋ฅผ ์ผ์œผํ‚จ๋‹ค. gradient๋งŒ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ, momentum์ด ์—†๊ธฐ ๋–„๋ฌธ์ด๋‹ค.

๋‹จ์ˆœํžˆ accumulateํ•˜๋ฉด ์œ„์™€ ๊ฐ™์ด momentum์„ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ์‹๊ณผ ๊ฝค ๋‹ฌ๋ผ์ง€๊ฒŒ ๋œ๋‹ค.

๋”ฐ๋ผ์„œ gradient๊ฐ€ ์•„๋‹ˆ๋ผ velocity๋ฅผ accumulateํ•˜๋Š” ๊ฒƒ์ด ๋” ์ข‹๋‹ค. ๋˜ํ•œ ์—ฌ๋Ÿฌ๊ฐ€์ง€ warm up ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค.

learning rate๋Š” ๋ฌผ๋ก ์ด๊ณ , pruning sparsity๋„ warm up ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ๋„์›€์ด ๋œ๋‹ค. optimizer๊ฐ€ ์ ์‘ํ•˜๋Š”๋ฐ ๋„์›€์„ ์ค€๋‹ค.

pruning ๋ฐฉ์‹์˜ ๋ฌธ์ œ์ ์€, ์—ฌ๋Ÿฌ node๋ผ๋ฆฌ ์ •๋ณด๋ฅผ ๊ตํ™˜ํ•˜๋Š” all-reduce ๊ณผ์ •์„ ๊ฑฐ์น˜๋ฉด์„œ pruning ํ•˜๋Š” ์˜๋ฏธ๊ฐ€ ์—†์–ด์ง„๋‹ค๋Š” ๊ฒƒ์ด๋‹ค(๋” denseํ•ด์ง„๋‹ค)

์ด๋Ÿฐ ๋ฐฉ์‹์„ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด sparseํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ low rank matrix๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๊ธฐ๋„ ํ•œ๋‹ค.

10.2. Gradient Quantization: 1-Bit SGD, TernGrad

quantization ๋ฐฉ๋ฒ•์„ ์†Œ๊ฐœํ•˜๋ฉด, 1bit-SGD ๋ฐฉ์‹์ด ์žˆ๋‹ค. ์ด๋Š” 0๋ณด๋‹ค ํฌ๋ฉด +, ์•„๋‹ˆ๋ฉด -๋กœ ๋‘” ๋’ค scaling factor(u1~u4)๋ฅผ colum๋งˆ๋‹ค ์ ์šฉํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค. quantization error๋Š” locallyํ•˜๊ฒŒ ์ง‘๊ณ„๋˜๊ณ , quantize๋œ gradient๋งŒ ์ „์†ก๋œ๋‹ค.

๋น„์Šทํ•œ ๋ฐฉ์‹์œผ๋กœ, threshold๋ฅผ ํŠน์ • ๊ฐ’์œผ๋กœ ๋‘๊ณ  quantizeํ•˜๋Š” ๋ฐฉ์‹๋„ ์žˆ๋‹ค.

Terngrad๋Š” ํ™•๋ฅ ๊ฐ’์— ๋”ฐ๋ผ 0, 1, -1 ์ค‘ ํ•˜๋‚˜๋กœ ์–‘์žํ™” ํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค.

11. Delayed gradient update: overcome the latency bottleneck

pruning์ด๋‚˜ quantize ๋ฐฉ์‹์œผ๋กœ bandwith๋ฌธ์ œ๋Š” ํ•ด๊ฒฐํ–ˆ์ง€๋งŒ, latency๋Š” ์ด๋Ÿฐ ๋ฐฉ๋ฒ•๋“ค๋กœ๋Š” ํ•ด๊ฒฐํ•  ์ˆ˜ ์—†๋‹ค.

๊ฑฐ๋ฆฌ๋‚˜ ์‹ ํ˜ธ ํ˜ผ์žก๊ฐ™์€ ๋ฌผ๋ฆฌ์ ์ธ ์ด์œ ๊ฐ€ ์žˆ๋‹ค.

๊ธฐ๋ณธ์ ์ธ ๋ถ„์‚ฐ ์ฒ˜๋ฆฌ ๋ฐฉ์‹์—์„œ๋Š”, ๊ณ„์‚ฐ->์ „์†ก->๊ณ„์‚ฐ ๋ฐฉ์‹์œผ๋กœ ์ด๋ฃจ์–ด์ง€๋Š”๋ฐ, ์ด๋Ÿฐ ๊ตฌ์กฐ์—์„œ๋Š” ์ „์†ก ์‹œ๊ฐ„์ด ๋Š˜์–ด๋‚˜๋ฉด ๊ทธ ์ฆ‰์‹œ ์ „์ฒด ์‹œ๊ฐ„์ด ๋Š˜์–ด๋‚œ๋‹ค

๊ทธ๋ž˜์„œ ์ „์†ก๊ณผ ๊ณ„์‚ฐ์„ ๋™์‹œ์— ์ง„ํ–‰ํ•˜๋ฉด, ์–ด๋Š์ •๋„ ์‹œ๊ฐ„์„ ์ ˆ์•ฝํ•  ์ˆ˜ ์žˆ๋‹ค.

์ „์†ก ์‹œ๊ฐ„์ด ์–ด๋Š์ •๋„ ๋Š˜์–ด๋‚˜๋”๋ผ๋„, training์„ ๋ฐฉํ•ดํ•˜์ง€ ์•Š๋Š”๋‹ค.

์—ฌ๊ธฐ๊นŒ์ง€ tinyML 17,18๊ฐ•์ธ distributed learning์— ๊ด€ํ•œ ์ •๋ฆฌ์˜€๋‹ค. ์‹ค์ œ๋กœ ๋ชจ๋ธ์ด ์ ์  ์ปค์ง€๋Š” ๋งŒํผ, ํ˜น์€ edge ํ™˜๊ฒฝ์—์„œ ๊ฐœ๊ฐœ์ธ์˜ ๋ฐ์ดํ„ฐ๋กœ ํ•™์Šต์„ ํ•˜๋Š” ๊ฒฝ์šฐ์—๋„ ๋ถ„์‚ฐ ๋ฐฉ์‹์ด ์œ ์šฉํ•˜๊ฒŒ ์‚ฌ์šฉ๋œ๋‹ค. ํ•˜์ง€๋งŒ ๋ถ„์‚ฐ ๋ฐฉ์‹์—์„œ๋Š” ๊ฐœ์ธ์˜ ๋ฐ์ดํ„ฐ๊ฐ€ ์œ ์ถœ๋  ์œ„ํ—˜์ด ์กด์žฌํ•˜๊ธฐ๋„ ํ•œ๋‹ค. ๊ทธ๋Ÿฐ ๋ฐฉ๋ฒ•์ด ๋ฌด์—‡์ธ์ง€, ์–ด๋–ป๊ฒŒ ๋ง‰์„ ์ˆ˜ ์žˆ๋Š”์ง€๋Š” ๋’ค์˜ ๊ฐ•์˜์—์„œ ๋‹ค๋ฃจ๋‹ˆ, ๊ด€์‹ฌ์žˆ๋Š” ๋ถ„์€ ๋’ค์˜ ํฌ์ŠคํŒ…๋„ ์ฐพ์•„๋ณด์‹œ๊ธธ ๋ฐ”๋ž€๋‹ค.