fp32_model_accuracy = evaluate(model, dataloader['test'])fp32_model_size = get_model_size(model)print(f"fp32 model has accuracy={fp32_model_accuracy:.2f}%")print(f"fp32 model has size={fp32_model_size/MiB:.2f} MiB")
fp32 model has accuracy=92.95%
fp32 model has size=35.20 MiB
from fast_pytorch_kmeans import KMeansdef k_means_quantize(fp32_tensor: torch.Tensor, bitwidth=4, codebook=None):""" quantize tensor using k-means clustering :param fp32_tensor: :param bitwidth: [int] quantization bit width, default=4 :param codebook: [Codebook] (the cluster centroids, the cluster label tensor) :return: [Codebook = (centroids, labels)] centroids: [torch.(cuda.)FloatTensor] the cluster centroids labels: [torch.(cuda.)LongTensor] cluster label tensor """if codebook isNone:############### YOUR CODE STARTS HERE ################ get number of clusters based on the quantization precision n_clusters =2** bitwidth # Calculate number of clusters as 2^bitwidth############### YOUR CODE ENDS HERE ################## use k-means to get the quantization centroids kmeans = KMeans(n_clusters=n_clusters, mode='euclidean', verbose=0) labels = kmeans.fit_predict(fp32_tensor.view(-1, 1)).to(torch.long) centroids = kmeans.centroids.to(torch.float).view(-1) codebook = Codebook(centroids, labels)############### YOUR CODE STARTS HERE ################ decode the codebook into k-means quantized tensor for inference# hint: one line of code quantized_tensor = codebook.centroids[codebook.labels].view_as(fp32_tensor)############### YOUR CODE ENDS HERE ################# fp32_tensor.set_(quantized_tensor.view_as(fp32_tensor))return codebook
print('Note that the storage for codebooks is ignored when calculating the model size.')quantizers =dict()for bitwidth in [8, 4, 2]: recover_model()print(f'k-means quantizing model into {bitwidth} bits') quantizer = KMeansQuantizer(model, bitwidth) quantized_model_size = get_model_size(model, bitwidth)print(f" {bitwidth}-bit k-means quantized model has size={quantized_model_size/MiB:.2f} MiB") quantized_model_accuracy = evaluate(model, dataloader['test'])print(f" {bitwidth}-bit k-means quantized model has accuracy={quantized_model_accuracy:.2f}%") quantizers[bitwidth] = quantizer
Note that the storage for codebooks is ignored when calculating the model size.
k-means quantizing model into 8 bits
8-bit k-means quantized model has size=8.80 MiB
8-bit k-means quantized model has accuracy=92.76%
k-means quantizing model into 4 bits
4-bit k-means quantized model has size=4.40 MiB
4-bit k-means quantized model has accuracy=79.07%
k-means quantizing model into 2 bits
2-bit k-means quantized model has size=2.20 MiB
2-bit k-means quantized model has accuracy=10.00%
def update_codebook(fp32_tensor: torch.Tensor, codebook: Codebook):""" update the centroids in the codebook using updated fp32_tensor :param fp32_tensor: [torch.(cuda.)Tensor] :param codebook: [Codebook] (the cluster centroids, the cluster label tensor) """ n_clusters = codebook.centroids.numel() fp32_tensor = fp32_tensor.view(-1)for k inrange(n_clusters):############### YOUR CODE STARTS HERE ############### codebook.centroids[k] = fp32_tensor[codebook.labels == k].mean()############### YOUR CODE ENDS HERE #################
accuracy_drop_threshold =0.5quantizers_before_finetune = copy.deepcopy(quantizers)quantizers_after_finetune = quantizersfor bitwidth in [8, 4, 2]: recover_model() quantizer = quantizers[bitwidth]print(f'k-means quantizing model into {bitwidth} bits') quantizer.apply(model, update_centroids=False) quantized_model_size = get_model_size(model, bitwidth)print(f" {bitwidth}-bit k-means quantized model has size={quantized_model_size/MiB:.2f} MiB") quantized_model_accuracy = evaluate(model, dataloader['test'])print(f" {bitwidth}-bit k-means quantized model has accuracy={quantized_model_accuracy:.2f}% before quantization-aware training ") accuracy_drop = fp32_model_accuracy - quantized_model_accuracyif accuracy_drop > accuracy_drop_threshold:print(f" Quantization-aware training due to accuracy drop={accuracy_drop:.2f}% is larger than threshold={accuracy_drop_threshold:.2f}%") num_finetune_epochs =5 optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs) criterion = nn.CrossEntropyLoss() best_accuracy =0 epoch = num_finetune_epochswhile accuracy_drop > accuracy_drop_threshold and epoch >0: train(model, dataloader['train'], criterion, optimizer, scheduler, callbacks=[lambda: quantizer.apply(model, update_centroids=True)]) model_accuracy = evaluate(model, dataloader['test']) is_best = model_accuracy > best_accuracy best_accuracy =max(model_accuracy, best_accuracy)print(f' Epoch {num_finetune_epochs-epoch} Accuracy {model_accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%') accuracy_drop = fp32_model_accuracy - best_accuracy epoch -=1else:print(f" No need for quantization-aware training since accuracy drop={accuracy_drop:.2f}% is smaller than threshold={accuracy_drop_threshold:.2f}%")
k-means quantizing model into 8 bits
8-bit k-means quantized model has size=8.80 MiB
8-bit k-means quantized model has accuracy=92.76% before quantization-aware training
No need for quantization-aware training since accuracy drop=0.19% is smaller than threshold=0.50%
k-means quantizing model into 4 bits
4-bit k-means quantized model has size=4.40 MiB
4-bit k-means quantized model has accuracy=79.07% before quantization-aware training
Quantization-aware training due to accuracy drop=13.88% is larger than threshold=0.50%
Epoch 0 Accuracy 92.47% / Best Accuracy: 92.47%
k-means quantizing model into 2 bits
2-bit k-means quantized model has size=2.20 MiB
2-bit k-means quantized model has accuracy=10.00% before quantization-aware training
Quantization-aware training due to accuracy drop=82.95% is larger than threshold=0.50%
Epoch 0 Accuracy 90.21% / Best Accuracy: 90.21%
Epoch 1 Accuracy 90.82% / Best Accuracy: 90.82%
Epoch 2 Accuracy 91.00% / Best Accuracy: 91.00%
Epoch 3 Accuracy 91.12% / Best Accuracy: 91.12%
Epoch 4 Accuracy 91.17% / Best Accuracy: 91.17%
Linear Quantization
์ด ์น์ ์์๋ linear quantization์ ๊ตฌํํ๊ณ ์ํํฉ๋๋ค.
def linear_quantize(fp_tensor, bitwidth, scale, zero_point, dtype=torch.int8) -> torch.Tensor:""" linear quantization for single fp_tensor from fp_tensor = (quantized_tensor - zero_point) * scale we have, quantized_tensor = int(round(fp_tensor / scale)) + zero_point :param tensor: [torch.(cuda.)FloatTensor] floating tensor to be quantized :param bitwidth: [int] quantization bit width :param scale: [torch.(cuda.)FloatTensor] scaling factor :param zero_point: [torch.(cuda.)IntTensor] the desired centroid of tensor values :return: [torch.(cuda.)FloatTensor] quantized tensor whose values are integers """assert(fp_tensor.dtype == torch.float)assert(isinstance(scale, float) or (scale.dtype == torch.floatand scale.dim() == fp_tensor.dim()))assert(isinstance(zero_point, int) or (zero_point.dtype == dtype and zero_point.dim() == fp_tensor.dim()))############### YOUR CODE STARTS HERE ################ Step 1: scale the fp_tensor scaled_tensor = fp_tensor / scale# Step 2: round the floating value to integer value rounded_tensor = torch.round(scaled_tensor)############### YOUR CODE ENDS HERE ################# rounded_tensor = rounded_tensor.to(dtype)############### YOUR CODE STARTS HERE ################ Step 3: shift the rounded_tensor to make zero_point 0 shifted_tensor = rounded_tensor + zero_point############### YOUR CODE ENDS HERE ################## Step 4: clamp the shifted_tensor to lie in bitwidth-bit range quantized_min, quantized_max = get_quantized_range(bitwidth) quantized_tensor = shifted_tensor.clamp_(quantized_min, quantized_max)return quantized_tensor
def get_quantization_scale_for_weight(weight, bitwidth):""" get quantization scale for single tensor of weight :param weight: [torch.(cuda.)Tensor] floating weight to be quantized :param bitwidth: [integer] quantization bit width :return: [floating scalar] scale """# we just assume values in weight are symmetric# we also always make zero_point 0 for weight fp_max =max(weight.abs().max().item(), 5e-7) _, quantized_max = get_quantized_range(bitwidth)return fp_max / quantized_max
def quantized_linear(input, weight, bias, feature_bitwidth, weight_bitwidth, input_zero_point, output_zero_point, input_scale, weight_scale, output_scale):""" quantized fully-connected layer :param input: [torch.CharTensor] quantized input (torch.int8) :param weight: [torch.CharTensor] quantized weight (torch.int8) :param bias: [torch.IntTensor] shifted quantized bias or None (torch.int32) :param feature_bitwidth: [int] quantization bit width of input and output :param weight_bitwidth: [int] quantization bit width of weight :param input_zero_point: [int] input zero point :param output_zero_point: [int] output zero point :param input_scale: [float] input feature scale :param weight_scale: [torch.FloatTensor] weight per-channel scale :param output_scale: [float] output feature scale :return: [torch.CharIntTensor] quantized output feature (torch.int8) """assert(input.dtype == torch.int8)assert(weight.dtype ==input.dtype)assert(bias isNoneor bias.dtype == torch.int32)assert(isinstance(input_zero_point, int))assert(isinstance(output_zero_point, int))assert(isinstance(input_scale, float))assert(isinstance(output_scale, float))assert(weight_scale.dtype == torch.float)# Step 1: integer-based fully-connected (8-bit multiplication with 32-bit accumulation)if'cpu'ininput.device.type:# use 32-b MAC for simplicity output = torch.nn.functional.linear(input.to(torch.int32), weight.to(torch.int32), bias)else:# current version pytorch does not yet support integer-based linear() on GPUs output = torch.nn.functional.linear(input.float(), weight.float(), bias.float())############### YOUR CODE STARTS HERE ################ Step 2: scale the output# hint: 1. scales are floating numbers, we need to convert output to float as well# 2. the shape of weight scale is [oc, 1, 1, 1] while the shape of output is [batch_size, oc] real_scale = input_scale * weight_scale.view(-1) / output_scale output = output.float() * real_scale# Step 3: Shift output by output_zero_point output += output_zero_point############### YOUR CODE STARTS HERE ################ Make sure all value lies in the bitwidth-bit range output = output.round().clamp(*get_quantized_range(feature_bitwidth)).to(torch.int8)return output
Letโs verify the functionality of defined quantized fully connected layer.
test_quantized_fc()
* Test quantized_fc()
target bitwidth: 2 bits
batch size: 4
input channels: 8
output channels: 8
* Test passed.
def quantized_conv2d(input, weight, bias, feature_bitwidth, weight_bitwidth, input_zero_point, output_zero_point, input_scale, weight_scale, output_scale, stride, padding, dilation, groups):""" quantized 2d convolution :param input: [torch.CharTensor] quantized input (torch.int8) :param weight: [torch.CharTensor] quantized weight (torch.int8) :param bias: [torch.IntTensor] shifted quantized bias or None (torch.int32) :param feature_bitwidth: [int] quantization bit width of input and output :param weight_bitwidth: [int] quantization bit width of weight :param input_zero_point: [int] input zero point :param output_zero_point: [int] output zero point :param input_scale: [float] input feature scale :param weight_scale: [torch.FloatTensor] weight per-channel scale :param output_scale: [float] output feature scale :return: [torch.(cuda.)CharTensor] quantized output feature """assert(len(padding) ==4)assert(input.dtype == torch.int8)assert(weight.dtype ==input.dtype)assert(bias isNoneor bias.dtype == torch.int32)assert(isinstance(input_zero_point, int))assert(isinstance(output_zero_point, int))assert(isinstance(input_scale, float))assert(isinstance(output_scale, float))assert(weight_scale.dtype == torch.float)# Step 1: calculate integer-based 2d convolution (8-bit multiplication with 32-bit accumulation)input= torch.nn.functional.pad(input, padding, 'constant', input_zero_point)if'cpu'ininput.device.type:# use 32-b MAC for simplicity output = torch.nn.functional.conv2d(input.to(torch.int32), weight.to(torch.int32), None, stride, 0, dilation, groups)else:# current version pytorch does not yet support integer-based conv2d() on GPUs output = torch.nn.functional.conv2d(input.float(), weight.float(), None, stride, 0, dilation, groups) output = output.round().to(torch.int32)if bias isnotNone: output = output + bias.view(1, -1, 1, 1)############### YOUR CODE STARTS HERE ################ hint: this code block should be the very similar to quantized_linear()# Step 2: scale the output# hint: 1. scales are floating numbers, we need to convert output to float as well# 2. the shape of weight scale is [oc, 1, 1, 1] while the shape of output is [batch_size, oc, height, width] real_scale = input_scale * weight_scale.view(-1) / output_scale output = output.float() * real_scale.unsqueeze(1).unsqueeze(2)# Step 3: shift output by output_zero_point# hint: one line of code output += output_zero_point############### YOUR CODE STARTS HERE ################ Make sure all value lies in the bitwidth-bit range output = output.round().clamp(*get_quantized_range(feature_bitwidth)).to(torch.int8)return output
# add hook to record the min max value of the activationinput_activation = {}output_activation = {}def add_range_recoder_hook(model):import functoolsdef _record_range(self, x, y, module_name): x = x[0] input_activation[module_name] = x.detach() output_activation[module_name] = y.detach() all_hooks = []for name, m in model.named_modules():ifisinstance(m, (nn.Conv2d, nn.Linear, nn.ReLU)): all_hooks.append(m.register_forward_hook( functools.partial(_record_range, module_name=name)))return all_hookshooks = add_range_recoder_hook(model_fused)sample_data =iter(dataloader['train']).__next__()[0]model_fused(sample_data.cuda())# remove hooksfor h in hooks: h.remove()
nn.Conv2d: QuantizedConv2d,nn.Linear: QuantizedLinear,# the following twos are just wrappers, as current# torch modules do not support int8 data format;# we will temporarily convert them to fp32 for computationnn.MaxPool2d: QuantizedMaxPool2d,nn.AvgPool2d: QuantizedAvgPool2d,
class QuantizedConv2d(nn.Module):def__init__(self, weight, bias, input_zero_point, output_zero_point, input_scale, weight_scale, output_scale, stride, padding, dilation, groups, feature_bitwidth=8, weight_bitwidth=8):super().__init__()# current version Pytorch does not support IntTensor as nn.Parameterself.register_buffer('weight', weight)self.register_buffer('bias', bias)self.input_zero_point = input_zero_pointself.output_zero_point = output_zero_pointself.input_scale = input_scaleself.register_buffer('weight_scale', weight_scale)self.output_scale = output_scaleself.stride = strideself.padding = (padding[1], padding[1], padding[0], padding[0])self.dilation = dilationself.groups = groupsself.feature_bitwidth = feature_bitwidthself.weight_bitwidth = weight_bitwidthdef forward(self, x):return quantized_conv2d( x, self.weight, self.bias,self.feature_bitwidth, self.weight_bitwidth,self.input_zero_point, self.output_zero_point,self.input_scale, self.weight_scale, self.output_scale,self.stride, self.padding, self.dilation, self.groups )class QuantizedLinear(nn.Module):def__init__(self, weight, bias, input_zero_point, output_zero_point, input_scale, weight_scale, output_scale, feature_bitwidth=8, weight_bitwidth=8):super().__init__()# current version Pytorch does not support IntTensor as nn.Parameterself.register_buffer('weight', weight)self.register_buffer('bias', bias)self.input_zero_point = input_zero_pointself.output_zero_point = output_zero_pointself.input_scale = input_scaleself.register_buffer('weight_scale', weight_scale)self.output_scale = output_scaleself.feature_bitwidth = feature_bitwidthself.weight_bitwidth = weight_bitwidthdef forward(self, x):return quantized_linear( x, self.weight, self.bias,self.feature_bitwidth, self.weight_bitwidth,self.input_zero_point, self.output_zero_point,self.input_scale, self.weight_scale, self.output_scale )class QuantizedMaxPool2d(nn.MaxPool2d):def forward(self, x):# current version PyTorch does not support integer-based MaxPoolreturnsuper().forward(x.float()).to(torch.int8)class QuantizedAvgPool2d(nn.AvgPool2d):def forward(self, x):# current version PyTorch does not support integer-based AvgPoolreturnsuper().forward(x.float()).to(torch.int8)# we use int8 quantization, which is quite popularfeature_bitwidth = weight_bitwidth =8quantized_model = copy.deepcopy(model_fused)quantized_backbone = []ptr =0while ptr <len(quantized_model.backbone):ifisinstance(quantized_model.backbone[ptr], nn.Conv2d) and\isinstance(quantized_model.backbone[ptr +1], nn.ReLU): conv = quantized_model.backbone[ptr] conv_name =f'backbone.{ptr}' relu = quantized_model.backbone[ptr +1] relu_name =f'backbone.{ptr +1}' input_scale, input_zero_point =\ get_quantization_scale_and_zero_point( input_activation[conv_name], feature_bitwidth) output_scale, output_zero_point =\ get_quantization_scale_and_zero_point( output_activation[relu_name], feature_bitwidth) quantized_weight, weight_scale, weight_zero_point =\ linear_quantize_weight_per_channel(conv.weight.data, weight_bitwidth) quantized_bias, bias_scale, bias_zero_point =\ linear_quantize_bias_per_output_channel( conv.bias.data, weight_scale, input_scale) shifted_quantized_bias =\ shift_quantized_conv2d_bias(quantized_bias, quantized_weight, input_zero_point) quantized_conv = QuantizedConv2d( quantized_weight, shifted_quantized_bias, input_zero_point, output_zero_point, input_scale, weight_scale, output_scale, conv.stride, conv.padding, conv.dilation, conv.groups, feature_bitwidth=feature_bitwidth, weight_bitwidth=weight_bitwidth ) quantized_backbone.append(quantized_conv) ptr +=2elifisinstance(quantized_model.backbone[ptr], nn.MaxPool2d): quantized_backbone.append(QuantizedMaxPool2d( kernel_size=quantized_model.backbone[ptr].kernel_size, stride=quantized_model.backbone[ptr].stride )) ptr +=1elifisinstance(quantized_model.backbone[ptr], nn.AvgPool2d): quantized_backbone.append(QuantizedAvgPool2d( kernel_size=quantized_model.backbone[ptr].kernel_size, stride=quantized_model.backbone[ptr].stride )) ptr +=1else:raiseNotImplementedError(type(quantized_model.backbone[ptr])) # should not happenquantized_model.backbone = nn.Sequential(*quantized_backbone)# finally, quantized the classifierfc_name ='classifier'fc = model.classifierinput_scale, input_zero_point =\ get_quantization_scale_and_zero_point( input_activation[fc_name], feature_bitwidth)output_scale, output_zero_point =\ get_quantization_scale_and_zero_point( output_activation[fc_name], feature_bitwidth)quantized_weight, weight_scale, weight_zero_point =\ linear_quantize_weight_per_channel(fc.weight.data, weight_bitwidth)quantized_bias, bias_scale, bias_zero_point =\ linear_quantize_bias_per_output_channel( fc.bias.data, weight_scale, input_scale)shifted_quantized_bias =\ shift_quantized_linear_bias(quantized_bias, quantized_weight, input_zero_point)quantized_model.classifier = QuantizedLinear( quantized_weight, shifted_quantized_bias, input_zero_point, output_zero_point, input_scale, weight_scale, output_scale, feature_bitwidth=feature_bitwidth, weight_bitwidth=weight_bitwidth)
print(quantized_model)def extra_preprocess(x):# hint: you need to convert the original fp32 input of range (0, 1)# into int8 format of range (-128, 127)############### YOUR CODE STARTS HERE ############### x_scaled = x *255 x_shifted = x_scaled -128return x_shifted.clamp(-128, 127).to(torch.int8)############### YOUR CODE ENDS HERE #################int8_model_accuracy = evaluate(quantized_model, dataloader['test'], extra_preprocess=[extra_preprocess])print(f"int8 model has accuracy={int8_model_accuracy:.2f}%")