๐งโ๐ซ Lecture 17-18
์ด๋ฒ ๊ฐ์๋ ๋ถ์ฐ ํ์ต(distributed training)์ ๋ํ ๋ด์ฉ์ ๋ค๋ฃฌ๋ค. LLM๋ฑ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ์ ์ ๋ณต์กํด์ง๋ฉด์, ๋๊ท๋ชจ ๋ชจ๋ธ์ ํ๋ จ์ ๋จ์ผ GPU๋ก๋ ๋ ์ด์ ์ถฉ๋ถํ์ง ์๊ณ , ์ฌ๋ฌ๊ฐ์ GPU๋ฅผ ์ฌ์ฉํด์ผ ๊ฐ๋ฅํ๋ค. ์ด๋ฒ ๊ธ์์๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ๋ จ์ ๊ฐ์ํํ๊ธฐ ์ํ ๋ค์ํ ๋ถ์ฐ ํ๋ จ ๊ธฐ๋ฒ๋ค์ ๋ค๋ฃจ๊ณ ์ ํ๋ค. ๋ฐ์ดํฐ ๋ณ๋ ฌํ, ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ, ํ ์ ๋ณ๋ ฌํ ๋ฑ์ ๊ธฐ๋ฒ๋ค์ ํตํด ํจ์จ์ ์ธ ๋ถ์ฐ ํ๋ จ์ ์คํํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ดํด๋ณผ ๊ฒ์ด๋ค. ๋ํ, ๋ถ์ฐ ํ๋ จ ๋ฟ๋ง ์๋๋ผ GPU๊ฐ ํต์ ์ ๋ณ๋ชฉ ํ์์ ๋ํ ๋์ฒ๋ฒ์ผ๋ก ๊ฐ์ค์น ์์ถ(gradient compression) ๋ฐ ์ง์ฐ ๊ฐ์ค์น ์ ๋ฐ์ดํธ(delayed gradient update) ๋ํ ์ดํด๋ณผ ๊ฒ์ด๋ค.
1. Background and motivation
์ ํ๋ Nvidia A100 ๊ธฐ์ค์ผ๋ก ๋ชจ๋ธ๋ณ๋ก ํ๋ จ์ ๊ฑธ๋ฆฌ๋ ์๊ฐ์ด๋ค. distributed training์์ด ๋จ์ผ GPU๋ก ํ์ตํ๋ค๋ฉด, GPT-3๋ ํ์ตํ๋๋ฐ 355๋ ์ด๋ ๊ฑธ๋ฆฐ๋ค.
์ ๋๋ด์ฒ๋ผ, ํ๋ จ์ ๋๋ฌด ๋ง์ ์๊ฐ์ด ๊ฑธ๋ฆฐ๋ค๋ฉด ๋ชจ๋ธ์ ๋ฐ์ ์ํค๊ธฐ ์ด๋ ค์ธ ๊ฒ์ด๋ค. distribution training์ ์ฌ์ฉํ๋ฉด, ์ด์์ ์ผ๋ก๋ 10์ผ์ด ๊ฑธ๋ฆฌ๋ ์์ ์ 1024๊ฐ์ GPU๋ก๋ 14๋ถ๋ง์ ํ์ต์ด ๊ฐ๋ฅํ๋ค.
2. Parallelization methods for distributed training
๊ทธ๋ ๋ค๋ฉด ์ด๋ป๊ฒ ๋ณ๋ ฌ๋ก ํ๋ จ์ํฌ ์ ์์๊น? ์ด๋ฒ ์ฅ์์๋ ๊ฐ๋จํ ์ธ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์๊ฐํ๊ณ , ๊ฐ ๋ฐฉ๋ฒ์ ๋ํ ์์ธํ ์ค๋ช ์ ํ๋์ฉ ๋ค๋ฃจ๋๋ก ํ๊ฒ ๋ค.
์ฒซ ๋ฒ์งธ๋ data parallelism์ด๋ค. ์ด ๋ฐฉ๋ฒ์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ์ฌ๋ฌ๊ฐ์ GPU์ ๋๋ ์ ํ์ตํ ๋ค, ๊ณ์ฐ๋ ๊ฐ๊ฐ์ gradient๋ฅผ ๋ค์ ์ค์์์ ํฉ์น๋ ๋ฐฉ์์ด๋ค.
๋ ๋ฒ์งธ๋ pipeline parallelism์ด๋ค. ์ด ๋ฐฉ๋ฒ์ ๋ชจ๋ธ์ layer ๋ณ๋ก ๋๋ ์ ํ ๋นํ๋ ๊ฒ์ด๋ค. GPU๋ณ๋ก layer๋ฅผ ๋๋ ๊ณ์ฐํ๋ ๊ฒ์ธ๋ฐ, ํ๋์ GPU์ ๋ชจ๋ธ์ ๋ชจ๋ ๊ฐ์ค์น๋ฅผ ๋ฃ๊ธฐ์ ๋๋ฌด ํด ๋ ์ฌ์ฉํ๋ค.
๋ง์ง๋ง์ผ๋ก tensor parallelism์ด๋ค. ์ด๋ layer๋ณด๋ค ๋ ์๊ฒ, ํ ์ ๋จ์๋ก ์ชผ๊ฐ๋ ๊ฒ์ด๋ค. ์ง๊ด์ ์ผ๋ก ๋ณด๋ฉด ๋ชจ๋ธ์ ์ธ๋ก๋ก ์ชผ๊ฐ์ ์๋ก ๋ค๋ฅธ GPU์ ํ ๋นํ๋ ๊ฒ์ธ๋ฐ, ์์ธํ ์ค๋ช ์ ๋ค์ ๋ค์ ๋ค๋ฃจ๊ฒ ๋ค
3. Data parallelism
data parallelism์ ํ ๋์๋ parameter server์ worker nodes๋ผ๋ ๋ ๊ฐ์ง ์์๊ฐ ์กด์ฌํ๋ค. worker nodes๋ ๋ถํ ๋ ๋ฐ์ดํฐ๋ก ํ์ต์ ์งํํ๊ณ , parameter server๋ worker๋ก๋ถํฐ ์ ๋ฌ๋ฐ์ gradient๋ฅผ ํฉ์น๋ค.
์์๋ฅผ ์ดํด๋ณด๋ฉด
- parameter server์์ worker๋ก ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ๋ณต์ฌ
- ํ์ต ๋ฐ์ดํฐ์ ์ worker์ ์ ์ ํ ๋ฐฐ๋ถ
- worker์์ gradient ๊ณ์ฐ
- gradient๋ฅผ parameter server์ ์ ๋ฌ
- parameter server์์ ์์ฒด์ ์ผ๋ก ๋ชจ๋ธ์ ์ ๋ฐ์ดํธ
4. Communication primitives
๊ทธ๋ฐ๋ฐ, data parallelism์์ paramter server์ worker๊ฐ์ ํต์ ์ ์ด๋ป๊ฒ ์ด๋ฃจ์ด์ง๊น?
๊ฐ์ฅ ๋จ์ํ ๋ฐฉ์์ ์์ผ ํต์ ๊ฐ์ 1:1 ๋ฐฉ์์ด๋ค.
๋ค๋ฅธ ๋ฐฉ์์ผ๋ก๋ ์ผ๋๋ค ๋ฐฉ์์ด ์๋ค. ํ ๋ ธ๋๋ก๋ถํฐ ์ ๋ณด๋ฅผ ๋ถ์ฐํ๊ณ , ํฉ์น๋ ๊ฒ์ด๋ค.
์ผ๋๋ค ๋ฐฉ์์์ ์กฐ๊ธ ๋ณํํ๋ฉด, reduce ๋ฐฉ์์ด ๋๋ค. reduce ๋ฐฉ์์ gather๊ณผ ๋น์ทํ์ง๋ง, ์ ๋ณด๋ฅผ ๋จ์ํ ๋ชจ์ผ๋ ๊ฒ์ด ์๋๋ผ ํฉ์น๊ฑฐ๋ ํ๊ท ์ ๋ด๋ ๊ฒ์ผ๋ก(๋นจํ๋ ธ์ด์ ํฉ)ํ๋์ ๊ฐ์ ๋ง๋ค์ด๋ด๊ณ , ๊ทธ๋ฐ ๋ค broadcast๋ฅผ ํตํด ๊ฐ์ ๋ชจ๋ node๋ก ์ก์ ํ๋ค
๊ทธ ๋ค์์, ๋ค๋๋ค ๋ฐฉ์์ด๋ค. all-reduce๋ ๋ชจ๋ ๋ ธ๋์ ๋ํด reduce๋ฅผ ์งํํ๋ ๊ฒ์ด๊ณ , all-gather์ ํตํด ํฉ์น๊ฑฐ๋ ํ๋ ๊ฒ์ด ์๋๋ผ ๋ชจ๋ ์ ๋ณด๋ฅผ ๋ค ๊ฐ๊ฒ ๋๋ค.
์ฐ๋ฆฌ๊ฐ ์ด์ ์ฅ์์ ์ดํด๋ณธ data parallelism ๋ฐฉ์์ ๋ณต์ก๋๋ฅผ ๊ตฌํด๋ณด๋ฉด, O(N)์ ๋ณต์ก๋๋ฅผ๊ฐ์ง๊ณ ์ด๋ ๊ฝค ๋ถ๋ด๋๋ ์ ๋์ด๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด all reduce๋ผ๋ ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ์์, sequential๋ก ํ๋์ฉ reduceํ๋ ๋ฐฉ์์ด๋ค. ์ด ๋ฐฉ์์ ์๊ฐ๊ณผ bandwith ์ธก๋ฉด์์ ๋ชจ๋ O(N)์ ๋ณต์ก๋๋ฅผ ๊ฐ์ง๋ค.
Ring ๋ฐฉ์์ฒ๋ผ ์์ ๋ ธ๋๋ก๋ง ์ ์กํ๊ฒ ๋๋ค๋ฉด, ์๊ฐ์ ๊ทธ๋๋ก O(N)์ด์ง๋ง, bandwith๋ O(1)์ ๋ณต์ก๋๋ฅผ ๊ฐ๊ฒ ๋๋ค.
์ข ํฉํด๋ณด๋ฉด ์ ํ์ ๊ฐ๋ค. ์ด๋ค ๋ฐฉ์๋ O(N)์ ์๊ตฌํ๋๋ฐ, ์ด๋ค ์์ผ๋ก ๋ ์ค์ผ ์ ์์๊น?
recursive all reduce ๋ฐฉ์์ด ๊ทธ ํด๋ฒ์ด๋ค. ์์ ์๋ ๋ ธ๋๋ผ๋ฆฌ ์ ๋ณด๋ฅผ ๊ตํํ๋ ๋ฐฉ์์ผ๋ก, log(N)์ผ๋ก ํด๊ฒฐ ๊ฐ๋ฅํ๋ค.
5. Reducing memory in data parallelism: ZeRO-1 / 2 / 3 and FSDP
data parallelism์์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ ์ค์ด๋ ZeRO ์๋ฆฌ์ฆ ๋ฐฉ๋ฒ์ ์๊ฐํ๋ค.
ํ์ต์ ํ ๋ weight์ gradient ์ด์ธ์๋ adam๋ฑ์ optimizer๋ฅผ ์ฌ์ฉํ๋ฉด optimizer state(momentum, variance๊ฐ, master copy)๋ฅผ ์ ์ฅํด์ผ ํ๋ค. ๊ทธ๋ ๋ค๋ฉด ํ๋์ ํ๋ผ๋ฏธํฐ์ 16byte๊ฐ ํ์ํ๊ฒ ๋๋ค(fp16๊ธฐ์ค, weight=2, gradient=2, optim state=12(4*3)) ์ ๊ทธ๋ฆผ์ฒ๋ผ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๊ฐ optimizer state๋ฅผ ์ ์ฅํ๊ธฐ ์ํด ์ฌ์ฉ๋๊ฒ ๋๋ค.
๊ทธ๋์ ๊ณ ์ํ ์ ์๋ ๋ฐฉ๋ฒ์ด optimizer state๋ฅผ sharding(๋ถ์ฐ์ ์ฅ)ํ๋ ๋ฐฉ์์ด๋ค. ์ด๋ฌ๋ฉด ํด๋นํ๋ ๊ตฌ๊ฐ์ ๋ํ weight๋ฐ์ ์ ๋ฐ์ดํธ๋ฅผ ํ์ง ๋ชปํ๊ฒ ๋์ง๋ง, ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๋ค๊ฒ ๋๋ค.(12/N bytes for optim state)
optimizer state๋ฟ๋ง ์๋๋ผ, gradient๋ shardingํ ์ ์๋ค.(2/N bytes for gradient)
weight๊น์ง๋ shardingํ ์ ์๋ค. ํ์ง๋ง gradient์ ๋นํด ์ด๊ฑด ์ข ์ด๋ ค์ด๋ฐ, inference๋ฅผ ํ ๋ ๋ชจ๋ weight๊ฐ ํ์ํ๊ธฐ ๋๋ฌธ์ด๋ค. ๋ฐ๋ผ์ ์ด ๋ฐฉ์์ ์ฌ์ฉํ๋ฉด, inference์์๋ ๋ค๋ฅธ GPU(node)๋ก๋ถํฐ weight๊ฐ์ ๊ฐ์ ธ์์ ์งํํ๊ฒ ๋๋ค. (2/N bytes for weights)
6. Pipeline parallelism
data parallelism๊ณผ ๋ฌ๋ฆฌ pipeline parallelism์, ๋ชจ๋ธ์ ๋ถํ ํ๋ ๋ฐฉ์์ด๋ค.
๊ฐ๋จํ๊ฒ ๊ตฌํ์ ์๊ฐํด๋ณด๋ฉด, forward-backward์ ๋ํด ์์ ๊ฐ์ด ์ผ์์ ์ธ ๊ตฌ์กฐ๋ก ๊ตฌํํ ์ ์๋ค. ํ์ง๋ง, ์ด๋ฐ ๋ฐฉ์์ด๋ฉด F0์ดํ ํ์ฐธ ๋ค์ B0์ด ์ด๋ฃจ์ด ์ง๋ ๊ฒ์ฒ๋ผ ๋น๊ฒ๋๋ ์๊ฐ์ด ๋ง์์ง๋ค.
์ด๋ฅผ ํด๊ฒฐํ๋ ๋ฐฉ๋ฒ์ด Gpipe์ ์ ์๋๋ค. batch๋ฅผ ์๊ฒ ์ชผ๊ฐ์ forward-backward๊ฐ์ ๊ฐ๊ฒฉ์ ์ค์ด๋ ๋ฐฉ์์ด๋ค. ์ด์ ์ naiveํ ๋ฐฉ๋ฒ์์ ์ด ๋ฐฉ๋ฒ์ผ๋ก ๋ฐ๊ฟ ์ 2.5๋ฐฐ์ ๋ utilization์ ๋๋ฆด ์ ์๋ค.(25%->57%)
7. Tensor parallelism
ํ์ง๋ง 57%์ utilization๋ ์๊ฐํด๋ดค์ ๋ ๋๋ฌด ์ ๊ฒ ํ์ฉ๋๋ ๊ฒ ๊ฐ๊ธฐ๋ ํ๋ค. ๊ทธ๋ ๋ค๋ฉด ๋ ์๊ฒ ์ชผ๊ฐ๋ ๋ฐฉ์์ ์๊ฐํด๋ณผ ์ ์์ ๊ฒ์ด๋ค.
์ ๊ทธ๋ฆผ์ ๊ฐ๊ฐ MLP์ self-attention layer์์ tensor parallelism์ ์ ์ฉํ์ ๋์ ๊ตฌ์กฐ์ด๋ค. ๋ น์ f ๋ฅผ ํตํด X๋ฅผ N๊ฐ(์์์์๋ 2๊ฐ)์ chunk๋ก ์ชผ๊ฐ๊ณ , ๊ฐ๊ฐ์ X๋ฅผ ์๋ก๋ค๋ฅธ GPU์์ ์ฐ์ฐ์ ์งํํ๋ค.
8. Hybrid (mixed) parallelism and how to auto-parallelize
์์์ ์ฐ๋ฆฌ๋ 3๊ฐ์ง ๋ฐฉ๋ฒ์ parallelism ๋ฐฉ๋ฒ์ ์ดํด๋ณด์๋ค. ์ด๋ฒ ์ฅ์์๋ ์ด๋ฐ ๋ฐฉ๋ฒ๋ค์ ์กฐํฉํด์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ดํด๋ณธ๋ค.
๋จผ์ ๊ฐ ๋ณ๋ ฌํ ๋ฐฉ๋ฒ์ ๋ค์ ์์ฝํด๋ณด๋ฉด, data parallelism์ ๋ฐ์ดํฐ๋ฅผ ๋๋๋ ๊ฒ์ด๊ณ , pipieline parallelism์ model์ layer๋จ์๋ก ๋๋๋ ๊ฒ์ด๋ฉฐ tensor parallelism์ model์ tensor๋จ์๋ก ๋๋๋ ๊ฒ์ด๋ค.
์ ๊ทธ๋ฆผ์ Data parallelism + Pipeline parallelism์ ๋์์ ์ ์ฉํ ๋ชจ์ต์ด๋ค. 2๊ฐ์ GPU๋จ์๋ก data๋ฅผ ๋๋๊ณ , 2๊ฐ GPU๊ฐ์ pipeline parallelism์ ์ ์ฉํ๋ค.
์ ๊ทธ๋ฆผ์ pipeline parallelism + tensor parallelism์ด๋ค. ์ด๋ ๋ฏ 2๊ฐ์ ๊ธฐ๋ฒ๋ง์ ์ฌ์ฉํ๋ ๊ฒ์ด ์๋๋ผ 3๊ฐ์ ๋ณ๋ ฌํ ๋ฐฉ๋ฒ์ ๋์์ ์ฌ์ฉํ ์๋ ์๋ค(3d parallelism)
๊ทธ๋ฆผ์ผ๋ก ํํํ๋ฉด ์์ ๊ฐ๋ค.
์๊ฐํด ๋ณผ ๋ฌธ์ ๋, ์ด๋ ๊ฒ ์ฌ๋ฌ๊ฐ์ ๋ณ๋ ฌํ ๋ฐฉ์์ ์ฌ์ฉํ ๋, ์ด๋ค ์์ผ๋ก ์ฌ์ฉํด์ผ ์ต์ ์ ๋ฐฉ์์ด ๋๋๋์ ๋ฌธ์ ์ด๋ค. ์ผ๋ฐ์ ์ผ๋ก๋ ๋ชจ๋ธ์ด GPU์์ ๋์๊ฐ์ง ์์ ์ ๋๋ก ํฌ๋ค๋ฉด pipeline, layer๊ฐ GPU์์ ๋์๊ฐ์ง ์์ ์ ๋๋ก ํฌ๋ค๋ฉด tensor ๋ฐฉ์์ ์ฌ์ฉํ์ง๋ง ๋จ์ํ ๊ทธ๋ฐ ๋ฐฉ์์ ์ ์ํ๋ ๊ฒ์ด ์ต์ ์ ๋ฐฉ๋ฒ์ ์๋๋ค.
์ด๋ฅผ NAS์ ๋น์ทํ ๋ฐฉ์์ผ๋ก ์ฐ์ฐ๊ฐ์ ๊ด๊ณ๋ฅผ ์ค์ ํ๋ inter-op pass์ ์ฐ์ฐ ๋ด๋ถ์ ๋์์ ์ค์ ํ๋ intra-op pass ๋ ๋จ๊ณ๋ฅผ ํตํด ์๋์ผ๋ก ์ฐพ์์ฃผ๋ ๋ฐฉ์์ด ์กด์ฌํ๋ค.
9 Understand the bandwidth and latency bottleneck of distributed training
๋ถ์ฐ ํ์ต ๋ฐฉ์์๋ ๋ฐ๋์ ๋ฐ๋ผ์ค๋ ๋ฌธ์ ์ ์ด, communication์ ๋ฐ๋ฅธ bottleneck์ด๋ค. ๋ชจ๋ธ์ด ์ปค์ง๋ฉด ๋ ์ฌ๋ฌ๊ณณ์ ๋ถ์ฐํด์ผ ํ๊ณ , ๋ฐ์ดํฐ ํฌ๊ธฐ๋ ์ปค์ง๋ฉฐ, ์ ์ก ์๋๋ ๊ธธ์ด์ง๋ค. ๊ทธ๋ ๊ฒ ๋๋ฉด communication latency๊ฐ ๊ธธ์ด์ง ์ ๋ฐ์ ์๋ค.
์ค์ ๋ก, ๋ถ์ฐ ํ์ต์์ GPU์ ๊ฐ์์ speed๋ ์ ํํ y=x์ ๊ทธ๋ํ๋ฅผ ๊ทธ๋ฆฌ์ง ์๋๋ค. ์ด๋ ์ค๊ฐ์ค๊ฐ์ bottleneck์ผ๋ก ์์ฉํ๋ ๋ถ๋ถ์ด ์์์ ๋ณด์ฌ์ค๋ค.
๊ทธ๋ฆฌ๊ณ ์ด latency๋ node๊ฐ ๊ฑฐ๋ฆฌ๊ฐ ๋ฉ์๋ก ๋น์ฐํ ํจ์ฌ ๋์ด๋๋ค.
10. Gradient compression: overcome the bandwidth bottleneck
๊ทธ๋ ๋ค๋ฉด ์ด๋ฐ data์ ์ก์์ bottleneck์ ์ค์ด๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ ๊น?
์ฐ๋ฆฌ๊ฐ ์ด์ ์ ํด์๋ pruning, quantization ๋ฐฉ์์ ๋๊ฐ์ด ์ ์ฉํ ์ ์๋ค. pruning/quantization์ผ๋ก ๋ฐ์ดํฐ ํฌ๊ธฐ๋ฅผ ์ค์ธ๋ค๋ฉด, ๋น์ฐํ ์ ์กํ ๋ฐ์ดํฐ ์๋ ์ค์ด๋ค๊ณ bottleneck๋ ์ค์ด๋ ๋ค.
10.1. Gradient Pruning: Sparse Communication, Deep Gradient Compression
pruning ๋ฐฉ์์ ๊ฐ๋จํ๊ฒ gradient์ค์์ top-k๊ฐ๋ง ์ ์กํ๋ ๋ฐฉ์์ด๋ค. top-k์ ๊ธฐ์ค์ ๋จ์ํ magnitude๋ฅผ ์ฌ์ฉํ๋ค. ์ ์ก๋์ง ์์ gradient๋ error feedback์ ํตํด local์ ๋ด๋น๋์ด ์ฌ์ฉํ๋ค.
ํ์ง๋ง ์ด๋ฐ ๋ฐฉ์์ ๊ฝค๋ ์ฑ๋ฅ ์ ํ๋ฅผ ์ผ์ผํจ๋ค. gradient๋ง ์ฌ์ฉํ๋ฏ๋ก, momentum์ด ์๊ธฐ ๋๋ฌธ์ด๋ค.
๋จ์ํ accumulateํ๋ฉด ์์ ๊ฐ์ด momentum์ ์ฌ์ฉํ๋ ๋ฐฉ์๊ณผ ๊ฝค ๋ฌ๋ผ์ง๊ฒ ๋๋ค.
๋ฐ๋ผ์ gradient๊ฐ ์๋๋ผ velocity๋ฅผ accumulateํ๋ ๊ฒ์ด ๋ ์ข๋ค. ๋ํ ์ฌ๋ฌ๊ฐ์ง warm up ๋ฐฉ์์ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข๋ค.
learning rate๋ ๋ฌผ๋ก ์ด๊ณ , pruning sparsity๋ warm up ๋ฐฉ์์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋์์ด ๋๋ค. optimizer๊ฐ ์ ์ํ๋๋ฐ ๋์์ ์ค๋ค.
pruning ๋ฐฉ์์ ๋ฌธ์ ์ ์, ์ฌ๋ฌ node๋ผ๋ฆฌ ์ ๋ณด๋ฅผ ๊ตํํ๋ all-reduce ๊ณผ์ ์ ๊ฑฐ์น๋ฉด์ pruning ํ๋ ์๋ฏธ๊ฐ ์์ด์ง๋ค๋ ๊ฒ์ด๋ค(๋ denseํด์ง๋ค)
์ด๋ฐ ๋ฐฉ์์ ํด๊ฒฐํ๊ธฐ ์ํด sparseํ๊ฒ ๋ง๋๋ ๊ฒ์ด ์๋๋ผ low rank matrix๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ์์ ์ฌ์ฉํ๊ธฐ๋ ํ๋ค.
10.2. Gradient Quantization: 1-Bit SGD, TernGrad
quantization ๋ฐฉ๋ฒ์ ์๊ฐํ๋ฉด, 1bit-SGD ๋ฐฉ์์ด ์๋ค. ์ด๋ 0๋ณด๋ค ํฌ๋ฉด +, ์๋๋ฉด -๋ก ๋ ๋ค scaling factor(u1~u4)๋ฅผ colum๋ง๋ค ์ ์ฉํ๋ ๋ฐฉ์์ด๋ค. quantization error๋ locallyํ๊ฒ ์ง๊ณ๋๊ณ , quantize๋ gradient๋ง ์ ์ก๋๋ค.
๋น์ทํ ๋ฐฉ์์ผ๋ก, threshold๋ฅผ ํน์ ๊ฐ์ผ๋ก ๋๊ณ quantizeํ๋ ๋ฐฉ์๋ ์๋ค.
Terngrad๋ ํ๋ฅ ๊ฐ์ ๋ฐ๋ผ 0, 1, -1 ์ค ํ๋๋ก ์์ํ ํ๋ ๋ฐฉ์์ด๋ค.
11. Delayed gradient update: overcome the latency bottleneck
pruning์ด๋ quantize ๋ฐฉ์์ผ๋ก bandwith๋ฌธ์ ๋ ํด๊ฒฐํ์ง๋ง, latency๋ ์ด๋ฐ ๋ฐฉ๋ฒ๋ค๋ก๋ ํด๊ฒฐํ ์ ์๋ค.
๊ฑฐ๋ฆฌ๋ ์ ํธ ํผ์ก๊ฐ์ ๋ฌผ๋ฆฌ์ ์ธ ์ด์ ๊ฐ ์๋ค.
๊ธฐ๋ณธ์ ์ธ ๋ถ์ฐ ์ฒ๋ฆฌ ๋ฐฉ์์์๋, ๊ณ์ฐ->์ ์ก->๊ณ์ฐ ๋ฐฉ์์ผ๋ก ์ด๋ฃจ์ด์ง๋๋ฐ, ์ด๋ฐ ๊ตฌ์กฐ์์๋ ์ ์ก ์๊ฐ์ด ๋์ด๋๋ฉด ๊ทธ ์ฆ์ ์ ์ฒด ์๊ฐ์ด ๋์ด๋๋ค
๊ทธ๋์ ์ ์ก๊ณผ ๊ณ์ฐ์ ๋์์ ์งํํ๋ฉด, ์ด๋์ ๋ ์๊ฐ์ ์ ์ฝํ ์ ์๋ค.
์ ์ก ์๊ฐ์ด ์ด๋์ ๋ ๋์ด๋๋๋ผ๋, training์ ๋ฐฉํดํ์ง ์๋๋ค.
์ฌ๊ธฐ๊น์ง tinyML 17,18๊ฐ์ธ distributed learning์ ๊ดํ ์ ๋ฆฌ์๋ค. ์ค์ ๋ก ๋ชจ๋ธ์ด ์ ์ ์ปค์ง๋ ๋งํผ, ํน์ edge ํ๊ฒฝ์์ ๊ฐ๊ฐ์ธ์ ๋ฐ์ดํฐ๋ก ํ์ต์ ํ๋ ๊ฒฝ์ฐ์๋ ๋ถ์ฐ ๋ฐฉ์์ด ์ ์ฉํ๊ฒ ์ฌ์ฉ๋๋ค. ํ์ง๋ง ๋ถ์ฐ ๋ฐฉ์์์๋ ๊ฐ์ธ์ ๋ฐ์ดํฐ๊ฐ ์ ์ถ๋ ์ํ์ด ์กด์ฌํ๊ธฐ๋ ํ๋ค. ๊ทธ๋ฐ ๋ฐฉ๋ฒ์ด ๋ฌด์์ธ์ง, ์ด๋ป๊ฒ ๋ง์ ์ ์๋์ง๋ ๋ค์ ๊ฐ์์์ ๋ค๋ฃจ๋, ๊ด์ฌ์๋ ๋ถ์ ๋ค์ ํฌ์คํ ๋ ์ฐพ์๋ณด์๊ธธ ๋ฐ๋๋ค.