๐Ÿ‘ฉโ€๐Ÿ’ป Lab 1

lab
pruning
fine-grained
channel
Fine-grained & Channel Pruning
Author

castleflag

Published

March 1, 2024

Lab 1 Pruning

์ด๋ฒˆ Lab 1 Pruning์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๋ชฉํ‘œ์™€ ๋‚ด์šฉ์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜ ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ Colaboratory์—์„œ ๋ฐ”๋กœ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Goals

  • pruning์˜ ๊ธฐ๋ณธ ๊ฐœ๋…์„ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค.
  • fine-grained pruning์„ ๊ตฌํ˜„ํ•˜๊ณ  ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • channel pruning์„ ๊ตฌํ˜„ํ•˜๊ณ  ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • pruning์œผ๋กœ๋ถ€ํ„ฐ์˜ ์„ฑ๋Šฅ ๊ฐœ์„ (์˜ˆ: ์†๋„ ํ–ฅ์ƒ)์— ๋Œ€ํ•œ ๊ธฐ๋ณธ์ ์ธ ์ดํ•ด๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.
  • ์ด๋Ÿฌํ•œ pruning ์ ‘๊ทผ ๋ฐฉ์‹ ๊ฐ„์˜ ์ฐจ์ด์ ๊ณผ tradeoffs๋ฅผ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค.

Contents

์ด ์‹ค์Šต์—๋Š” Fine-grained Pruning๊ณผ Channel Pruning์˜ ๋‘ ๊ฐ€์ง€ ์ฃผ์š” ์„น์…˜์ด ์žˆ์Šต๋‹ˆ๋‹ค.

์ด 9๊ฐœ์˜ ์งˆ๋ฌธ์ด ์žˆ์Šต๋‹ˆ๋‹ค:

  • Fine-grained Pruning์— ๋Œ€ํ•ด์„œ๋Š” 5๊ฐœ์˜ ์งˆ๋ฌธ์ด ์žˆ์Šต๋‹ˆ๋‹ค (Question 1-5).
  • Channel Pruning์— ๋Œ€ํ•ด์„œ๋Š” 3๊ฐœ์˜ ์งˆ๋ฌธ์ด ์žˆ์Šต๋‹ˆ๋‹ค (Question 6-8).
  • Question 9๋Š” fine-grained pruning๊ณผ channel pruning์„ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค.

์‹ค์Šต๋…ธํŠธ์— ๋Œ€ํ•œ ์„ค์ • ๋ถ€๋ถ„(Setup)์€ Colaboratory Note๋ฅผ ์—ด๋ฉด ํ™•์ธํ•˜์‹ค ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํฌ์ŠคํŒ…์—์„œ๋Š” ๋ณด๋‹ค ์‹ค์Šต๋‚ด์šฉ์— ์ง‘์ค‘ํ•  ์ˆ˜ ์žˆ๋„๋ก ์ƒ๋žต๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

weight ๊ฐ’์˜ ๋ถ„ํฌ๋ฅผ ์‚ดํŽด๋ด…์‹œ๋‹ค.

pruning์œผ๋กœ ๋„˜์–ด๊ฐ€๊ธฐ ์ „์—, dense ๋ชจ๋ธ์—์„œ ๊ฐ€์ค‘์น˜ ๊ฐ’์˜ ๋ถ„ํฌ๋ฅผ ์‚ดํŽด๋ด…์‹œ๋‹ค.

def plot_weight_distribution(model, bins=256, count_nonzero_only=False):
    fig, axes = plt.subplots(3,3, figsize=(10, 6))
    axes = axes.ravel()
    plot_index = 0
    for name, param in model.named_parameters():
        if param.dim() > 1:
            ax = axes[plot_index]
            if count_nonzero_only:
                param_cpu = param.detach().view(-1).cpu()
                param_cpu = param_cpu[param_cpu != 0].view(-1)
                ax.hist(param_cpu, bins=bins, density=True,
                        color = 'blue', alpha = 0.5)
            else:
                ax.hist(param.detach().view(-1).cpu(), bins=bins, density=True,
                        color = 'blue', alpha = 0.5)
            ax.set_xlabel(name)
            ax.set_ylabel('density')
            plot_index += 1
    fig.suptitle('Histogram of Weights')
    fig.tight_layout()
    fig.subplots_adjust(top=0.925)
    plt.show()

plot_weight_distribution(model)

Question 1 (10 pts)

์œ„ weight ํžˆ์Šคํ† ๊ทธ๋žจ๋“ค์„ ๋ณด๊ณ  ๋‹ค์Œ ์งˆ๋ฌธ์— ๋‹ตํ•ด ์ฃผ์„ธ์š”.

Question 1.1 (5 pts)

๊ฐ๊ธฐ ๋‹ค๋ฅธ ๊ณ„์ธต์—์„œ weight ๋ถ„ํฌ๋“ค์˜ ๊ณตํ†ต์ ์ธ ํŠน์„ฑ์€ ๋ฌด์—‡์ธ๊ฐ€์š”?

Your Answer:

mean์ด 0์ธ normal ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅด๊ณ  ์žˆ๋‹ค (backbone์˜ ๊ฒฝ์šฐ, classifier ์ œ์™ธ)

Question 1.2 (5 pts)

์ด๋Ÿฌํ•œ ํŠน์„ฑ๋“ค์ด pruning์— ์–ด๋–ป๊ฒŒ ๋„์›€์ด ๋˜๋‚˜์š”?

Your Answer:

0์ด ๋งŽ์œผ๋ฏ€๋กœ, ๊ณ„์‚ฐํ•˜์ง€ ์•Š๊ฑฐ๋‚˜ ์—†์•จ ์ˆ˜ ์žˆ๋‹ค.

Fine-grained Pruning

์ด ์„น์…˜์—์„œ๋Š” fine-grained pruning์„ ๊ตฌํ˜„ํ•˜๊ณ  ์ˆ˜ํ–‰ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

Fine-grained pruning์€ ๊ฐ€์žฅ ์ค‘์š”๋„๊ฐ€ ๋‚ฎ์€ synapses๋ฅผ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค. Fine-grained pruning ํ›„์—๋Š” ๊ฐ€์ค‘์น˜ ํ…์„œ \(W\)๊ฐ€ sparseํ•ด์ง€๋ฉฐ, ์ด๋Š” sparsity๋กœ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

\(\mathrm{sparsity} := \#\mathrm{Zeros} / \#W = 1 - \#\mathrm{Nonzeros} / \#W\)

์—ฌ๊ธฐ์„œ \(\#W\)๋Š” \(W\)์˜ element ์ˆ˜์ž…๋‹ˆ๋‹ค.

์‹ค์ œ๋กœ, ๋ชฉํ‘œ sparsity \(s\)๊ฐ€ ์ฃผ์–ด์ง€๋ฉด, ๊ฐ€์ค‘์น˜ ํ…์„œ \(W\)๋Š” ์ œ๊ฑฐ๋œ ๊ฐ€์ค‘์น˜๋ฅผ ๋ฌด์‹œํ•˜๊ธฐ ์œ„ํ•ด ์ด์ง„ ๋งˆ์Šคํฌ \(M\)๊ณผ ๊ณฑํ•ด์ง‘๋‹ˆ๋‹ค:

\(v_{\mathrm{thr}} = \texttt{kthvalue}(Importance, \#W \cdot s)\)

\(M = Importance > v_{\mathrm{thr}}\)

\(W = W \cdot M\)

์—ฌ๊ธฐ์„œ \(Importance\)๋Š” \(W\)์™€ ๋™์ผํ•œ ํ˜•ํƒœ์˜ ์ค‘์š”๋„ ํ…์„œ์ด๋ฉฐ, \(\texttt{kthvalue}(X, k)\)๋Š” ํ…์„œ \(X\)์˜ \(k\)๋ฒˆ์งธ๋กœ ์ž‘์€ ๊ฐ’์„ ์ฐพ์œผ๋ฉฐ, \(v_{\mathrm{thr}}\)๋Š” ์ž„๊ณ„๊ฐ’์ž…๋‹ˆ๋‹ค.

Magnitude-based Pruning

Fine-grained pruning์— ์žˆ์–ด์„œ ๋„๋ฆฌ ์‚ฌ์šฉ๋˜๋Š” importance(์ค‘์š”๋„)๋Š” weight ๊ฐ’์˜ ํฌ๊ธฐ, ์ฆ‰,

\(Importance=|W|\)

์ž…๋‹ˆ๋‹ค. Magnitude-based Pruning์œผ๋กœ ์•Œ๋ ค์ ธ ์žˆ์Šต๋‹ˆ๋‹ค (Learning both Weights and Connections for Efficient Neural Networks ์ฐธ์กฐ).

Question 2 (15 pts)

๋‹ค์Œ magnitude-based fine-grained pruning ํ•จ์ˆ˜๋ฅผ ์™„์„ฑํ•ด ์ฃผ์„ธ์š”.

Hint:

  • 1๋‹จ๊ณ„์—์„œ๋Š” pruning ํ›„์— 0์˜ ๊ฐœ์ˆ˜(num_zeros)๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. num_zeros๋Š” ์ •์ˆ˜์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋ถ€๋™ ์†Œ์ˆ˜์  ์ˆซ์ž๋ฅผ ์ •์ˆ˜๋กœ ๋ณ€ํ™˜ํ•˜๊ธฐ ์œ„ํ•ด round() ๋˜๋Š” int()๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” round()๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • 2๋‹จ๊ณ„์—์„œ๋Š” ๊ฐ€์ค‘์น˜ ํ…์„œ์˜ importance๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. Pytorch๋Š” torch.abs(), torch.Tensor.abs(), torch.Tensor.abs_() API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
  • 3๋‹จ๊ณ„์—์„œ๋Š” threshold๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ threshold๋ณด๋‹ค ์ค‘์š”๋„๊ฐ€ ๋‚ฎ์€ ๋ชจ๋“  synapses๊ฐ€ ์ œ๊ฑฐ๋˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. Pytorch๋Š” torch.kthvalue(), torch.Tensor.kthvalue(), torch.topk() API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
  • 4๋‹จ๊ณ„์—์„œ๋Š” threshold๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ pruning mask๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. mask์—์„œ 1์€ synapse๊ฐ€ ์œ ์ง€๋จ์„ ๋‚˜ํƒ€๋‚ด๊ณ , 0์€ synapse๊ฐ€ ์ œ๊ฑฐ๋จ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. mask = importance > threshold. Pytorch๋Š” torch.gt() API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
def fine_grained_prune(tensor: torch.Tensor, sparsity : float) -> torch.Tensor:
    """
    magnitude-based pruning for single tensor
    :param tensor: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    :return:
        torch.(cuda.)Tensor, mask for zeros
    """
    sparsity = min(max(0.0, sparsity), 1.0)
    if sparsity == 1.0:
        tensor.zero_()
        return torch.zeros_like(tensor)
    elif sparsity == 0.0:
        return torch.ones_like(tensor)

    num_elements = tensor.numel()

    ##################### YOUR CODE STARTS HERE #####################
    # Step 1: calculate the #zeros (please use round())
    num_zeros = round(num_elements * sparsity)
    # Step 2: calculate the importance of weight
    importance = torch.abs(tensor)
    # Step 3: calculate the pruning threshold
    threshold = torch.kthvalue(torch.flatten(importance), num_zeros)[0]
    # Step 4: get binary mask (1 for nonzeros, 0 for zeros)
    mask = importance > threshold

    ##################### YOUR CODE ENDS HERE #######################

    # Step 5: apply mask to prune the tensor
    tensor.mul_(mask)

    return mask

์œ„์—์„œ ์ •์˜ํ•œ fine-grained pruning ๊ธฐ๋Šฅ์„ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด, ๋”๋ฏธ ํ…์„œ์— ์œ„ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•ด ๋ด…์‹œ๋‹ค.

test_fine_grained_prune()

* Test fine_grained_prune()
    target sparsity: 0.75
        sparsity before pruning: 0.04
        sparsity after pruning: 0.76
        sparsity of pruning mask: 0.76
* Test passed.

Question 3 (5 pts)

๋งˆ์ง€๋ง‰ ์…€์€ pruning ์ „ํ›„์˜ ํ…์„œ๋ฅผ ๊ทธ๋ฆฝ๋‹ˆ๋‹ค. 0์ด ์•„๋‹Œ ๊ฐ’์€ ํŒŒ๋ž€์ƒ‰์œผ๋กœ, 0์€ ํšŒ์ƒ‰์œผ๋กœ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ ์ฝ”๋“œ ์…€์—์„œ target_sparsity์˜ ๊ฐ’์„ ์ˆ˜์ •ํ•˜์—ฌ pruning ํ›„ sparse ํ…์„œ์— 0์ด ์•„๋‹Œ ๊ฐ’์ด 10๊ฐœ๋งŒ ๋‚จ๋„๋ก ํ•ด ์ฃผ์„ธ์š”.

##################### YOUR CODE STARTS HERE #####################
# sparsity:=#Zeros/#๐‘Š=1โˆ’#Nonzeros/#๐‘Š
# 1 - 10/25
target_sparsity = 0.6 # please modify the value of target_sparsit
##################### YOUR CODE ENDS HERE #####################
test_fine_grained_prune(target_sparsity=target_sparsity, target_nonzeros=10)

* Test fine_grained_prune()
    target sparsity: 0.60
        sparsity before pruning: 0.04
        sparsity after pruning: 0.60
        sparsity of pruning mask: 0.60
* Test passed.

์ด์ œ fine-grained pruning ํ•จ์ˆ˜๋ฅผ ์ „์ฒด ๋ชจ๋ธ์„ pruningํ•˜๋Š” ํด๋ž˜์Šค๋กœ ๋ž˜ํ•‘ํ•ฉ๋‹ˆ๋‹ค. FineGrainedPruner ํด๋ž˜์Šค์—์„œ๋Š” ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๊ฐ€ ๋ณ€๊ฒฝ๋  ๋•Œ๋งˆ๋‹ค ๋งˆ์Šคํฌ๋ฅผ ์ ์šฉํ•˜์—ฌ ๋ชจ๋ธ์ด ํ•ญ์ƒ sparse ์ƒํƒœ๋ฅผ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ๋„๋ก pruning ๋งˆ์Šคํฌ ๊ธฐ๋ก์„ ๊ฐ€์ง€๊ณ  ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

class FineGrainedPruner:
    def __init__(self, model, sparsity_dict):
        self.masks = FineGrainedPruner.prune(model, sparsity_dict)

    @torch.no_grad()
    def apply(self, model):
        for name, param in model.named_parameters():
            if name in self.masks:
                param *= self.masks[name]

    @staticmethod
    @torch.no_grad()
    def prune(model, sparsity_dict):
        masks = dict()
        for name, param in model.named_parameters():
            if param.dim() > 1: # we only prune conv and fc weights
                masks[name] = fine_grained_prune(param, sparsity_dict[name])
        return masks

Sensitivity Scan

๊ฐ ๋ ˆ์ด์–ด๋Š” ๋ชจ๋ธ ์„ฑ๋Šฅ์— ๋Œ€ํ•ด ๊ฐ๊ฐ ๋‹ค๋ฅด๊ฒŒ ๊ธฐ์—ฌํ•ฉ๋‹ˆ๋‹ค. ๊ฐ ๋ ˆ์ด์–ด์— ์ ์ ˆํ•œ sparsity๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๊ฒƒ์€ ์–ด๋ ค์šด ์ผ์ž…๋‹ˆ๋‹ค. ๋„๋ฆฌ ์‚ฌ์šฉ๋˜๋Š” ์ ‘๊ทผ ๋ฐฉ์‹์€ sensitivity scan์ž…๋‹ˆ๋‹ค.

sensitivity scan ๋™์•ˆ, ๊ฐ ์‹œ๊ฐ„๋งˆ๋‹ค ํ•˜๋‚˜์˜ ๋ ˆ์ด์–ด๋งŒ์„ pruneํ•˜์—ฌ accuracy ์ €ํ•˜๋ฅผ ๊ด€์ฐฐํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์–‘ํ•œ sparsities๋ฅผ ์Šค์บ”ํ•จ์œผ๋กœ์จ, ํ•ด๋‹น ๋ ˆ์ด์–ด์˜ sensitivity curve (์ฆ‰, ์ •ํ™•๋„ ๋Œ€๋น„ sparsity)๋ฅผ ๊ทธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค์Œ์€ sensitivity curves์˜ ์˜ˆ์‹œ ๊ทธ๋ฆผ์ž…๋‹ˆ๋‹ค. x์ถ•์€ sparsity ๋˜๋Š” #parameters๊ฐ€ ๊ฐ์†Œํ•œ ๋น„์œจ (์ฆ‰, sparsity)์ž…๋‹ˆ๋‹ค. y์ถ•์€ ๊ฒ€์ฆ ์ •ํ™•๋„์ž…๋‹ˆ๋‹ค. (Learning both Weights and Connections for Efficient Neural Networks์˜ Figure 6)

๋‹ค์Œ ์ฝ”๋“œ ์…€์€ ์Šค์บ”๋œ sparsities์™€ ๊ฐ ๊ฐ€์ค‘์น˜๊ฐ€ prune๋  ๋•Œ์˜ ์ •ํ™•๋„ ๋ฆฌ์ŠคํŠธ๋ฅผ ๋ฐ˜ํ™˜ํ•˜๋Š” sensitivity scan ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

@torch.no_grad()
def sensitivity_scan(model, dataloader, scan_step=0.1, scan_start=0.4, scan_end=1.0, verbose=True):
    sparsities = np.arange(start=scan_start, stop=scan_end, step=scan_step)
    accuracies = []
    named_conv_weights = [(name, param) for (name, param) \
                          in model.named_parameters() if param.dim() > 1]
    for i_layer, (name, param) in enumerate(named_conv_weights):
        param_clone = param.detach().clone()
        accuracy = []
        for sparsity in tqdm(sparsities, desc=f'scanning {i_layer}/{len(named_conv_weights)} weight - {name}'):
            fine_grained_prune(param.detach(), sparsity=sparsity)
            acc = evaluate(model, dataloader, verbose=False)
            if verbose:
                print(f'\r    sparsity={sparsity:.2f}: accuracy={acc:.2f}%', end='')
            # restore
            param.copy_(param_clone)
            accuracy.append(acc)
        if verbose:
            print(f'\r    sparsity=[{",".join(["{:.2f}".format(x) for x in sparsities])}]: accuracy=[{", ".join(["{:.2f}%".format(x) for x in accuracy])}]', end='')
        accuracies.append(accuracy)
    return sparsities, accuracies

๋‹ค์Œ ์…€๋“ค์„ ์‹คํ–‰ํ•˜์—ฌ sensitivity curves๋ฅผ ๊ทธ๋ ค์ฃผ์„ธ์š”. ์™„๋ฃŒํ•˜๋Š” ๋ฐ ์•ฝ 2๋ถ„ ์ •๋„ ๊ฑธ๋ฆด ๊ฒƒ์ž…๋‹ˆ๋‹ค.

sparsities, accuracies = sensitivity_scan(
    model, dataloader['test'], scan_step=0.1, scan_start=0.4, scan_end=1.0)
    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.42%, 91.19%, 87.55%, 83.39%, 69.41%, 31.81%]    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.93%, 92.88%, 92.71%, 92.40%, 91.32%, 84.78%]    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.94%, 92.64%, 92.46%, 91.77%, 89.85%, 78.56%]    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.86%, 92.72%, 92.23%, 91.09%, 85.35%, 51.31%]    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.88%, 92.68%, 92.22%, 89.47%, 76.86%, 38.78%]    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.92%, 92.71%, 92.63%, 91.88%, 89.90%, 82.19%]    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.94%, 92.86%, 92.65%, 92.10%, 90.58%, 83.65%]    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.94%, 92.92%, 92.88%, 92.81%, 92.63%, 91.34%]    sparsity=[0.40,0.50,0.60,0.70,0.80,0.90]: accuracy=[92.91%, 92.83%, 92.81%, 92.97%, 92.68%, 92.52%]
def plot_sensitivity_scan(sparsities, accuracies, dense_model_accuracy):
    lower_bound_accuracy = 100 - (100 - dense_model_accuracy) * 1.5
    fig, axes = plt.subplots(3, int(math.ceil(len(accuracies) / 3)),figsize=(15,8))
    axes = axes.ravel()
    plot_index = 0
    for name, param in model.named_parameters():
        if param.dim() > 1:
            ax = axes[plot_index]
            curve = ax.plot(sparsities, accuracies[plot_index])
            line = ax.plot(sparsities, [lower_bound_accuracy] * len(sparsities))
            ax.set_xticks(np.arange(start=0.4, stop=1.0, step=0.1))
            ax.set_ylim(80, 95)
            ax.set_title(name)
            ax.set_xlabel('sparsity')
            ax.set_ylabel('top-1 accuracy')
            ax.legend([
                'accuracy after pruning',
                f'{lower_bound_accuracy / dense_model_accuracy * 100:.0f}% of dense model accuracy'
            ])
            ax.grid(axis='x')
            plot_index += 1
    fig.suptitle('Sensitivity Curves: Validation Accuracy vs. Pruning Sparsity')
    fig.tight_layout()
    fig.subplots_adjust(top=0.925)
    plt.show()

plot_sensitivity_scan(sparsities, accuracies, dense_model_accuracy)

Question 4 (15 pts)

์œ„ sensitivity curves์˜ ์ •๋ณด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์Œ ์งˆ๋ฌธ์— ๋‹ตํ•ด ์ฃผ์„ธ์š”.

Question 4.1 (5 pts)

pruning sparsity์™€ ๋ชจ๋ธ ์ •ํ™•๋„ ์‚ฌ์ด์˜ ๊ด€๊ณ„๋Š” ๋ฌด์—‡์ธ๊ฐ€์š”? (์ฆ‰, sparsity๊ฐ€ ๋†’์•„์งˆ ๋•Œ ์ •ํ™•๋„๊ฐ€ ์ฆ๊ฐ€ํ•˜๋‚˜์š”, ์•„๋‹ˆ๋ฉด ๊ฐ์†Œํ•˜๋‚˜์š”?)

Your Answer:

pruning sparsity๊ฐ€ ๋†’์•„์งˆ ์ˆ˜๋ก, model accuracy๋Š” ๊ฐ์†Œํ•˜๋Š” ๊ฒฝํ–ฅ์„ ๋ณด์ธ๋‹ค

Question 4.2 (5 pts)

๋ชจ๋“  ๋ ˆ์ด์–ด๊ฐ€ ๊ฐ™์€ sensitivity๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‚˜์š”?

Your Answer:

์–ด๋–ค ๋ ˆ์ด์–ด๋Š” sensitiveํ•˜์ง€ ์•Š๊ณ (classifier), ์–ด๋–ค ๋ ˆ์ด์–ด๋Š” sensitiveํ•˜๋‹ค(conv0) ๋Œ€์ฒด๋กœ, ์•ž์ชฝ ๋ ˆ์ด์–ด(0~1..)์ด ๋ฏผ๊ฐํ•ด๋ณด์ธ๋‹ค

Question 4.3 (5 pts)

์–ด๋–ค ๋ ˆ์ด์–ด๊ฐ€ pruning sparsity์— ๊ฐ€์žฅ ๋ฏผ๊ฐํ•œ๊ฐ€์š”?

Your Answer:

conv0 layer

#Parameters of each layer

์ •ํ™•๋„๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๊ฐ ๋ ˆ์ด์–ด์˜ ๋งค๊ฐœ๋ณ€์ˆ˜(parameter) ์ˆ˜๋„ sparsity ์„ ํƒ์— ์˜ํ–ฅ์„ ๋ฏธ์นฉ๋‹ˆ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ๋” ๋งŽ์€ ๋ ˆ์ด์–ด๋Š” ๋” ํฐ sparsities๋ฅผ ์š”๊ตฌํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ ์ฝ”๋“œ ์…€์„ ์‹คํ–‰ํ•˜์—ฌ ์ „์ฒด ๋ชจ๋ธ์—์„œ #parameters์˜ ๋ถ„ํฌ๋ฅผ ๊ทธ๋ ค์ฃผ์„ธ์š”.

def plot_num_parameters_distribution(model):
    num_parameters = dict()
    for name, param in model.named_parameters():
        if param.dim() > 1:
            num_parameters[name] = param.numel()
    fig = plt.figure(figsize=(8, 6))
    plt.grid(axis='y')
    plt.bar(list(num_parameters.keys()), list(num_parameters.values()))
    plt.title('#Parameter Distribution')
    plt.ylabel('Number of Parameters')
    plt.xticks(rotation=60)
    plt.tight_layout()
    plt.show()

plot_num_parameters_distribution(model)

Sensitivity Curves์™€ #Parameters ๋ถ„ํฌ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ Sparsity ์„ ํƒํ•˜๊ธฐ

Question 5 (10 pts)

sensitivity curves์™€ ๋ชจ๋ธ์˜ #parameters ๋ถ„ํฌ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฐ ๋ ˆ์ด์–ด์˜ sparsity๋ฅผ ์„ ํƒํ•ด ์ฃผ์„ธ์š”.

pruned ๋ชจ๋ธ์˜ ์ „์ฒด ์••์ถ• ๋น„์œจ์€ ๋Œ€์ฒด๋กœ #parameters๊ฐ€ ํฐ ๋ ˆ์ด์–ด์— ์ฃผ๋กœ ์˜์กดํ•˜๋ฉฐ, ๋‹ค๋ฅธ ๋ ˆ์ด์–ด๋Š” pruning์— ๋Œ€ํ•œ sensitivity๊ฐ€ ๋‹ค๋ฆ…๋‹ˆ๋‹ค(Question 4 ์ฐธ์กฐ).

pruning ํ›„์— sparse ๋ชจ๋ธ์ด dense ๋ชจ๋ธ์˜ ํฌ๊ธฐ์˜ 25%์ด๋ฉฐ, finetuning ํ›„์— ๊ฒ€์ฆ ์ •ํ™•๋„๊ฐ€ 92.5% ์ด์ƒ์ธ์ง€ ํ™•์ธํ•˜์„ธ์š”.

Hint:

  • #parameters๊ฐ€ ๋” ๋งŽ์€ ๋ ˆ์ด์–ด๋Š” ๋” ํฐ sparsity๋ฅผ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค. (Figure #Parameter Distribution ์ฐธ์กฐ)
  • pruning sparsity์— ๋ฏผ๊ฐํ•œ ๋ ˆ์ด์–ด(์ฆ‰, sparsity๊ฐ€ ๋†’์•„์งˆ์ˆ˜๋ก ์ •ํ™•๋„๊ฐ€ ๋น ๋ฅด๊ฒŒ ๋–จ์–ด์ง€๋Š” ๋ ˆ์ด์–ด)๋Š” ๋” ์ž‘์€ sparsity๋ฅผ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค. (Figure Sensitivity Curves ์ฐธ์กฐ)
recover_model()

sparsity_dict = {
##################### YOUR CODE STARTS HERE #####################
    # please modify the sparsity value of each layer
    # please DO NOT modify the key of sparsity_dict
    'backbone.conv0.weight': 0,
    'backbone.conv1.weight': 0.5,
    'backbone.conv2.weight': 0.5,
    'backbone.conv3.weight': 0.5,
    'backbone.conv4.weight': 0.5,
    'backbone.conv5.weight': 0.8,
    'backbone.conv6.weight': 0.8,
    'backbone.conv7.weight': 0.9,
    'classifier.weight': 0
##################### YOUR CODE ENDS HERE #######################
}

์ •์˜๋œ sparsity_dict์— ๋”ฐ๋ผ ๋ชจ๋ธ์„ pruneํ•˜๊ณ  sparse ๋ชจ๋ธ์˜ ์ •๋ณด๋ฅผ ์ถœ๋ ฅํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ์…€์„ ์‹คํ–‰ํ•ด ์ฃผ์„ธ์š”.

pruner = FineGrainedPruner(model, sparsity_dict)
print(f'After pruning with sparsity dictionary')
for name, sparsity in sparsity_dict.items():
    print(f'  {name}: {sparsity:.2f}')
print(f'The sparsity of each layer becomes')
for name, param in model.named_parameters():
    if name in sparsity_dict:
        print(f'  {name}: {get_sparsity(param):.2f}')

sparse_model_size = get_model_size(model, count_nonzero_only=True)
print(f"Sparse model has size={sparse_model_size / MiB:.2f} MiB = {sparse_model_size / dense_model_size * 100:.2f}% of dense model size")
sparse_model_accuracy = evaluate(model, dataloader['test'])
print(f"Sparse model has accuracy={sparse_model_accuracy:.2f}% before fintuning")

plot_weight_distribution(model, count_nonzero_only=True)
After pruning with sparsity dictionary
  backbone.conv0.weight: 0.00
  backbone.conv1.weight: 0.50
  backbone.conv2.weight: 0.50
  backbone.conv3.weight: 0.50
  backbone.conv4.weight: 0.50
  backbone.conv5.weight: 0.80
  backbone.conv6.weight: 0.80
  backbone.conv7.weight: 0.90
  classifier.weight: 0.00
The sparsity of each layer becomes
  backbone.conv0.weight: 0.00
  backbone.conv1.weight: 0.50
  backbone.conv2.weight: 0.50
  backbone.conv3.weight: 0.50
  backbone.conv4.weight: 0.50
  backbone.conv5.weight: 0.80
  backbone.conv6.weight: 0.80
  backbone.conv7.weight: 0.90
  classifier.weight: 0.00
Sparse model has size=8.63 MiB = 24.50% of dense model size
Sparse model has accuracy=87.00% before fintuning

Finetune the fine-grained pruned model

์ด์ „ ์…€์˜ ์ถœ๋ ฅ์—์„œ ๋ณผ ์ˆ˜ ์žˆ๋“ฏ์ด, fine-grained pruning์ด ๋ชจ๋ธ ๊ฐ€์ค‘์น˜์˜ ๋Œ€๋ถ€๋ถ„์„ ์ค„์ด์ง€๋งŒ ๋ชจ๋ธ์˜ ์ •ํ™•๋„๋„ ๋–จ์–ด์กŒ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, sparse ๋ชจ๋ธ์˜ ์ •ํ™•๋„๋ฅผ ํšŒ๋ณตํ•˜๊ธฐ ์œ„ํ•ด finetuneํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

sparse ๋ชจ๋ธ์„ finetuneํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ์…€์„ ์‹คํ–‰ํ•ด ์ฃผ์„ธ์š”. ์™„๋ฃŒํ•˜๋Š” ๋ฐ ์•ฝ 3๋ถ„ ์ •๋„ ๊ฑธ๋ฆด ๊ฒƒ์ž…๋‹ˆ๋‹ค.

num_finetune_epochs = 5
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_sparse_model_checkpoint = dict()
best_accuracy = 0
print(f'Finetuning Fine-grained Pruned Sparse Model')
for epoch in range(num_finetune_epochs):
    # At the end of each train iteration, we have to apply the pruning mask
    #    to keep the model sparse during the training
    train(model, dataloader['train'], criterion, optimizer, scheduler,
          callbacks=[lambda: pruner.apply(model)])
    accuracy = evaluate(model, dataloader['test'])
    is_best = accuracy > best_accuracy
    if is_best:
        best_sparse_model_checkpoint['state_dict'] = copy.deepcopy(model.state_dict())
        best_accuracy = accuracy
    print(f'    Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')
Finetuning Fine-grained Pruned Sparse Model
    Epoch 1 Accuracy 92.66% / Best Accuracy: 92.66%
    Epoch 2 Accuracy 92.77% / Best Accuracy: 92.77%
    Epoch 3 Accuracy 92.80% / Best Accuracy: 92.80%
    Epoch 4 Accuracy 92.68% / Best Accuracy: 92.80%
    Epoch 5 Accuracy 92.77% / Best Accuracy: 92.80%

best finetuned sparse ๋ชจ๋ธ์˜ ์ •๋ณด๋ฅผ ๋ณด๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ์…€์„ ์‹คํ–‰ํ•ด ์ฃผ์„ธ์š”.

# load the best sparse model checkpoint to evaluate the final performance
model.load_state_dict(best_sparse_model_checkpoint['state_dict'])
sparse_model_size = get_model_size(model, count_nonzero_only=True)
print(f"Sparse model has size={sparse_model_size / MiB:.2f} MiB = {sparse_model_size / dense_model_size * 100:.2f}% of dense model size")
sparse_model_accuracy = evaluate(model, dataloader['test'])
print(f"Sparse model has accuracy={sparse_model_accuracy:.2f}% after fintuning")
Sparse model has size=8.63 MiB = 24.50% of dense model size
Sparse model has accuracy=92.80% after fintuning

Channel Pruning

์ด ์„น์…˜์—์„œ๋Š” channel pruning์„ ๊ตฌํ˜„ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. Channel pruning์€ ์ „์ฒด ์ฑ„๋„์„ ์ œ๊ฑฐํ•˜์—ฌ ๊ธฐ์กด ํ•˜๋“œ์›จ์–ด(์˜ˆ: GPU)์—์„œ ์ถ”๋ก  ์†๋„๋ฅผ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, ๋” ์ž‘์€ ํฌ๊ธฐ(Frobenius norm์œผ๋กœ ์ธก์ •)์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ง„ ์ฑ„๋„์„ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค.

# firstly, let's restore the model weights to the original dense version
#   and check the validation accuracy
recover_model()
dense_model_accuracy = evaluate(model, dataloader['test'])
print(f"dense model has accuracy={dense_model_accuracy:.2f}%")
dense model has accuracy=92.95%

Remove Channel Weights

Fine-grained pruning๊ณผ ๋‹ฌ๋ฆฌ, channel pruning์—์„œ๋Š” ํ…์„œ์—์„œ ๊ฐ€์ค‘์น˜๋ฅผ ์™„์ „ํžˆ ์ œ๊ฑฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฆ‰, ์ถœ๋ ฅ ์ฑ„๋„์˜ ์ˆ˜๊ฐ€ ์ค„์–ด๋“ญ๋‹ˆ๋‹ค:

\(\#\mathrm{out\_channels}_{\mathrm{new}} = \#\mathrm{out\_channels}_{\mathrm{origin}} \cdot (1 - \mathrm{sparsity})\)

Channel pruning ํ›„์—๋„ ๊ฐ€์ค‘์น˜ ํ…์„œ \(W\)๋Š” ์—ฌ์ „ํžˆ denseํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, sparsity๋ฅผ prune ratio๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

Fine-grained pruning์ฒ˜๋Ÿผ, ๋‹ค๋ฅธ ๋ ˆ์ด์–ด์— ๋Œ€ํ•ด ๋‹ค๋ฅธ pruning ๋น„์œจ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ์ง€๊ธˆ์€ ๋ชจ๋“  ๋ ˆ์ด์–ด์— ๋Œ€ํ•ด ๊ท ์ผํ•œ pruning ๋น„์œจ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ๋Œ€๋žต 30%์˜ ๊ท ์ผํ•œ pruning ๋น„์œจ๋กœ 2๋ฐฐ์˜ ๊ณ„์‚ฐ ๊ฐ์†Œ๋ฅผ ๋ชฉํ‘œ๋กœ ํ•ฉ๋‹ˆ๋‹ค(์™œ ๊ทธ๋Ÿฐ์ง€ ์ƒ๊ฐํ•ด ๋ณด์„ธ์š”).

์ด ์„น์…˜์˜ ๋์—์„œ ๋ ˆ์ด์–ด๋ณ„๋กœ ๋‹ค๋ฅธ pruning ๋น„์œจ์„ ์‹œ๋„ํ•ด ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. channel_prune ํ•จ์ˆ˜์— ๋น„์œจ ๋ฆฌ์ŠคํŠธ๋ฅผ ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Question 6 (10 pts)

Channel pruning์„ ์œ„ํ•œ ๋‹ค์Œ ํ•จ์ˆ˜๋ฅผ ์™„์„ฑํ•ด ์ฃผ์„ธ์š”.

์—ฌ๊ธฐ์„œ ์šฐ๋ฆฌ๋Š” ์ฒซ ๋ฒˆ์งธ \(\#\mathrm{out\_channels}_{\mathrm{new}}\) ์ฑ„๋„์„ ์ œ์™ธํ•œ ๋ชจ๋“  ์ถœ๋ ฅ ์ฑ„๋„์„ ๋‹จ์ˆœํžˆ pruneํ•ฉ๋‹ˆ๋‹ค.

def get_num_channels_to_keep(channels: int, prune_ratio: float) -> int:
    """A function to calculate the number of layers to PRESERVE after pruning
    Note that preserve_rate = 1. - prune_ratio
    """
    ##################### YOUR CODE STARTS HERE #####################
    return int(round((1-prune_ratio)*channels))
    ##################### YOUR CODE ENDS HERE #####################

@torch.no_grad()
def channel_prune(model: nn.Module,
                  prune_ratio: Union[List, float]) -> nn.Module:
    """Apply channel pruning to each of the conv layer in the backbone
    Note that for prune_ratio, we can either provide a floating-point number,
    indicating that we use a uniform pruning rate for all layers, or a list of
    numbers to indicate per-layer pruning rate.
    """
    # sanity check of provided prune_ratio
    assert isinstance(prune_ratio, (float, list))
    n_conv = len([m for m in model.backbone if isinstance(m, nn.Conv2d)])
    # note that for the ratios, it affects the previous conv output and next
    # conv input, i.e., conv0 - ratio0 - conv1 - ratio1-...
    if isinstance(prune_ratio, list):
        assert len(prune_ratio) == n_conv - 1
    else:  # convert float to list
        prune_ratio = [prune_ratio] * (n_conv - 1)

    # we prune the convs in the backbone with a uniform ratio
    model = copy.deepcopy(model)  # prevent overwrite
    # we only apply pruning to the backbone features
    all_convs = [m for m in model.backbone if isinstance(m, nn.Conv2d)]
    all_bns = [m for m in model.backbone if isinstance(m, nn.BatchNorm2d)]
    # apply pruning. we naively keep the first k channels
    assert len(all_convs) == len(all_bns)
    for i_ratio, p_ratio in enumerate(prune_ratio):
        prev_conv = all_convs[i_ratio]
        prev_bn = all_bns[i_ratio]
        next_conv = all_convs[i_ratio + 1]
        original_channels = prev_conv.out_channels  # same as next_conv.in_channels
        n_keep = get_num_channels_to_keep(original_channels, p_ratio)

        # prune the output of the previous conv and bn
        prev_conv.weight.set_(prev_conv.weight.detach()[:n_keep])
        prev_bn.weight.set_(prev_bn.weight.detach()[:n_keep])
        prev_bn.bias.set_(prev_bn.bias.detach()[:n_keep])
        prev_bn.running_mean.set_(prev_bn.running_mean.detach()[:n_keep])
        prev_bn.running_var.set_(prev_bn.running_var.detach()[:n_keep])

        # prune the input of the next conv (hint: just one line of code)
        ##################### YOUR CODE STARTS HERE #####################
        next_conv.weight.set_(next_conv.weight.detach()[:, :n_keep])
        ##################### YOUR CODE ENDS HERE #####################

    return model

๊ตฌํ˜„์ด ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ์…€์„ ์‹คํ–‰ํ•˜์—ฌ ํ™•์ธํ•˜์„ธ์š”.

dummy_input = torch.randn(1, 3, 32, 32).cuda()
pruned_model = channel_prune(model, prune_ratio=0.3)
pruned_macs = get_model_macs(pruned_model, dummy_input)
assert pruned_macs == 305388064
print('* Check passed. Right MACs for the pruned model.')
* Check passed. Right MACs for the pruned model.

์ด์ œ 30% pruning ๋น„์œจ์„ ๊ฐ€์ง„ ๊ท ์ผ channel pruning ํ›„ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•ด ๋ด…์‹œ๋‹ค.

์ง์ ‘์ ์œผ๋กœ 30%์˜ ์ฑ„๋„์„ ์ œ๊ฑฐํ•˜๋Š” ๊ฒƒ์€ ๋‚ฎ์€ ์ •ํ™•๋„๋กœ ์ด์–ด์ง‘๋‹ˆ๋‹ค.

pruned_model_accuracy = evaluate(pruned_model, dataloader['test'])
print(f"pruned model has accuracy={pruned_model_accuracy:.2f}%")
pruned model has accuracy=28.14%

Ranking Channels by Importance

๋ณด์‹œ๋‹ค์‹œํ”ผ, ๋ชจ๋“  ๋ ˆ์ด์–ด์—์„œ ์ฒซ 30%์˜ ์ฑ„๋„์„ ์ œ๊ฑฐํ•˜๋ฉด ์ •ํ™•๋„๊ฐ€ ํฌ๊ฒŒ ๊ฐ์†Œํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ํ•œ ๊ฐ€์ง€ ๊ฐ€๋Šฅํ•œ ๋ฐฉ๋ฒ•์€ ๋œ ์ค‘์š”ํ•œ ์ฑ„๋„ ๊ฐ€์ค‘์น˜๋ฅผ ์ฐพ์•„์„œ ์ œ๊ฑฐํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ค‘์š”๋„๋ฅผ ํ‰๊ฐ€ํ•˜๋Š” ์ธ๊ธฐ ์žˆ๋Š” ๊ธฐ์ค€์€ ๊ฐ ์ž…๋ ฅ ์ฑ„๋„์— ํ•ด๋‹นํ•˜๋Š” ๊ฐ€์ค‘์น˜์˜ Frobenius norm์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค:

\(importance_{i} = \|W_{i}\|_2, \;\; i = 0, 1, 2,\cdots, \#\mathrm{in\_channels}-1\)

์šฐ๋ฆฌ๋Š” ์ฑ„๋„ ๊ฐ€์ค‘์น˜๋ฅผ ๋” ์ค‘์š”ํ•œ ๊ฒƒ์—์„œ ๋œ ์ค‘์š”ํ•œ ๊ฒƒ์œผ๋กœ ์ •๋ ฌํ•œ ๋‹ค์Œ, ๊ฐ ๋ ˆ์ด์–ด์— ๋Œ€ํ•ด ์ฒ˜์Œ \(k\)๊ฐœ์˜ ์ฑ„๋„์„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Question 7 (15 pts)

Frobenius norm์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ฐ€์ค‘์น˜ ํ…์„œ๋ฅผ ์ •๋ ฌํ•˜๋Š” ๋‹ค์Œ ํ•จ์ˆ˜๋ฅผ ์™„์„ฑํ•ด ์ฃผ์„ธ์š”.

Hint:

  • ํ…์„œ์˜ Frobenius norm์„ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด, Pytorch๋Š” torch.norm API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
# function to sort the channels from important to non-important
def get_input_channel_importance(weight):
    in_channels = weight.shape[1]
    importances = []
    # compute the importance for each input channel
    for i_c in range(weight.shape[1]):
        channel_weight = weight.detach()[:, i_c]
        ##################### YOUR CODE STARTS HERE #####################
        importance = torch.norm(channel_weight, p=2)
        ##################### YOUR CODE ENDS HERE #####################
        importances.append(importance.view(1))
    return torch.cat(importances)

@torch.no_grad()
def apply_channel_sorting(model):
    model = copy.deepcopy(model)  # do not modify the original model
    # fetch all the conv and bn layers from the backbone
    all_convs = [m for m in model.backbone if isinstance(m, nn.Conv2d)]
    all_bns = [m for m in model.backbone if isinstance(m, nn.BatchNorm2d)]
    # iterate through conv layers
    for i_conv in range(len(all_convs) - 1):
        # each channel sorting index, we need to apply it to:
        # - the output dimension of the previous conv
        # - the previous BN layer
        # - the input dimension of the next conv (we compute importance here)
        prev_conv = all_convs[i_conv]
        prev_bn = all_bns[i_conv]
        next_conv = all_convs[i_conv + 1]
        # note that we always compute the importance according to input channels
        importance = get_input_channel_importance(next_conv.weight)
        # sorting from large to small
        sort_idx = torch.argsort(importance, descending=True)

        # apply to previous conv and its following bn
        prev_conv.weight.copy_(torch.index_select(
            prev_conv.weight.detach(), 0, sort_idx))
        for tensor_name in ['weight', 'bias', 'running_mean', 'running_var']:
            tensor_to_apply = getattr(prev_bn, tensor_name)
            tensor_to_apply.copy_(
                torch.index_select(tensor_to_apply.detach(), 0, sort_idx)
            )

        # apply to the next conv input (hint: one line of code)
        ##################### YOUR CODE STARTS HERE #####################
        next_conv.weight.copy_(torch.index_select(next_conv.weight.detach(), 1, sort_idx))
        ##################### YOUR CODE ENDS HERE #####################

    return model

์ด์ œ ๋‹ค์Œ ์…€์„ ์‹คํ–‰ํ•˜์—ฌ ๊ฒฐ๊ณผ๊ฐ€ ์˜ฌ๋ฐ”๋ฅธ์ง€ ํ™•์ธํ•˜์„ธ์š”.

print('Before sorting...')
dense_model_accuracy = evaluate(model, dataloader['test'])
print(f"dense model has accuracy={dense_model_accuracy:.2f}%")

print('After sorting...')
sorted_model = apply_channel_sorting(model)
sorted_model_accuracy = evaluate(sorted_model, dataloader['test'])
print(f"sorted model has accuracy={sorted_model_accuracy:.2f}%")

# make sure accuracy does not change after sorting, since it is
# equivalent transform
assert abs(sorted_model_accuracy - dense_model_accuracy) < 0.1
print('* Check passed.')
Before sorting...
dense model has accuracy=92.95%
After sorting...
sorted model has accuracy=92.95%
* Check passed.

๋งˆ์ง€๋ง‰์œผ๋กœ ํ”„๋ฃจ๋‹๋œ ๋ชจ๋ธ์˜ ์ •ํ™•๋„๋ฅผ ์ •๋ ฌํ•  ๋•Œ์™€ ๊ทธ๋ ‡์ง€ ์•Š์„ ๋•Œ๋ฅผ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค.

channel_pruning_ratio = 0.3  # pruned-out ratio

print(" * Without sorting...")
pruned_model = channel_prune(model, channel_pruning_ratio)
pruned_model_accuracy = evaluate(pruned_model, dataloader['test'])
print(f"pruned model has accuracy={pruned_model_accuracy:.2f}%")


print(" * With sorting...")
sorted_model = apply_channel_sorting(model)
pruned_model = channel_prune(sorted_model, channel_pruning_ratio)
pruned_model_accuracy = evaluate(pruned_model, dataloader['test'])
print(f"pruned model has accuracy={pruned_model_accuracy:.2f}%")
 * Without sorting...
pruned model has accuracy=28.14%
 * With sorting...
pruned model has accuracy=36.81%

๋ณด์‹œ๋‹ค์‹œํ”ผ channel sorting์€ pruned model์˜ ์ •ํ™•๋„๋ฅผ ์•ฝ๊ฐ„ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์ง€๋งŒ ์—ฌ์ „ํžˆ channel pruning์— ๋งค์šฐ ์ผ๋ฐ˜์ ์ธ ํฐ ์ €ํ•˜๊ฐ€ ์žˆ๋Š” ๊ฑธ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ •ํ™•๋„ ์ €ํ•˜๋ฅผ ํšŒ๋ณตํ•˜๊ธฐ ์œ„ํ•ด fine-tuning์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

num_finetune_epochs = 5
optimizer = torch.optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()

best_accuracy = 0
for epoch in range(num_finetune_epochs):
    train(pruned_model, dataloader['train'], criterion, optimizer, scheduler)
    accuracy = evaluate(pruned_model, dataloader['test'])
    is_best = accuracy > best_accuracy
    if is_best:
        best_accuracy = accuracy
    print(f'Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')
Epoch 1 Accuracy 91.66% / Best Accuracy: 91.66%
Epoch 2 Accuracy 92.10% / Best Accuracy: 92.10%
Epoch 3 Accuracy 92.01% / Best Accuracy: 92.10%
Epoch 4 Accuracy 92.18% / Best Accuracy: 92.18%
Epoch 5 Accuracy 92.16% / Best Accuracy: 92.18%

Measure acceleration from pruning

fine-tuning์ด ๋๋‚˜๋ฉด ๋ชจ๋ธ์€ ์ •ํ™•๋„๋ฅผ ๊ฑฐ์˜ ํšŒ๋ณตํ•ฉ๋‹ˆ๋‹ค. channel pruning๋Š” fine-grained pruning์— ๋น„ํ•ด ์ผ๋ฐ˜์ ์œผ๋กœ ์ •ํ™•๋„๋ฅผ ํšŒ๋ณตํ•˜๊ธฐ๊ฐ€ ๋” ์–ด๋ ต๋‹ค๋Š” ๊ฒƒ์„ ์ด๋ฏธ ์•Œ๊ณ  ๊ณ„์‹ค ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ specialized model format์ด ์—†์œผ๋ฉด ์ง์ ‘์ ์œผ๋กœ ๋ชจ๋ธ ํฌ๊ธฐ๊ฐ€ ์ž‘์•„์ง€๊ณ  ๊ณ„์‚ฐ์ด ์ž‘์•„์ง‘๋‹ˆ๋‹ค. GPU์—์„œ๋„ ๋” ๋น ๋ฅด๊ฒŒ ์‹คํ–‰๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด์ œ pruned model์˜ ๋ชจ๋ธ ํฌ๊ธฐ, ๊ณ„์‚ฐ ๋ฐ ์ง€์—ฐ ์‹œ๊ฐ„์„ ๋น„๊ตํ•ด๋ด…์‹œ๋‹ค.

# helper functions to measure latency of a regular PyTorch models.
#   Unlike fine-grained pruning, channel pruning
#   can directly leads to model size reduction and speed up.
@torch.no_grad()
def measure_latency(model, dummy_input, n_warmup=20, n_test=100):
    model.eval()
    # warmup
    for _ in range(n_warmup):
        _ = model(dummy_input)
    # real test
    t1 = time.time()
    for _ in range(n_test):
        _ = model(dummy_input)
    t2 = time.time()
    return (t2 - t1) / n_test  # average latency

table_template = "{:<15} {:<15} {:<15} {:<15}"
print (table_template.format('', 'Original','Pruned','Reduction Ratio'))

# 1. measure the latency of the original model and the pruned model on CPU
#   which simulates inference on an edge device
dummy_input = torch.randn(1, 3, 32, 32).to('cpu')
pruned_model = pruned_model.to('cpu')
model = model.to('cpu')

pruned_latency = measure_latency(pruned_model, dummy_input)
original_latency = measure_latency(model, dummy_input)
print(table_template.format('Latency (ms)',
                            round(original_latency * 1000, 1),
                            round(pruned_latency * 1000, 1),
                            round(original_latency / pruned_latency, 1)))

# 2. measure the computation (MACs)
original_macs = get_model_macs(model, dummy_input)
pruned_macs = get_model_macs(pruned_model, dummy_input)
print(table_template.format('MACs (M)',
                            round(original_macs / 1e6),
                            round(pruned_macs / 1e6),
                            round(original_macs / pruned_macs, 1)))

# 3. measure the model size (params)
original_param = get_num_parameters(model)
pruned_param = get_num_parameters(pruned_model)
print(table_template.format('Param (M)',
                            round(original_param / 1e6, 2),
                            round(pruned_param / 1e6, 2),
                            round(original_param / pruned_param, 1)))

# put model back to cuda
pruned_model = pruned_model.to('cuda')
model = model.to('cuda')
                Original        Pruned          Reduction Ratio
Latency (ms)    24.2            13.0            1.9            
MACs (M)        606             305             2.0            
Param (M)       9.23            5.01            1.8            

Question 8 (10 pts)

์ด์ „ ์ฝ”๋“œ์…€์˜ ์ •๋ณด๋ฅผ ์ด์šฉํ•˜์—ฌ ๋‹ค์Œ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•ด ์ฃผ์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

Question 8.1 (5 pts)

30%์˜ ์ฑ„๋„์„ ์ œ๊ฑฐํ•˜๋ฉด ๋Œ€๋žต 50%์˜ ๊ณ„์‚ฐ ์ ˆ๊ฐ ํšจ๊ณผ๊ฐ€ ๋ฐœ์ƒํ•˜๋Š” ์ด์œ ๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”.

Your Answer:

MAC์€ 2๋ฐฐ, Param์€ 1.8๋ฐฐ ์ค„์–ด๋“ค์—ˆ์ง€๋งŒ, latency๋Š” 1.7๋ฐฐ๋งŒ ๋” ๋นจ๋ผ์กŒ๋‹ค ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ จ๋œ ์ด์œ ๋ผ๊ณ  ์ถ”์ •๋จ

Question 8.2 (5 pts)

์ง€์—ฐ ์‹œ๊ฐ„ ๊ฐ์†Œ ๋น„์œจ(latency reduction ratio)์ด ๊ณ„์‚ฐ ๊ฐ์†Œ(computation reduction)๋ณด๋‹ค ์•ฝ๊ฐ„ ์ž‘์€ ์ด์œ ๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”.

Your Answer:

0.7^2 = 0.49 ๋‹ˆ๊นŒ, ์–ผ์ถ” 2๋ฐฐ. ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ์ค„์–ด๋“ค์ˆ˜๋ก latency๋Š” quadraticํ•˜๊ฒŒ ์ค„์–ด๋“ ๋‹ค


Compare Fine-grained Pruning and Channel Pruning

Question 9 (10 pts)

์ด๋ฒˆ ๋žฉ์—์„œ ๋ชจ๋“  ์‹คํ—˜์„ ํ•œ ํ›„์—๋Š” fine-grained pruning์™€ channel pruning์— ์ต์ˆ™ํ•ด์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

lecture์™€ ์ด๋ฒˆ lab์—์„œ ๋ฐฐ์šด ๋‚ด์šฉ์„ ํ™œ์šฉํ•˜์—ฌ ๋‹ค์Œ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•ด ์ฃผ์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

Question 9.1 (5 pts)

fine-grained pruning์™€ channel pruning์˜ ์žฅ๋‹จ์ ์€ ๋ฌด์—‡์ž…๋‹ˆ๊นŒ?

compression ratio, accuracy, latency, hardware support(i.e., ์ „๋ฌธ ํ•˜๋“œ์›จ์–ด ๊ฐ€์†๊ธฐ ํ•„์š”) ๋“ฑ์˜ ๊ด€์ ์—์„œ ๋…ผ์˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Your Answer:

  1. fine-grained
  • ์žฅ์ 
    • ์ •ํ™•๋„๊ฐ€ ๋†’์Œ
    • Usually larger compression ratio since we can flexibly find โ€œredundantโ€ weights
  • ๋‹จ์ 
    • cpu overhead
    • memory overhead
    • hardware support ํ•„์š”(eieโ€ฆ)
  1. channel pruning
  • ์žฅ์ 
    • ๋น ๋ฅธ inference
  • ๋‹จ์ 
    • smaller compression ratio

Question 9.2 (5 pts)

์Šค๋งˆํŠธํฐ์—์„œ ๋ชจ๋ธ์„ ๋” ๋นจ๋ฆฌ ์‹คํ–‰์‹œํ‚ค๊ณ  ์‹ถ๋‹ค๋ฉด, ์–ด๋–ค ๊ฐ€์ง€์น˜๊ธฐ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•  ๊ฒƒ์ธ๊ฐ€์š”? ๊ทธ ์ด์œ ๋Š” ๋ฌด์—‡์ธ๊ฐ€์š”?

Your Answer:

ํŠน๋ณ„ํ•œ ํ•˜๋“œ์›จ์–ด ์„œํฌํŠธ๊ฐ€ ํ•„์š”ํ•˜์ง€์•Š๊ณ , inference time์ด ๋น ๋ฅธ channel pruning.