๐งโ๐ซ Lecture 4
์ด์ ํฌ์คํ
์์ Pruning์ ๋ํด์ ๋ฐฐ์ ์๋ค. ์ด๋ฒ์๋ Pruning์ ๋ํ ๋จ์ ์ด์ผ๊ธฐ์ธ Pruning Ratio๋ฅผ ์ ํ๋ ๋ฐฉ๋ฒ
, Fine-tuning ๊ณผ์
์ ๋ํด ์์๋ณด๊ณ , ๋ง์ง๋ง์ผ๋ก Sparsity๋ฅผ ์ํ System Support
์ ๋ํด ์์๋ณด๊ณ ์ ํ๋ค.
1. Pruning Ratio
Pruning์ ํ๊ธฐ ์ํด์ ์ด๋ ์ ๋ Pruning์ ํด์ผ ํ ์ง ์ด๋ป๊ฒ ์ ํด์ผ ํ ๊น?
์ฆ, ๋ค์ ๋งํด์ ๋ช % ์ ๋ ๊ทธ๋ฆฌ๊ณ ์ด๋ป๊ฒ Pruning์ ํด์ผ ์ข์๊น?
์ฐ์ Channel ๋ณ Pruning์ ํ ๋, Channel ๊ตฌ๋ถ ์์ด ๋์ผํ Pruning ๋น์จ(Uniform
)์ ์ ์ฉํ๋ฉด ์ฑ๋ฅ์ด ์ข์ง ์๋ค. ์ค๋ฅธ์ชฝ ๊ทธ๋ํ์์ ์งํฅํด์ผ ํ๋ ๋ฐฉํฅ์ Latency๋ ์ ๊ฒ, Accuracy๋ ๋๊ฒ์ด๋ฏ๋ก ์ผ์ชฝ ์๋จ์ ์์ญ์ด ๋๋๋ก Pruning์ ์งํํด์ผ ํ๋ค. ๊ทธ๋ ๋ค๋ฉด ๊ฒฐ๋ก ์ Channel ๋ณ ๊ตฌ๋ถ์ ํด์ ์ด๋ค Channel์ Pruning ๋น์จ์ ๋๊ฒ, ์ด๋ค Channel์ Pruning ๋น์จ์ ๋ฎ๊ฒ
ํด์ผ ํ๋ค๋ ์ด์ผ๊ธฐ๊ฐ ๋๋ค.
1.1 Sensitiviy Analysis
Channel ๋ณ ๊ตฌ๋ถ์ ํด์ Pruning์ ํ๋ค๋ ๊ธฐ๋ณธ ์์ด๋์ด๋ ์๋์ ๊ฐ๋ค.
- Accuracy์ ์ํฅ์ ๋ง์ด ์ฃผ๋ Layer๋ Pruning์ ์ ๊ฒ ํด์ผ ํ๋ค.
- Accuracy์ ์ํฅ์ ์ ๊ฒ ์ฃผ๋ Layer๋ Pruning์ ๋ง์ด ํด์ผ ํ๋ค.
Accuracy๋ฅผ ๋๋๋ก์ด๋ฉด ์๋์ ๋ชจ๋ธ๋ณด๋ค ๋ ๋จ์ด์ง๊ฒ ๋ง๋ค๋ฉด์ Pruning์ ํด์ ๋ชจ๋ธ์ ๊ฐ๋ณ๊ฒ ๋ง๋๋ ๊ฒ์ด ๋ชฉํ์ด๊ธฐ ๋๋ฌธ์ ๋น์ฐํ ์์ด๋์ด์ผ ๊ฒ์ด๋ค. Accuracy์ ์ํฅ์ ๋ง์ด ์ค๋ค๋ ๋ง์ Sensitiveํ Layer์ด๋ค๋ผ๋ ํํ์ผ๋ก ๋ค๋ฅด๊ฒ ๋งํ ์ ์๋ค. ๋ฐ๋ผ์ ๊ฐ Layer์ Senstivity๋ฅผ ์ธก์ ํด์ Sensitiveํ Layer๋ Pruning Ratio๋ฅผ ๋ฎ๊ฒ ์ค๊ณํ๋ฉด ๋๋ค.
Layer์ Sensitivity๋ฅผ ์ธก์ ํ๊ธฐ ์ํด Sensitivity Analysis๋ฅผ ์งํํด๋ณด์. ๋น์ฐํ ํน์ Layer์ Pruning Ratio๊ฐ ๋์ ์๋ก weight๊ฐ ๋ง์ด ๊ฐ์ง์น๊ธฐ ๋ ๊ฒ์ด๋ฏ๋ก Accuracy๋ ๋จ์ด์ง๊ฒ ๋๋ค.
Pruning Ratio์ ์ํด Pruned ๋๋ weight๋ ์ด์ ๊ฐ์์์ ๋ฐฐ์ด โImportance(weight์ ์ ๋๊ฐ ํฌ๊ธฐ)โ์ ๋ฐ๋ผ ์ ํ๋๋ค.
์์ ๊ทธ๋ฆผ์์ ์ฒ๋ผ Layer 0(L0)
๋ง์ ๊ฐ์ง๊ณ Pruning Ratio๋ฅผ ๋์ฌ๊ฐ๋ฉด์ ๊ด์ฐฐํด๋ณด๋ฉด, ์ฝ 70% ์ดํ๋ถํฐ๋ Accuracy๊ฐ ๊ธ๊ฒฉํ๊ฒ ๋จ์ด์ง๋ ๊ฒ์ ๋ณผ ์ ์๋ค. L0
์์ Ratio๋ฅผ ๋์ฌ๊ฐ๋ฉฐ Accuracy์ ๋ณํ๋ฅผ ๊ด์ฐฐํ ๊ฒ์ฒ๋ผ ๋ค๋ฅธ Layer๋ค๋ ๊ด์ฐฐํด๋ณด์.
L1
์ ๋ค๋ฅธ Layer๋ค์ ๋นํด ์๋์ ์ผ๋ก Pruning Ratio๋ฅผ ๋์ฌ๊ฐ๋ Accuracy์ ๋จ์ด์ง๋ ์ ๋๊ฐ ์ฝํ ๋ฐ๋ฉด, L0
๋ ๋ค๋ฅธ Layer๋ค์ ๋นํด ์๋์ ์ผ๋ก Pruning Ratio๋ฅผ ๋์ฌ๊ฐ๋ฉด Accuracy์ ๋จ์ด์ง๋ ์ ๋๊ฐ ์ฌํ ๊ฒ์ ํ์ธํ ์ ์๋ค. ๋ฐ๋ผ์ L1
์ Sensitivity๊ฐ ๋๋ค๊ณ ๋ณผ ์ ์์ผ๋ฉฐ Pruning์ ์ ๊ฒํด์ผ ํ๊ณ , L0
์ Sensitivity๊ฐ ๋ฎ๋ค๊ณ ๋ณผ ์ ์์ผ๋ฉฐ Pruning์ ๋ง๊ฒํด์ผ ํจ์ ์ ์ ์๋ค.
์ฌ๊ธฐ์ Sensitivity Analysis์์ ๊ณ ๋ คํด์ผํ ๋ช๊ฐ์ง ์ฌํญ๋ค์ ๋ํด์ ์ง๊ณ ๋์ด๊ฐ์.
- Sensitivity Analysis์์ ๋ชจ๋ Layer๋ค์ด ๋
๋ฆฝ์ ์ผ๋ก ์๋ํ๋ค๋ ๊ฒ์ ์ ์ ๋ก ํ๋ค. ์ฆ,
L0
์ Pruning์ดL1
์ ํจ๊ณผ์ ์ํฅ์ ์ฃผ์ง ์๋ ๋ ๋ฆฝ์ฑ์ ๊ฐ์ง๋ค๋ ๊ฒ์ ์ ์ ๋ก ํ๋ค. - Layer์ Pruning Ratio๊ฐ ๋์ผํ๋ค๊ณ ํด์ Pruned Weight์๊ฐ ๊ฐ์์ ์๋ฏธํ์ง ์๋๋ค.
- 100๊ฐ์ weight๊ฐ ์๋ layer์ 10% Pruning Ratio ์ ์ฉ์ 10๊ฐ์ weight๊ฐ pruned ๋์์์ ์๋ฏธํ๊ณ , 500๊ฐ์ weight๊ฐ ์๋ layer์ 10% Pruning Ratio ์ ์ฉ์ 50๊ฐ์ weight๊ฐ pruned ๋์์์ ์๋ฏธํ๋ค.
- Layer์ ์ ์ฒด ํฌ๊ธฐ์ ๋ฐ๋ผ Pruning Ratio์ ์ ์ฉ ํจ๊ณผ๋ ๋ค๋ฅผ ์ ์๋ค.
Sensitivity Analysis๊น์ง ์งํํ ํ์๋ ๋ณดํต ์ฌ๋์ด Accuracy๊ฐ ๋จ์ด์ง๋ ์ ๋, threshold๋ฅผ ์ ํด์ Pruning Ratio๋ฅผ ์ ํ๋ค.
์ ๊ทธ๋ํ์์๋ Accuracy๊ฐ ์ฝ 75%์์ค์ผ๋ก ์ ์ง๋๋ threhsold \(T\) ์ํ์ ์ ๊ธฐ์ค์ผ๋ก L0
๋ ์ฝ 74%, L4
๋ ์ฝ 80%, L3
๋ ์ฝ 82%, L2
๋ 90%๊น์ง Pruning์ ์งํํด์ผ ๊ฒ ๋ค๊ณ ์ ํ ์์๋ฅผ ๋ณด์ฌ์ค๋ค. ๋ฏผ๊ฐํ layer์ธ L0
๋ ์๋์ ์ผ๋ก Pruning์ ์ ๊ฒ
, ๋ ๋ฏผ๊ฐํ layer์ธ L2
๋ Pruning์ ๋ง๊ฒ
ํ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
๋ฌผ๋ก ์ฌ๋์ด ์ ํ๋ threshold๋ ๊ฐ์ ์ ์ฌ์ง๊ฐ ๋ฌผ๋ก ์๋ค. Pruning Ratio๋ฅผ ์ข ๋ Automaticํ๊ฒ ์ฐพ๋ ๋ฐฉ๋ฒ์ ๋ํด ์์๋ณด์.
1.2 AMC
AMC๋ AutoML for Model Compression์ ์ฝ์๋ก, ๊ฐํํ์ต(Reinforcement Learning) ๋ฐฉ๋ฒ์ผ๋ก ์ต์ ์ Pruning Ratio๋ฅผ ์ฐพ๋๋ก ํ๋ ๋ฐฉ๋ฒ์ด๋ค.
AMC์ ๊ตฌ์กฐ๋ ์ ๊ทธ๋ฆผ๊ณผ ๊ฐ๋ค. ๊ฐํํ์ต ์๊ณ ๋ฆฌ์ฆ ๊ณ์ด ์ค, Actor-Critic ๊ณ์ด์ ์๊ณ ๋ฆฌ์ฆ์ธ Deep Deterministic Policy Gradient(DDPG)์ ํ์ฉํ์ฌ Pruning Ratio๋ฅผ ์ ํ๋ Action์ ์ ํํ๋๋ก ํ์ตํ๋ค. ์์ธํ MDP(Markov Decision Process) ์ค๊ณ๋ ์๋์ ๊ฐ๋ค.
๊ฐํํ์ต Agent์ ํ์ต ๋ฐฉํฅ์ ๊ฒฐ์ ํ๋ ์ค์ํ Reward Function์ ๋ชจ๋ธ์ Accuracy๋ฅผ ๊ณ ๋ คํด์ Error
๋ฅผ ์ค์ด๋๋ก ์ ๋ํ ๋ฟ๋ง ์๋๋ผ Latency๋ฅผ ๊ฐ์ ์ ์ผ๋ก ๊ณ ๋ คํ ์ ์๋๋ก ๋ชจ๋ธ์ ์ฐ์ฐ๋์ ๋ํ๋ด๋ FLOP
๋ฅผ ์ ๊ฒ ํ๋๋ก ์ ๋ํ๋๋ก ์ค๊ณํ๋ค. ์ค๋ฅธ์ชฝ์ ๋ชจ๋ธ๋ค์ ์ฐ์ฐ๋ ๋ณ(Operations
) Top-1 Accuracy
๊ทธ๋ํ๋ฅผ ๋ณด๋ฉด ์ฐ์ฐ๋์ด ๋ง์์๋ก ๋ก๊ทธํจ์์ฒ๋ผ Accuracy๊ฐ ์ฆ๊ฐํ๋ ๊ฒ์ ๋ณด๊ณ ์ด๋ฅผ ๋ณด๊ณ ๋ฐ์ํ ๋ถ๋ถ์ด๋ผ๊ณ ๋ณผ ์ ์๋ค.
์ด๋ ๊ฒ AMC๋ก Pruning์ ์งํํ์ ๋, Human Expert๊ฐ Pruning ํ ๊ฒ๊ณผ ๋น๊ตํด๋ณด์. ์๋ ๋ชจ๋ธ ์น์
๋ณ Density ํ์คํ ๊ทธ๋จ ๊ทธ๋ํ์์ Total
์ ๋ณด๋ฉด, ๋์ผ Accuracy๊ฐ ๋์ค๋๋ก Pruning์ ์งํํ์ ๋ AMC๋ก Pruning์ ์งํํ ๊ฒ(์ฃผํฉ์)์ด Human Expert Pruning ๋ชจ๋ธ(ํ๋์)๋ณด๋ค Density๊ฐ ๋ฎ์ ๊ฒ์ ํ์ธํ ์ ์๋ค. ์ฆ, AMC๋ก Pruning ์งํํ์ ๋ ๋ ๋ง์ weight๋ฅผ Pruning ๋ ๊ฐ๋ฒผ์ด ๋ชจ๋ธ์ ๊ฐ์ง๊ณ ๋ Accuracy๋ฅผ ์ ์งํ๋ค๊ณ ๋ณผ ์ ์๋ค.
๋๋ฒ์งธ ๊บพ์ ์ ๊ทธ๋ํ์์ AMC๋ฅผ ๊ฐ์ง๊ณ Pruning๊ณผ Fine-tuning์ ๋ฒ๊ฐ์ ๊ฐ๋ฉฐ ์ฌ๋ฌ ์คํ
์ผ๋ก ์งํํด๊ฐ๋ฉด์ ๊ด์ฐฐํ ๊ฒ์ ์กฐ๊ธ ๋ ์์ธํ ์ดํด๋ณด์. ๊ฐ Iteration(Pruning+Fine-tuning)์ stage1, 2, 3, 4๋ก ๋ํ๋ด์ด plotํ ๊ฒ์ ๋ณด๋ฉด, 1x1 conv
๋ณด๋ค 3x3 conv
์์ Density๊ฐ ๋ ๋ฎ์ ๊ฒ์ ํ์ธํ ์ ์๋ค. ์ฆ, 3x3 conv
์์ 1x1 conv
๋ณด๋ค Pruning์ ๋ง์ด ํ ๊ฒ์ ๋ณผ ์ ์๋ค. ์ด๋ฅผ ํด์ํด๋ณด์๋ฉด, AMC๊ฐ 3x3 conv
์ Pruningํ๋ฉด 9๊ฐ์ weight๋ฅผ pruningํ๊ณ ์ด๋ 1x1 conv
pruningํด์ 1๊ฐ์ weight๋ฅผ ์์ ๋ ๊ฒ๋ณด๋ค ํ๋ฒ์ ๋ ๋ง์ weight ์๋ฅผ ์ค์ผ ์ ์๊ธฐ ๋๋ฌธ์ 3x3 conv
pruning์ ์ ๊ทน ํ์ฉํ์ ๊ฒ์ผ๋ก ๋ณผ ์ ์๋ค.
์ด AMC ์คํ ๊ฒฐ๊ณผํ์์ ๋ณด๋ฉด, FLOP์ Time ๊ฐ๊ฐ 50%๋ก ์ค์ธ AMC ๋ชจ๋ธ ๋๋ค Top-1 Accuracy๊ฐ ๊ธฐ์กด์ 1.0 MobileNet
์ Accuracy๋ณด๋ค ์ฝ 0.1~0.4% ์ ๋๋ง ์ค๊ณ Latency๋ SpeedUp์ด ํจ์จ์ ์ผ๋ก ์กฐ์ ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
0.75 MobileNet
์ SpeedUp์ด ์ 1.7x ์ธ๊ฐ์?
๊ฒฐ๊ณผํ์์ 0.75 MobileNet
์ 25%์ weight๋ฅผ ๊ฐ์์ํจ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ SpeedUp์ด \(\frac{4}{3} \simeq 1.3\)x์ด์ด์ผ ํ๋ค๊ณ ์๊ฐํ ์ ์๋ค. ํ์ง๋ง ์ฐ์ฐ๋์ quadraticํ๊ฒ ๊ฐ์ํ๊ฒ ๋๊ธฐ ๋๋ฌธ์ \(\frac{4}{3} \cdot \frac{4}{3} \simeq 1.7\)x๋ก SpeedUp์ด ๋๋ค.
1.3 NetAdapt
๋ ๋ค๋ฅธ Pruning Ratio๋ฅผ ์ ํ๋ ๊ธฐ๋ฒ์ผ๋ก NetAdapt์ด ์๋ค. Latency Constraint๋ฅผ ๊ฐ์ง๊ณ layer๋ง๋ค pruning์ ์ ์ฉํด๋ณธ๋ค. ์๋ฅผ ๋ค์ด, ์ค์ผ ๋ชฉํ latency ๋์ lms
๋ก ์ ํ๋ฉด, 10ms
โ 9ms
๋ก ์ค ๋๊น์ง layer์ pruning ratio๋ฅผ ๋์ฌ๊ฐ๋ ๋ฐฉ๋ฒ์ด๋ค.
NetAdapt์ ์ ์ฒด์ ์ธ ๊ณผ์ ์ ์๋์ ๊ฐ์ด ์งํ๋๋ค. ๊ธฐ์กด ๋ชจ๋ธ์์ ๊ฐ layer๋ฅผ Latency Constraint์ ๋๋ฌํ๋๋ก Pruningํ๋ฉด์ Accuracy(\(Acc_A\)๋ฑ)์ ๋ฐ๋ณต์ ์ผ๋ก ์ธก์ ํ๋ค.
- ๊ฐ layer์ pruning ratio๋ฅผ ์กฐ์ ํ๋ค.
Short-term
fine tuning์ ์งํํ๋ค.- Latency Constraint์ ๋๋ฌํ๋์ง ํ์ธํ๋ค.
- Latency Constraint ๋๋ฌํ๋ฉด ํด๋น layer์ ์ต์ ์ Pruning ratio๋ก ํ๋จํ๋ค.
- ๊ฐ layer์ ์ต์ Pruning ratio๊ฐ ์ ํด์ก๋ค๋ฉด ๋ง์ง๋ง์ผ๋ก
Long-term
fine tuning์ ์งํํ๋ค.
์ด์ ๊ฐ์ด NetAdapt์ ๊ณผ์ ์ ์งํํ๋ฉด ์๋์ ๊ฐ์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณผ ์ ์๋ค. Uniformํ๊ฒ Pruning์ ์งํํ Multipilers
๋ณด๋ค NetAdapt
๊ฐ 1.7x ๋ ๋น ๋ฅด๊ณ ์คํ๋ ค Accuracy๋ ์ฝ 0.3% ์ ๋ ๋์ ๊ฒ์ ์ ์ ์๋ค.
2. Fine-tuning/Train
Prunned ๋ชจ๋ธ์ ํผํฌ๋จผ์ค๋ฅผ ํฅ์ํ๊ธฐ ์ํด์๋ Pruning๋ฅผ ์งํํ๊ณ ๋์ Fine-tuning ๊ณผ์ ์ด ํ์ํ๋ค.
2.1 Iterative Pruning
๋ณดํต Pruned ๋ชจ๋ธ์ Fine-tuning ๊ณผ์ ์์๋ ๊ธฐ์กด์ ํ์ตํ๋ learning rate๋ณด๋ค ์์ rate๋ฅผ ์ฌ์ฉํ๋ค. ์๋ฅผ๋ค์ด ๊ธฐ์กด์ ๋ชจ๋ธ์ ํ์ตํ ๋ ์ฌ์ฉํ learning rate์ \(1/100\) ๋๋ \(1/10\)์ ์ฌ์ฉํ๋ค. ๋ํ Pruning ๊ณผ์ ๊ณผ Fine-tuning ๊ณผ์ ์ 1๋ฒ๋ง ์งํํ๊ธฐ๋ณด๋ค ์ ์ฐจ์ ์ผ๋ก pruning ratio๋ฅผ ๋๋ ค๊ฐ๋ฉฐ Pruning, Fine-tuning์ ๋ฒ๊ฐ์๊ฐ๋ฉฐ ์ฌ๋ฌ๋ฒ ์งํํ๋๊ฒ ๋ ์ข๋ค.
2.2 Regularization
TinyML์ ๋ชฉํ๋ ๊ฐ๋ฅํ ๋ง์ weight๋ค์ 0์ผ๋ก ๋ง๋๋ ๊ฒ์ผ๋ก ์๊ฐํ ์ ์๋ค. ๊ทธ๋์ผ ๋ชจ๋ธ์ ๊ฐ๋ณ๊ฒ ๋ง๋ค ์ ์๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ๋์ Regularization ๊ธฐ๋ฒ์ ์ด์ฉํด์ ๋ชจ๋ธ์ weight๋ค์ 0์ผ๋ก, ํน์ 0๊ณผ ๊ฐ๊น๊ฒ ์์ ๊ฐ์ ๊ฐ์ง๋๋ก ๋ง๋ ๋ค. ์์ ๊ฐ์ weight๊ฐ ๋๋๋ก ํ๋ ์ด์ ๋ 0๊ณผ ๊ฐ๊น์ด ์์ ๊ฐ๋ค์ ๋ค์ layer๋ค๋ก ๋์ด๊ฐ๋ฉด์ 0์ด ๋ ๊ฐ๋ฅ์ฑ์ด ๋์์ง๊ธฐ ๋๋ฌธ์ด๋ค. ๊ธฐ์กด์ ๋ฅ๋ฌ๋ ๋ชจ๋ธ๋ค์ ๊ณผ์ ํฉ(Overfitting)์ ๋ง๊ธฐ ์ํ Regularization ๊ธฐ๋ฒ๋ค๊ณผ ๋ค๋ฅด์ง ์์ผ๋ ์๋์ ๋ชฉ์ ์ ๋ค๋ฅธ ๊ฒ์ ์ง์ด๋ณผ ์ ์๋ค.
2.3 The Lottery Ticket Hypothesis
2019๋
ICLR์์ ๋ฐํ๋ ๋
ผ๋ฌธ์์ Jonathan Frankle๊ณผ Michael Carbin์ด ์๊ฐํ The Lottery Ticket Hypothesis(LTH)์ ์ฌ์ธต ์ ๊ฒฝ๋ง(DNN) ํ๋ จ์ ๋ํ ํฅ๋ฏธ๋ก์ด ์์ด๋์ด๋ฅผ ์ ์ํ๋ค. ๋ฌด์์๋ก ์ด๊ธฐํ๋ ๋๊ท๋ชจ ์ ๊ฒฝ๋ง ๋ด์ ๋ ์์ ํ์ ๋คํธ์ํฌ(Winning Ticket
)๊ฐ ์กด์ฌํ๋ค๋ ๊ฒ์ ๋งํ๋ค. ์ด ํ์ ๋คํธ์ํฌ๋ ์ฒ์๋ถํฐ ๋ณ๋๋ก ํ๋ จํ ๋ ์๋ ๋คํธ์ํฌ์ ์ฑ๋ฅ์ ๋๋ฌํ๊ฑฐ๋ ๋ฅ๊ฐํ ์ ์๋ค. ์ด ๊ฐ์ค์ ์ด๋ฌํ Winning Ticket
์ด ํ์ตํ๋ ๋ฐ ์ ํฉํ ์ด๊ธฐ ๊ฐ์ค์น๋ฅผ ๊ฐ๋๋ค๊ณ ๊ฐ์ ํ๋ค.
3. System Support for Sparsity
DNN์ ๊ฐ์ํ ์ํค๋ ๋ฐฉ๋ฒ์ ํฌ๊ฒ 3๊ฐ์ง, Sparse Weight, Sparse Activation, Weight Sharing์ด ์๋ค. Sparse Weight, Sparse Activation์ Pruning์ด๊ณ Weight Sharing์ Quantization์ ๋ฐฉ๋ฒ์ด๋ค.
- Sparse Weight: Weight๋ฅผ Pruningํ์ฌ Computation์ Pruning Ratio์ ๋์ํ์ฌ ๋นจ๋ผ์ง๋ค. ํ์ง๋ง Memory๋ Pruning๋ weight์ ์์น๋ฅผ ๊ธฐ์ตํ๊ธฐ ์ํ memory ์ฉ๋์ด ํ์ํ๋ฏ๋ก Pruning Ratio์ ๋น๋กํ์ฌ ์ค์ง ์๋๋ค.
- Sparse Activation: Weight๋ฅผ Pruningํ๋ ๊ฒ๊ณผ ๋ค๋ฅด๊ฒ Activation์ Test Input์ ๋ฐ๋ผ dynamic ํ๋ฏ๋ก Weight๋ฅผ Pruningํ๋ ๊ฒ๋ณด๋ค Computation์ด ๋ ์ค์ด๋ ๋ค.
- Weight Sharing: Quantization ๋ฐฉ๋ฒ์ผ๋ก 32-bit data๋ฅผ 4-bit data๋ก ๋ณ๊ฒฝํจ์ผ๋ก์จ 8๋ฐฐ์ memory ์ ์ฝ์ ํ ์ ์๋ค.
3.1 EIE
Efficient Inference Engine์ ๊ธฐ๊ณ ํ์ต ๋ชจ๋ธ์ ์ค์๊ฐ์ผ๋ก ์คํํ๊ธฐ ์ํด ์ต์ ํ๋ ์ํํธ์จ์ด ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ํ๋ ์์ํฌ๋ฅผ ๋งํ๋ค. Processing Elements(PE)์ ๊ตฌ์กฐ
์๋ ๊ทธ๋ฆผ์์ Input๋ณ(\(\vec{a}\)) ์ฐ์ฐ์ ์๋์ ๊ฐ์ด Input์ด 0์ผ ๋๋ skip๋๊ณ 0์ด ์๋ ๋๋ prunning ๋์ง ์์ weight์ ์ฐ์ฐ์ด ์งํ๋๋ค.
EIE ์คํ์ ๊ฐ์ฅ loss๊ฐ ์ ์ data ์๋ฃํ์ธ 16 bit Intํ์ ์ฌ์ฉํ๋ค.(0.5% loss) AlexNet
์ด๋ VGG
์ ๊ฐ์ด ReLU Activation์ด ๋ง์ด ์ฌ์ฉ๋๋ ๋ชจ๋ธ๋ค์ ๊ฒฝ๋ํ๊ฐ ๋ง์ด ๋ ๋ฐ๋ฉด, RNN์ LSTM์ด ์ฌ์ฉ๋ NeuralTalk
๋ชจ๋ธ๋ค ๊ฐ์ ๊ฒฝ์ฐ์๋ ReLU๋ฅผ ์ฌ์ฉํ์ง ์์ ๊ฒฝ๋ํ๋ ์ ์๋ ๋ถ๋ถ์ด ์์ด Activation Density๊ฐ 100%์ธ ๊ฒ์ ํ์ธํ ์ ์๋ค.
3.2 M:N Weight Sparsity
์ด ๋ฐฉ๋ฒ์ Nvidia ํ๋์จ์ด์ ์ง์์ด ํ์ํ ๋ฐฉ๋ฒ์ผ๋ก ๋ณดํต 2:4 Weight Sparsity๋ฅผ ์ฌ์ฉํ๋ค. ์ผ์ชฝ์ Sparse Matrix๋ฅผ ์ฌ๋ฐฐ์นํด์ Non-zero data matrix์ ์ธ๋ฑ์ค๋ฅผ ์ ์ฅํ๋ Index matrix๋ฅผ ๋ฐ๋ก ๋ง๋ค์ด์ ์ ์ฅํ๋ค.
M:N Weight Sparsity ์ ์ฉํ์ง ์์ Dense GEMM๊ณผ ์ ์ฉํ Sparse GEMM์ ๊ณ์ฐํ ๋๋ ์๋์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ ๊ณผ์ ์ผ๋ก ์ฐ์ฐ์ด ์งํ๋๋ค.
3.3 Sparse Convolution
Submanifold Sparse Convolutional Networks (SSCN)์ ๊ณ ์ฐจ์ ๋ฐ์ดํฐ์์ ํจ์จ์ ์ธ ๊ณ์ฐ์ ๊ฐ๋ฅํ๊ฒ ํ๋ ์ ๊ฒฝ๋ง ์ํคํ ์ฒ์ ํ ํํ์ด๋ค. ์ด ๊ธฐ์ ์ ํนํ 3D ํฌ์ธํธ ํด๋ผ์ฐ๋ ๋๋ ๊ณ ํด์๋ ์ด๋ฏธ์ง์ ๊ฐ์ด ๋๊ท๋ชจ ๋ฐ ๊ณ ์ฐจ์ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ๋ ์ค์ํ๋ค. SSCN์ ํต์ฌ ์์ด๋์ด๋ ๋ฐ์ดํฐ์ Sparcity์ ํ์ฉํ์ฌ ๊ณ์ฐ๊ณผ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ํฌ๊ฒ ์ค์ด๋ ๊ฒ์ด๋ค.
์ด๋ฌํ Sparse Convolution์ ๊ธฐ๋ณธ Convolution๊ณผ ๋น๊ตํ์ ๋ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ๋ํ๋ด๋ณผ ์ ์๋ค.
์ฐ์ฐ ๊ณผ์ ์ ๋น๊ตํด๋ณด๊ธฐ ์ํด Input Point Cloud(\(P\)), Feature Map(\(W\)), Ouput Point Cloud(\(Q\))๋ฅผ ์๋์ ๊ฐ์ด ์๋ค๊ณ ํ์. ๊ธฐ์กด์ Convolution๊ณผ Sparse Convolution์ ๋น๊ตํด๋ณด๋ฉด ์ฐ์ฐ๋์ด 9:2๋ก ๋งค์ฐ ์ ์ ์ฐ์ฐ๋ง ํ์ํ ๊ฒ์ ์ ์ ์๋ค.
Feature Map(\(W\))์ ๊ธฐ์ค์ผ๋ก ๊ฐ weight ๋ง๋ค ํ์ํ Input data์ ํฌ๊ธฐ๊ฐ ๋ค๋ฅด๋ค. ์๋ฅผ ๋ค์ด \(W_{-1, 0}\)์ \(P1\)๊ณผ ๋ง์ ์ฐ์ฐ์ด ์งํ๋๋ฏ๋ก \(P1\)๋ง ์ฐ์ฐ์ ๋ถ๋ฌ๋ด๊ฒ ๋๋ค.
๋ฐ๋ผ์ Feature Map์ \(W\)์ ๋ฐ๋ผ ํ์ํ Input data๋ฅผ ํํํ๊ณ ๋ฐ๋ก computation์ ์งํํ๋ฉด ์๋์ ๊ฐ์ด ๊ณ ๋ฅด์ง ๋ชปํ ์ฐ์ฐ๋ ๋ถ๋ฐฐ๊ฐ ์งํ๋๋๋ฐ(์ผ์ชฝ ๊ทธ๋ฆผ) ์ด๋ computation์ overhead๋ ์์ง๋ง regularity๊ฐ ์ข์ง ์๋ค. ๋๋ ๊ฐ์ฅ computation์ด ๋ง์ ๊ฒ์ ๊ธฐ์ค์ผ๋ก Batch ๋จ์๋ก ๊ณ์ฐํ๊ฒ ๋๋ค๋ฉด(๊ฐ์ด๋ฐ ๊ทธ๋ฆผ) ์ ์ computation weight์์์ ๋นํจ์จ์ ์ธ ๊ณ์ฐ ๋๊ธฐ์๊ฐ์ด ์๊ธฐ๋ฏ๋ก overhead๊ฐ ์๊ธด๋ค. ๋ฐ๋ผ์ ์ ์ ํ ๋น์ทํ ์ฐ์ฐ๋์ ๊ฐ์ง๋ grouping์ ์งํํ ๋ค batch๋ก ๋ฌถ์ผ๋ฉด ์ ์ ํ computation์ ์งํํ ์ ์๋ค.(์ค๋ฅธ์ชฝ ๊ทธ๋ฆผ)
์ด๋ฐ Grouping์ ์ ์ฉํ ํ Sparse Convolution์ ์งํํ๋ฉด Adaptive Grouping์ด ์ ์ฉ๋์ด ์๋์ ๊ฐ์ด ์งํ๋๋ค.
์ฌ๊ธฐ๊น์ง๊ฐ 2023๋ ๋ ๊ฐ์์์ ๋ง์ง๋ง Sparse Convolution์ ๋ํด ์ค๋ช ํ ๋ถ๋ถ์ ์ ๋ฆฌํ ๋ถ๋ถ์ด๋ค. ํ์ง๋ง ๊ฐ์์์ ์ค๋ช ์ด ๋ง์ด ์๋ต๋์ด ์์ผ๋ฏ๋ก ์ข ๋ ์์ธํ ๋ด์ฉ์ Youtube ๋ฐํ ์์์ด๋ 2022๋ ๋ ๊ฐ์๋ฅผ ์ฐธ๊ณ ํ๋ ๊ฒ์ ๊ถ์ฅํ๋ค.
4. Reference
- MIT-TinyML-lecture04-Pruning-2
- AMC: Automl for Model Compression and Acceleration on Mobile Devices, 2018
- Continuous control with deep reinforcement learning
- FLOPs๋? ๋ฅ๋ฌ๋ ์ฐ์ฐ๋์ ๋ํด์
- NetAdapt: Platform-Aware Neural Network Adaptation for Mobile Applications
- The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks
- The Lottery Ticket Hypothesis Finding Sparse, Trainable Neural Networks ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ
- LLM Inference - HW/SW Optimizations
- Accelerating Sparse Deep Neural Networks
- Submanifold Sparse Convolutional Networks
- TorchSparse: Efficient Point Cloud Inference Engine
- TorchSparse++: Efficient Training and Inference Framework for Sparse Convolution on GPUs
- mit-han-lab/torchsparse
- MLSysโ22 TorchSparse: Efficient Point Cloud Inference Engine