๐งโ๐ซ Lecture 12-13
์ด๋ฒ ๊ธ์์๋ ํ๋ NLP์ ํต์ฌ ๋๊ตฌ์ธ transformer์ LLM, ๊ทธ๋ฆฌ๊ณ ์ด๋ค์ ์ต์ ํํ๋ ์ฌ๋ฌ ๋ฐฉ๋ฒ์ ๋ํด์ ์๊ฐํ๋ค. transformer๋ฅผ ์ต์ ํ ํ๊ธฐ ์ํ ๋ฐฉ๋ฒ์ผ๋ก positional embedding์ ๋ณํ๋ llama2์ ์ฌ์ฉ๋ grouped-query-attention๋ฑ์ ๋ค๋ฃฌ๋ค. ๋ ๋์๊ฐ, LLMs์ ํจ์จ์ ์ธ ์ถ๋ก (inference)์ ์กฐ์ (fine-tuning) ์๊ณ ๋ฆฌ์ฆ๊ณผ ์์คํ ์ ๋ํด์๋ ๋ค๋ฃฌ๋ค. inference์ ๋ํด์๋ vLLM, StreamingLLM, fine-tuning์ LoRA, QLoRA, Adapter๊ธฐ๋ฒ ๋ฑ์ ์๊ฐํ๋ค.
1. Transformer basics
๊ฐ์์์๋ transformer ๊ตฌ์กฐ์ tokenizer, encoding, attention๋ฑ์ ๋ํด ๊ฐ๋ตํ ์๊ฐํ๊ณ ์์ง๋ง ์ด ๊ธ์ transformer/LLM์ ์ต์ ํ์ ๋ํด ๋ค๋ฃจ๊ณ ์๊ณ , transformer ์์ฒด์ ๊ตฌ์กฐ๋ฅผ ๋ค๋ฃจ๋ฉด ๊ธ์ด ๋๋ฌด ๊ธธ์ด์ง ๊ฒ ๊ฐ์์ ์์ธํ ๊ตฌ์กฐ๋ ์๋์ ๋ ๋งํฌ๋ฅผ ์ฐธ๊ณ ๋ถํ๋๋ฆฝ๋๋ค
https://wikidocs.net/31379
https://blogs.nvidia.co.kr/blog/what-is-a-transformer-model/
2. Transformer design variants
์ด๋ฒ ์ฅ์์๋ ์๋ณธ transformer(attention is all you need) ์ดํ์ transformer๋ฅผ ๋ฐ์ ์ํจ ์ฌ๋ฌ ๊ธฐ๋ฒ์ ๋ํด ๋ค๋ฃฌ๋ค.
Absolute positional encoding -> Relative positional encoding
์๋ณธ transformer ๋ชจ๋ธ์์๋ positional embedding์ผ๋ก sinusoid embedding์ ์ฌ์ฉํ๋ค. ์ด๋ position๋ง๋ค ๋ ๋ฆฝ์ ์ด๋ฉด์๋ ์ฐ์์ ์ธ embedding vector๋ฅผ ๋ง๋ค์ด๋ธ๋ค.
๊ทธ๋ฌ๋ ์ด๋ฐ index์ ์์กด์ ์ธ absolute positional encoding ๋ฐฉ์์๋, training์ค์ ๋ณด์ง ๋ชปํ ๊ธธ์ด๋ฅผ ๋์ํ๊ธฐ ์ด๋ ต๋ค๋ ๋ฌธ์ ์ ์ด ์๋ค. ์๋ฅผ ๋ค์ด, 250 token๊น์ง๋ง ํ์ตํ๋๋ฐ 251 token์ ๋ฐ์ดํฐ๊ฐ ๋ค์ด์ค๋ ์ํฉ ๋ง์ด๋ค.
Relative positional encoding์ ์ฌ์ฉํ๋ฉด train short, test long์ ๋ฌ์ฑํ ์ ์๋ค. ์ด์ธ์๋ absolute positional encoding์ ์์น ์ ๋ณด๋ฅผ input embedding์ ๋ํด Q/K/V ์ ์ฒด์ ์ํฅ์ ๋ฏธ์น์ง๋ง, relative positional encoding์ Q,K์ bias๋ฅผ ๋ํ๋ ๋ฐฉ์์ผ๋ก attention score์ ์ํฅ์ ์ค๋ค(V์๋ ์ํฅ์ ๋ฏธ์น์ง ์๋๋ค)
Attention with Linear Biases (ALiBi)
๊ฐ์ฅ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ผ๋ก๋ ALiBi๊ฐ ์๋ค. ์ด ๋ฐฉ๋ฒ์ ๋จ์ํ attention matrix์ query์ key์ ๊ฑฐ๋ฆฌ์ ๋ํ offset์ ๋ํด์ค๋ค.
Rotary Positional Embedding (RoPE)
๋ค๋ฅธ ๋ฐฉ๋ฒ์ผ๋ก๋ llama2์๋ ์ฌ์ฉ๋ ์ ๋๋ก ๋๋ฆฌ ์ฌ์ฉ๋๊ณ ์๋ RoPE์ด๋ค. RoPE์ ์์ด๋์ด๋ ์ ๋ ฅ ๋ฐ์ดํฐ์ ์์น ์ ๋ณด๋ฅผ ํ์ (rotation)์ ํตํด ์ธ์ฝ๋ฉํ๋ ๊ฒ์ด๋ค. d์ฐจ์์ word embedding์ผ๋ก d/2๊ฐ์ ์ง์ ๋ง๋ค๊ณ , ๊ฐ๊ฐ์ pair๋ฅผ 2d ์ขํ๋ก ๊ฐ์ ํ๊ณ , position์ ๋ฐ๋ผ ํ์ ์ํค๋ ๊ฒ์ด๋ค. ์ ๊ทธ๋ฆผ์์ x1,x2์ m * theta๋ฅผ ๊ณฑํด์ ํ์ ์ํค๊ณ ์๋ค. RoPE๋ฅผ ์์์ผ๋ก ๋ํ๋ด๋ฉด ์์ ๊ฐ๋ค
LLM์ ๋ณดํต ํ์ตํ ๋ context ๊ธธ์ด์ ์ ํ์ด ์๋ค. ์๋ฅผ๋ค์ด llama๋ 2k, llama2๋ 4k๋ก ์ ํ๋ ๋ฐ์ดํฐ๋ก ํ์ตํ๋ค. ๊ทธ๋ฌ๋ RoPE ๋ฐฉ์ ๋์ ๋ ํฐ context ๊ธธ์ด๋ ๋ค๋ฃฐ ์ ์๋ค. ๋ ์์ theta๋ฅผ ์ฌ์ฉํ๋ฉด, ๋ ์ด์ดํ๊ฒ interpolate ํ๋ฉด์ context๊ธธ์ด๋ฅผ ๋๋ฆด ์ ์๋ค. ์ ๊ทธ๋ฆผ์์ ๋จ์ํ 4096 context ๊ธธ์ด๋ unseen์ด๋ผ ์คํจํ์ง๋ง, theta๋ฅผ ์ ๋ฐ์ผ๋ก ์ค์ธ ์๋ ๊ทธ๋ํ์์๋ ์๋ context ๊ธธ์ด ์์ ๋ค์ด์ค๊ธฐ ๋๋ฌธ์ ์ฑ๊ณตํ๋ ๋ชจ์ต์ ๋ณผ ์ ์๋ค.
KV cache optimizations(Multi-Head Attention (MHA) -> Multi-Query Attention (MQA) -> Grouped-Query Attention(GQA)
KV cache๋ attention ๋งค์ปค๋์ฆ์ Key, Value๋ฅผ ์ ์ฅํด๋๋ ๊ฒ์ ๋งํ๋ค. transfomer๋ฅผ decode(gpt-style. decoder ๋ชจ๋ธ ์ฌ์ฉํด ์์ฑํ๋ ๊ฒ์ ์๋ฏธํจ)ํ ๋๋ ์ง๊ธ ์์ ํ ํฐ์ attention์ ๊ณ์ฐํ๊ธฐ ์ํด ์ด์ ํ ํฐ๋ค์ Key, Value๊ฐ์ ๋ชจ๋ ์ ์ฅํ๊ณ ์์ด์ผ ํ๋ค. ์ ๊ทธ๋ฆผ์์ โtrainiumโ ํ ํฐ์ ์์ฑํ๊ธฐ ์ํด์ ์ด์ โIโ, โloveโ์ K,V๊ฐ ํ์ํ๋ค(Query๋ ํ์ํ์ง ์์) ๋จ์ํ ์์ํด๋ด๋, KV cache๋ฅผ ์ ์ฅํ๊ธฐ ์ํด ์ฌ์ฉ๋๋ ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋๋ฌด ๋ง์ด ํ์ํ๋ค. llama2-7b ๋ชจ๋ธ์์ KV cache ํฌ๊ธฐ๋ batch_size * 32(layers) * 128(n_emd) * N(length) * 2(K,V๋๊น 2๊ฐ) * 2byte(fp16) = 512KB * BS * N ๋งํผ ํ์ํ๋ค. llama2-70B ๋ชจ๋ธ์ ์ด๋ฐ ์์ผ๋ก ๊ณ์ฐํด๋ณด๋ฉด, batch size 16์ผ ๋ 4096๋ฒ์งธ token์ ์ฒ๋ฆฌํ ๋ KV cache์ ์ฉ๋์ 160GB์ ๋ฌํ๋ค. ๋ฐ๋ผ์ KV cache์ ์ฌ์ด์ฆ๋ฅผ ์ค์ผ ํ์๊ฐ ์๊ณ , ๊ทธ ๋ฐฉ๋ฒ์ด multi-query-attention(MQA), grouped-query-attention(GQA)์ด๋ค. ์ด์ค GQA๋ llama2์๋ ์ ์ฉ๋ ์ ๋๋ก ๋ง์ด ์ฌ์ฉ๋๋ ๋ฐฉ์์ด๋ค. ๊ฐ๊ฐ์ ๋ฐฉ์์ ์ดํด๋ณด๋ฉด
MQA : ๋ชจ๋ value์ key๋ฅผ ํ๋๋ก ํ๊ท ๋ธ๋ค
GQA : ๋ชจ๋ value์ key๋ฅผ G๊ฐ๋ก ํ๊ท ๋ธ๋ค(๋ณดํต G๋ N/8)
์ ๊ทธ๋ฆผ์ฒ๋ผ MQA, GQA๋ฅผ ์ฌ์ฉํ๋ฉด KV cache ํฌ๊ธฐ๋ฅผ ๋ง์ด ์ค์ผ ์ ์๋ค.
FFN->GLU
inverted bottleneck, relu๋ฅผ ์ฌ์ฉํ๋ ๊ธฐ์กด FFN ๊ณ์ธต์, GLU(Gated Linear Unit)๊ณผ swish ํ์ฑํ ํจ์๋ฅผ ์ฌ์ฉํ๋ฉด ์ฑ๋ฅ์ด ๋ ๋์์ง๋ค๊ณ ํ๋ค. ์ด๋ ์ฑ๋ฅ์ PPL(perplexity)๋ก ์ธก์ ํ๋ค.
3. Large language models(LLMs)
LLM์ด ์ด๋ค ์ผ์ ํ ์ ์๋์ง, ์ด๋ค ์ข ๋ฅ๊ฐ ์๋์ง์ ๋ํด์๋ ์ด๋ฏธ ๋ง์ด ๊ธ๊ณผ ์ ๋ณด๊ฐ ์๊ธฐ ๋๋ฌธ์, 3์ฅ์์๋ LLM์ ์ฌ๋ฌ๊ฐ์ง ํน์ง์ ๋ํด ๊ฐ๋จํ ์ค๋ช ํ๋ค. LLM์ ์ ๊ธฐํ ํน์ง์ค ํ๋๋, model size๊ฐ ์ปค์ง๋ค ๋ณด๋ฉด ์ด๋์ ํน์ task์ ๋ํ ๋ฅ๋ ฅ์ด ์๊ธด๋ค๋ ๊ฒ์ด๋ค. ์ ๋ ฅ ๋งฅ๋ฝ์ ๋ง๋ ์ซ์ ์ฐ์ฐ์ ํ๋ค๊ฑฐ๋, ์์ธ ์ํ๋ฒณ ์ฒ ์๋ฅผ ์ฐพ์๋ผ ์๋ ์๋ค.
๋ํ ์ด์ NLP ์๋์์๋ downstream task๋ฅผ ํ๊ธฐ ์ํด์ fine tuning์ ํด์ผ ํ์ง๋ง, LLM์ ํ์ธํ๋ ์์ด Zero-shot์ด๋ Few-shot ๋ฐฉ์์ผ๋ก downstream task๋ฅผ ํด๊ฒฐํ๋ค.
์ฃผ๋ชฉ๋ฐ๋ LLM ๋ชจ๋ธ๊ณผ ๊ฐ๊ฐ์ ํน์ง์ ๊ฐ๋ตํ ์ ๋ฆฌํด๋ณด๋ฉด, llama๋ SwiGLU๋ฅผ ์ ์ฉํ๊ณ , llama2์์๋ training tokens์ ํฌ๊ฒ ๋๋ฆฐ ์ ์ด falcon์ 180B๋ผ๋ ๊ฑฐ๋ํ model size๊ฐ mistral์ sliding window attention์ด๋ผ๋ attention ๊ธฐ๋ฒ์ด ๋ ํนํ ์ ์ด๋ค.
์น์น ๋ผ ๋ฒ์น(The Chinchilla Law)
์น์น ๋ผ ๋ฒ์น์ model size๋ฟ๋ง ์๋๋ผ, training data์ ํฌ๊ธฐ๋ ๋๋ ค์ผ ์ต์ ์ computation-accuracy trade-off ํ๋ ์ง์ ์ ์ฐพ์ ์ ์๋ค๋ ๊ฒ์ด๋ค. (๋ฌด์กฐ๊ฑด data ํฌ๊ธฐ๋ฅผ ๋๋ฆฌ๋๊ฒ ์ข๋ค๋๊ฒ ์๋๋ผ ๋ฐ์ดํฐ ์์ ๋ฐ๋ฅธ ์ต์ model size๊ฐ ์๋ค๋ ๋ป์ด๋ค) llama-2๊ฐ ๋น๊ต์ ์ ์ ํ๋ผ๋ฏธํฐ์ ๋ง์ train token์ผ๋ก ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋ค.
4. Advanced topics, multi-modal LLM
์ดํ ๊ฐ์์์ ํ๋ฒ ๋ ์์ธํ ๋ค๋ฃจ๊ธฐ ๋๋ฌธ์ ์ด๋ฒ ๊ฐ์ ์ ๋ฆฌ ๊ธ์์๋ ์๋ตํฉ๋๋ค
5. Efficient inference algorithms for LLMs
์์ ๊ฐ์๋ค์์ inference๋ฅผ ํจ์จ์ ์ผ๋ก ํ๊ธฐ ์ํ ๋ฐฉ๋ฒ์ผ๋ก quantization๊ณผ pruning์ ๋ฐฉ๋ฒ์ด ์์๊ณ , ์ด ๋ฐฉ๋ฒ๋ค์ LLM์๋ ์ ์ฉํด ๋ณผ ์ ์์ ๊ฒ์ ๋๋ค.
5.1. Quantization: SmoothQuant, AWQ, TinyChat
ํ์ง๋ง, ๋จ์ํ W8A8๊ณผ ๊ฐ์ด quantizeํ๋ ๊ฒ์, ๊ต์ฅํ ํฐ ์ฑ๋ฅ ์ ํ๋ฅผ ๋ณด์ฌ์ค๋ค ์ด์ ๋, LLM์์๋ activation์ outlier๊ฐ ์ค์ํ ์ญํ ์ ํ๊ธฐ ๋๋ฌธ์ด๋ค ์ ๊ทธ๋ฆผ์ ์ค๋ฅธ์ชฝ์ฒ๋ผ activation์๋ ๊ต์ฅํ ํฐ outlier๊ฐ ์กด์ฌํ๊ณ weight๋ ๋น๊ต์ ํธ์ฐจ๊ฐ ์๋ค. ๋ฐ๋ผ์ activation์ 10์ผ๋ก ๋๋๊ณ , weight์ 10์ ๊ณฑํ๋ฉด ์์์ ๊ฐ์ ๋ณํ์ง ์์ง๋ง, activation์ ์ข ๋ ํธํ๊ฒ quantizeํ ์ ์๊ฒ ๋๋ค(์ค๋ฅธ์ชฝ ๊ทธ๋ฆผ). ์ด๋ฐ ๋ฐฉ์์ activation์ ์ข๋ ํํํ๊ฒ ๋ง๋ ๋ค๊ณ ํด์ smoothQuant๋ผ๊ณ ํ๋ค. smoothQuant๋ฐฉ์์ llama ๋ชจ๋ธ์์๋ ๋งค์ฐ ์ ๋์ํ๋ค.
์ ๊ทธ๋ฆผ์ x์ถ์ธ compute intensity๋ FLOPs / MemoryBandwith๋ฅผ ๋ํ๋ธ๋ค. ์ฆ, ๋ฐ์ดํฐ ํ๋๋น ์ฐ์ฐ์ ์ผ๋ง๋ ํจ์จ์ ์ผ๋ก ํ ์ ์๋๋์ ๋ํ ์งํ์ด๋ค. ์ ๊ทธ๋ฆผ์์ batch size๊ฐ 1์ผ ๋ ๋ฎ์ TFLOPS๋ฅผ ๋ณด์ด๋ ์ด์ ๋ ๋ฉ๋ชจ๋ฆฌ ๋๋ฌธ์ด๋ค. LLM์์ ๋งค ํ ํฐ์ ์์ฑํ๊ธฐ ์ํด์๋ ํฐ ๋ฉ๋ชจ๋ฆฌ fetch๊ฐ ํ์ํ๋ค(parameter fetch.). activation๊ณผ weight ์ค์์๋ weight๊ฐ ํจ์ฌ ๋ ํฌ๋ฏ๋ก, weight๋ฅผ ์ค์ด๋๋ฐ ๋ ์ง์คํด์ผ ํ๋ค.
์์์ ์ดํด๋ณธ W8A8 ๋ฐฉ์์ quantization์ batch serving(ํ๋ฒ์ ์ฌ๋ฌ batch๋ฅผ ์ฒ๋ฆฌํ๋ ์ผ)์์๋ ์ ๋์ํ๋ค. ํ๋๋ง ์ฒ๋ฆฌํ๋ ์์ ์(single-batch) memory-bounded(๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ์กฑํ๋ฉด bottleneck์ด ๋๋ค)์ด๋ค. ๋น์ฐํ weight๋ฅผ ๋ฐ๋ก quantize ํ๋ฉด ์ ๊ทธ๋ฆผ์ฒ๋ผ ์ฑ๋ฅ ์ ํ๊ฐ ๋ฐ์ํ๋ค ์ค๋ฅธ์ชฝ ๊ทธ๋ฆผ์ ์ดํด๋ณด๋ฉด, RTN๋ฐฉ์์ ๋จ์ํ ์ ์ฉํ ๊ฒฝ์ฐ Perplexity๊ฐ ๋ง์ด ์์นํ ๊ฒ์ ๋ณผ ์ ์๋ค. ๋ ๋์ ๋ฐฉ๋ฒ์, ์ค์ํ(salient) weight๋ค๋ง quantize ํ์ง ์๊ณ ๋๋ ๊ฒ์ธ๋ฐ, salient ํ๋ค๊ณ ํ๋จํ๋ ๊ธฐ์ค์ โactivationโ๊ฐ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋จํ ๋(magnitude-base, ๋จ์ํ ์ ๋๊ฐ์ด ํฌ๋ฉด ์ค์ํ๋ค๊ณ ํ๋จ) ์ข์ ์ฑ๋ฅ์ ๋ณด์ธ๋ค. ์ด๋ฐ ๋ฐฉ์์ AWQ(Activation-aware Weight Quantization) ๋ผ๊ณ ํ๋ค.
SmoothQuant์ AWQ๋ ์ค๋๋ ๋๋ฆฌ ์ฌ์ฉ๋๋ ๋ฐฉ์์ด๋ค.
5.2. Pruning/sparsity: SpAtten, H2O, MoE
quantization์ ํ์ผ๋ฉด, pruning๋ ํด ๋ด์ผ ํ๋ค. Wanda๋ AWQ์ฒ๋ผ Weight์ Activation์ ๊ณ ๋ คํด์ pruning ํ๋ ๋ฐฉ์์ด๋ค SpAtten์ ์ค์ํ์ง ์์ ํ ํฐ ์์ฒด๋ฅผ ์ญ์ ํ๋ ๋ฐฉ์์ด๋ค. ์ค๋ฅธ์ชฝ attention ๋งต ๊ธฐ๋ฐ์ผ๋ก, ๊ฐ์ฅ ๋ฎ์ attention ํฉ๊ณ๋ฅผ ๊ฐ์ง ํ ํฐ์ ์ญ์ ํ๋ค. H2O๋ Heavy Hitter Token(H2)๋ฅผ ์ค์ฌ์ผ๋ก ๋จ๊ธฐ๊ณ , ๋๋จธ์ง๋ฅผ pruningํ๋ ๋ฐฉ์์ด๋ค. ์ฌ๊ธฐ์ ๋งํ๋ Heavy Hitter๋ attention ๊ธฐ๋ฐ์ผ๋ก ์ ์ ํ๋ค. ์ดํดํ๊ธฐ๋ก๋, SpAtten์ ๋ฐฉ์๊ณผ ๋น์ทํ๋ค๊ณ ๋๊ผ๋ค. DejaVu๋ ์ ๋ ฅ์ ์ํฅ์ ๋ฐ์ง ์๋ attention head๋ค์ด ์กด์ฌํ๊ณ , ์ด๊ฒ์ contextual sparsity๋ผ๊ณ ๋ถ๋ฅด๋ฉฐ, ์ด ํจํด์ MLP๋ฅผ ํตํด ์์ํ ์ ์๋ค๋ ๊ฐ์ค์ ์ธ์ ๋ค. ์ด๋ฐ contextual sparsity๋ฅผ ์ ๊ฑฐํ๋ ๋ฐฉ์์ DejaVu๋ ์ฌ์ฉํ๋ค. MoE(Mixture of Experts) ๋ FFN์ N๊ฐ๋ก ๋๋๊ณ , Expert๋ฅผ ์ฌ์ฉํด ๊ทธ์ค์ ํ๋๋ฅผ ๊ณ ๋ฅด๋ ๊ฐ๋ ์ ๋์ ํ๋ค. ๊ทธ๋ฆผ ์ค๊ฐ์ ์๋ Router๋ก๋ถํฐ ํ๋ฅ ์ ์ผ๋ก ์ด๋ค FFN์ ์ฌ์ฉํ ์ง MoE๋ฐฉ์์ GPT-4์์ ์ฌ์ฉํ๊ณ ์๋ค๊ณ ์๋ ค์ ธ ์๋ค.
6. Efficient inference systems for LLMs
์ด ์ฅ์์๋ system์ ๊ด์ ์์ ๋ ํจ์จ์ ์ผ๋ก LLM์ inferenceํ๋ ๋ฒ์ ๋ค๋ฃฌ๋ค.
6.1. vLLM(Paged Attention)
๋ค์์ ์ฌ์ฉ์๊ฐ LLM์ ์ฌ์ฉํ๋ ํ๊ฒฝ์์ ๋ฌด์์ด ๋ฌธ์ ๊ฐ ๋ ๊น? ์ ๊ทธ๋ฆผ์ฒ๋ผ ์ฐ๋ฆฌ๋ LLM์ ์ถ๋ ฅ์ด ์ผ๋ง๋ ๊ธธ์ด์ง ์ง ๋ชจ๋ฅด๊ธฐ ๋๋ฌธ์, ์ผ๋ง๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ ๋นํด์ผ ํ ์ง ์ ์ ์๋ค. ๋ฐ๋ผ์ <resv> ์ฒ๋ผ ๋ด๋ถ ๋จํธํ, ํน์ ๋ค๋ฅธ ์์ฒญ๊ฐ์ ๊ฐ๊ฒฉ์ผ๋ก ์ธํด ์ธ๋ถ ๋จํธํ๊ฐ ๋ฐ์ํ๊ฒ ๋๋ค. ๋ง์น ์ค์ ์ด์์ฒด์ ์ ๋ฉ๋ชจ๋ฆฌ๊ฐ๋ค. ๊ทธ๋ ๋ค๋ฉด, ์ฌ๋ฐ๊ฒ๋ ์ด์์ฒด์ ์์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ผ๋ก ์ด๋ฅผ ํด๊ฒฐํ ์ ์๊ณ , ๊ทธ ๋ฐฉ๋ฒ์ด ๋ฐ๋ก Page๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ์์ด๋ค.
OS์์ ๋ค๋ฅธ ํ๋ก์ธ์ค๊ฐ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ ๋ page๋จ์๋ก ์ฌ์ฉํ๋ฏ์ด, LLM์์๋ ๋ค๋ฅธ ์์ฒญ๋ค ๊ฐ์ KV cache๋ฅผ page ๋จ์๋ก ์ฌ์ฉํ๋ฉด ๋๋ค. ์์ฒ๋ผ ๋ค๋ฅธ ์์ฒญ์ page๋จ์๋ก ๋ฐ์ ์ ์๋ค.
๋ ๋๋ผ์ด ์ ์, ํ๋์ KV Cache๋ฅผ ๊ณต์ ํ ์ ์๋ค๋ ์ ์ด๋ค. ์ ๋ฌธ์ฅ์ ๊ณต์ ํ๊ฑฐ๋, ์๋๋ฉด Prompt๊ฐ์ด ๋ง์ด ์ฌ์ฉ๋๋ ๋ฌธ์ฅ์ KV cache๋ฅผ ๊ณต์ ํด ํจ์จ์ ์ผ๋ก ๋๋์ inference๊ฐ ๊ฐ๋ฅํ๋ค.
์ด๋ฐ ๋ฐฉ์์ Paged Attention ์ด๋ผ๊ณ ํ๊ณ , ์ด ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ๊ฒ์ด vLLM์ด๋ผ๋ ๋ฐฉ๋ฒ๋ก ์ด๋ค.
6.2. StreamingLLM
LLM ๋ฐฐํฌ์ ๋ ๋ค๋ฅธ ๋ฌธ์ ๋ ๊ธธ์ด ๋ฌธ์ ์ด๋ค. ์์ฒญ๋๊ฒ ๊ธด ๋ฌธ์ฅ์ด๋, ํน์ ์ฑ๋ด์์ ์์ฒญ ์์ ์ ์ด์ผ๊ธฐํ๋ ๋ด์ฉ๊น์ง ๊ธฐ์ตํ๋ ค๋ฉด ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋งค์ฐ ๋ง์ด ํ์ํ๋ค. ๋จ์ํ transformer ๋ฐฉ์์ ์ฌ์ฉํ๋ฉด(๋ ธ๋์ ๊ทธ๋ํ) ๋ฉ๋ชจ๋ฆฌ๋ ์ ํ์ ์ผ๋ก ์ฆ๊ฐํ๊ณ , perplexity๋ ์ ๋ ฅ ๊ธธ์ด 4K ์ดํ๋ก ํญ๋ฐ์ ์ผ๋ก ์ฆ๊ฐํ๋ค(training์์ ๋ณด์ง ๋ชปํ ๊ธธ์ด์ด๊ธฐ ๋๋ฌธ์) windowed attention(์ผ์ context๋ง ๊ธฐ์ต, ๋ น์)์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ผ์ ํ์ง๋ง ์ ๋ ฅ ๊ธธ์ด๊ฐ window๊ธธ์ด๋ฅผ ๋ฒ์ด๋๋ ์๊ฐ(๊ทธ๋ฆผ์์๋ 1K์ ๋) perplexity๊ฐ ๊ธ์ฆํ๋ค(์ฒซ ๋ช ํ ํฐ์ด ๋งค์ฐ ์ค์ํ๊ธฐ ๋๋ฌธ์) a๊ฐ ์์์ ๋งํ ๋จ์ transformer๋ฐฉ์, b๊ฐ windowed attention์ด๋ค. c๋ sliding window๋ฐฉ์์ธ๋ฐ, ์ด์ ํ ํฐ์ ๋ฉ๋ชจ๋ฆฌ์ ๋๋๊ฒ ์๋๋ผ ๋ค์ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ด๋ค. ์ด ๋ฐฉ๋ฒ์ perplexity๋ ๊ด์ฐฎ์ง๋ง, ์ฐ์ฐํ๋๋ฐ ๋๋ฌด ๋ง์ ์๊ฐ์ด ๋ ๋ค.
์ด๋ฐ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ฐพ์๋ด๊ธฐ ์ํ ์์ด๋์ด๋ฅผ Attention Sink์์ ์ฐพ์๋ค. ์ ๊ทธ๋ฆผ์์ ๋ณด๋ฉด ์ฒซ๋ฒ์งธ ํ ํฐ์ attention score๊ฐ ๋งค์ฐ ๋์ ๊ฒ์ ์ ์ ์๋ค. ๊ทธ๋ฐ๋ฐ, ๊ทธ๋ค์ด ๋ฌธ๋งฅ์ ์ผ๋ก(semantically)์ค์ํ์ง ์์ ๊ฒฝ์ฐ์๋ ๊ทธ๋ ๋ค. ์ด๋ฐ ํ์์ Attention Sink๋ผ๊ณ ํ๋๋ฐ, ์ ์ผ์ด๋๋ ํ์์ผ๊น? attention์ ๊ตฌํ ๋ softmax๋ฅผ ์ฌ์ฉํ๊ฒ ๋๋๋ฐ, decoding์ ํ๋ฉด์ ์ฒซ๋ฒ์งธ ํ ํฐ์ ๋ชจ๋ ํ ํฐ์ decode ํ ๋ ๋ฑ์ฅํ๊ฒ ๋๋ฏ๋ก, ๋น์ฐํ ์ด๋์ ๋์ ๊ฐ์ ๊ณ์ ๋ํด๊ฐ์ ์๊ธฐ๋ ํ์์ด๋ผ๋ ๊ฒ์ด๋ค. ๊ทธ๋์ ์ด๋ฐ attention sink๊ฐ ์ผ์ด๋๋ ์ฒซ ํ ํฐ์ ๋ฌด์กฐ๊ฑด ๋จ๊ฒจ๋๊ณ , windowed attention์ ์ฌ์ฉํ๋ฉด ๋ ๊ด์ฐฎ์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์๋ค๋ ๊ฒ์ด๋ค. ์ด๋ฐ ํ์์ ๋ํ ๋ ผ๋ฆฌ์ ์ธ ์ค๋ช ์ ์ฐพ์ง ๋ชปํ์ง๋ง, ์๋ง๋ ์ฒซ ํ ํฐ์ด ๋ฌธ๋งฅ์ ์ผ๋ก ์ค์ํ์ง ์๋๋ผ๋ โsink(์์๋๋)โ์ ์ญํ ์ ํ๋ ๊ฒ์ด๋ผ๊ณ ์๊ฐ๋๋ค. ablation study์์๋ ํ๋์ ํ ํฐ์ด ์๋๋ผ, 4๊ฐ์ token์ sink๋ก ํ๋๊ฒ ํ๊ท ์ ์ผ๋ก ์ข๋ค๋ ๊ฒฐ๊ณผ๊ฐ ์๋ค
6.3. FlashAttention
FlashAttention์ ์ข ๋ ํ๋์จ์ด์ ์ธ ์ ๊ทผ์ด๋ค. HBM(High Bandwith Memory)์ ์ ๊ทผํ๋ ํ์๋ฅผ ์ค์ด๋ ์์ด๋์ด์ด๋ค. ํ๋ ฌ ์ฐ์ฐ์ ํ ๋ ์ ์ฒด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ถ๋ฌ์ค๋ ๊ฒ์ด ์๋๋ผ ํ๋์ฉ ๋ถ๋ฌ์์(Copy Block to SRAM๋ถ๋ถ) GPU์ SRAM ๋ด์์ ์ฐ์ฐ์ ์ต๋ํ ๋ง์น๊ฒ ๋ค๋ ์์ด๋์ด์ด๋ค. ์์์ ๋ค๋ค๋ MQA, GQA๋ฑ์ ์ ์ฉํ FlashAttention-2๋ผ๋ ๋ ผ๋ฌธ๋ ์๋ค.
6.4. Speculative decoding
LLM์ decoding์ ๋งค์ฐ memory-boundedํ๋ค. ํ๋ํ๋์ ํ ํฐ์ ์์ฑํ ๋๋ง๋ค ๋งค์ฐ ๋ง์ ๋ฉ๋ชจ๋ฆฌ ์ฐ์ฐ์ด ํ์ํ๋ค. Speculative Decoding์ ์ด๋ฐ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์์ ๋ชจ๋ธ๋ก K๊ฐ์ ํ ํฐ์ ์์ฑํ ๋ค ํฐ ๋ชจ๋ธ๋ก ์ด ํ ํฐ์ด ์ข์์ง ์๋์ง ํ๋จํ๊ณ ๋์์ ์์ฑํ๋ค(ํฐ ๋ชจ๋ธ์์๋ batch size๊ฐ 1์ผ ๋๋ K์ผ๋๋ ๋น์ทํ๋ฏ๋ก) K๊ฐ์ token์ ์์ฑํ ๋, ํฐ ๋ชจ๋ธ์ K๋ฒ ํธ์ถํ๋ ๊ฒ์ด ์๋๋ผ ์์ ๋ชจ๋ธ์ K๋ฒ, ํฐ ๋ชจ๋ธ์ 1๋ฒ๋ง ํธ์ถํ๋ฉด ๋๋ฏ๋ก decoding ์๊ฐ์ ์ ์ฝํ ์ ์๋ค(๋๋ต 2~3๋ฐฐ)
7. Efficient fine-tuning for LLMs
์ด๋ฒ ์ฅ์์๋ LLM fine tuning์ ํจ์จ์ ์ผ๋ก ํ๋ ๋ฒ์ ์์๋ณธ๋ค.
7.1. LoRA/QLoRA
LoRA๋ ๋ชจ๋ธ ์ ์ฒด๋ฅผ updateํ๋๊ฒ ์๋๋ผ, ์์ bypass branch์ weight๋ง updateํ๋ ๋ฐฉ๋ฒ์ด๋ค. LLM์ pretrainํ ๊ฐ์ค์น W๊ฐ ์๊ณ , full-fine tuningํ์ ๋ ๋ฌ๋ผ์ง๋ ๊ฐ์ค์น๋ฅผ delta W๋ผ๊ณ ํ์. ๊ทธ๋ฌ๋ฉด ๊ทธ delta W๋ฅผ low-rank ํ๋ ฌ์ธ A์ B์ ๊ณฑ(์ ๊ทธ๋ฆผ์ ์ฃผํฉ์ ํ๋ ฌ)์ผ๋ก ๋ํ๋ด์๋ ์์ด๋์ด์ด๋ค.
QLoRA๋ ๊ฐ๋จํ ๋งํ๋ฉด LoRA์ quantization์ ๋ํ ๊ฒ์ด๋ค. NF4(NormalFloat4)๋ผ๋ normal distribution์ ์ต์ ํ๋ ๋ฐ์ดํฐ ํ์ , Double Quantization, paged optimizer๋ฑ์ ๊ธฐ๋ฒ์ ์ฌ์ฉํ๋ค.
7.2. Adapter
Adapter๋ transformer ๋ธ๋ก์ learnableํ ์์ ๋ธ๋ก์ ํ๋ ๋ผ์ ๋ฃ๋ ๊ฒ์ด๋ค. ์ ๊ทธ๋ฆผ์์ ์ด๋ํฐ๋ ์ค๋ฅธ์ชฝ ๋ ธ๋์ ๊ตฌ์กฐ๋ฅผ ๋ฃ์ ๊ฒ์ด๋ค. ํ์ง๋ง ์๋ก์ด layer๊ฐ ์ถ๊ฐ๋๋ ๊ฒ์ด๋ผ, inference์ ์๊ฐ์ด ์กฐ๊ธ ๋ ๋์ด๋ ์ ์๋ค๋ ๋ฌธ์ ์ ์ด ์๋ค.
7.3. Prompt Tuning
์์ ๋ฐฉ๋ฒ๋ค๊ณผ๋ ๋ค๋ฅด๊ฒ, tuning์์ด prompt๋ง ์ ๋ ฅํด์ ํน์ ํ task์ ๋ํ ์ฑ๋ฅ์ ๋์ด๋ ๋ฐฉ๋ฒ์ด๋ค. ์๋ฅผ ๋ค์ด, โ๋ค์ ๋ฌธ์ฅ์ ์์ฝํด์ค :โ ๋ผ๋ ๋ฌธ์ฅ์ ์ ๋ ฅ์ ์ถ๊ฐํ๋ฉด ์์ฝ task์ ๋ํ ์ฑ๋ฅ์ด ์ฌ๋ผ๊ฐ๋ค. ์ด๋ฅผ ํ์ํ๋ฉด, ํ๋์ ๋ชจ๋ธ๋ก ์ฌ๋ฌ๊ฐ์ง task์ ๋์ํ ์ ์๊ฒ ๋๋ฉฐ, ๋ชจ๋ธ์ด ์ปค์ง์๋ก ํด๋น task์ ๋ํด์๋ง fine-tuningํ ๋ชจ๋ธ์ด๋ ๋น์ทํ ์ฑ๋ฅ์ ๋ด๊ฒ ๋๋ค.