Knowledge Distillation for Edge Deployment: Compressing CNNs Without Losing Accuracy

Published on January 5, 2026

Deep learning models achieve impressive accuracy, but deploying them on resource-constrained devices remains challenging. In this post, I'll share how we used knowledge distillation to compress large CNN models for crack detection, achieving a 20% improvement in Matthews Correlation Coefficient (MCC) while making models suitable for embedded deployment.

The Problem: Deploying AI on Edge Devices

Infrastructure inspection requires AI that can run on portable devices—drones, handheld scanners, or embedded systems. But state-of-the-art models like ResNet-152 or DenseNet-201 demand significant computational resources:

Model Parameters Size (MB) Inference Time (Jetson Nano)
ResNet-152 60.2M 230 890ms
DenseNet-201 20.0M 77 1240ms
VGG-19 143.7M 549 1580ms

For real-time inspection, we needed models that could run in under 100ms while maintaining detection accuracy.

Knowledge Distillation: Teaching Small Models to Think Big

Knowledge distillation transfers "knowledge" from a large teacher model to a smaller student model. The key insight is that the teacher's soft probability outputs contain more information than hard labels alone.

The Distillation Loss

We train the student with a combination of hard label loss and soft target loss:

import tensorflow as tf

def distillation_loss(y_true, y_pred, teacher_pred, temperature=3.0, alpha=0.5):
    """
    Combined loss for knowledge distillation.
    
    Args:
        y_true: Ground truth labels
        y_pred: Student predictions
        teacher_pred: Teacher soft predictions
        temperature: Softmax temperature (higher = softer)
        alpha: Weight for distillation loss
    """
    # Hard label loss (standard cross-entropy)
    hard_loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    
    # Soft label loss (KL divergence with temperature)
    soft_teacher = tf.nn.softmax(teacher_pred / temperature)
    soft_student = tf.nn.softmax(y_pred / temperature)
    soft_loss = tf.keras.losses.KLDivergence()(soft_teacher, soft_student)
    
    # Temperature scaling for gradient magnitude
    soft_loss = soft_loss * (temperature ** 2)
    
    # Combined loss
    return alpha * soft_loss + (1 - alpha) * hard_loss

Training Pipeline

class DistillationTrainer:
    def __init__(self, teacher_model, student_model, temperature=3.0):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        
        # Freeze teacher weights
        self.teacher.trainable = False
    
    def train_step(self, images, labels):
        # Get teacher predictions
        teacher_logits = self.teacher(images, training=False)
        
        with tf.GradientTape() as tape:
            # Get student predictions
            student_logits = self.student(images, training=True)
            
            # Calculate distillation loss
            loss = distillation_loss(
                labels, 
                student_logits, 
                teacher_logits,
                self.temperature
            )
        
        # Update student weights
        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(
            zip(gradients, self.student.trainable_variables)
        )
        
        return loss

Case Study: SDNET2018 Crack Detection

We applied distillation to the SDNET2018 dataset containing 56,000+ images of concrete surfaces. Our approach:

1. Teacher Ensemble

Instead of a single teacher, we used an ensemble of high-performing models:

def create_teacher_ensemble(models):
    """Create ensemble teacher from multiple models."""
    inputs = tf.keras.Input(shape=(224, 224, 3))
    
    outputs = []
    for model in models:
        outputs.append(model(inputs))
    
    # Average predictions
    averaged = tf.keras.layers.Average()(outputs)
    
    return tf.keras.Model(inputs, averaged)

# Ensemble of ResNet-152, DenseNet-201, and EfficientNet-B4
teacher = create_teacher_ensemble([resnet152, densenet201, efficientnetb4])

2. Student Architecture Selection

We benchmarked lightweight architectures as students:

Student Model Parameters MCC (No Distillation) MCC (With Distillation)
ResNet-50 25.6M 0.72 0.86
VGG-19 20.0M 0.68 0.82
MobileNetV3 5.4M 0.61 0.78

Knowledge distillation improved MCC by 14-20% across all student architectures!

3. Data Augmentation for Distillation

We found that aggressive augmentation during distillation improved generalization:

augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomZoom(0.2),
    tf.keras.layers.RandomContrast(0.2),
    # Crack-specific augmentations
    tf.keras.layers.GaussianNoise(0.1),
    tf.keras.layers.RandomBrightness(0.2),
])

Interpretability with Grad-CAM

For infrastructure inspection, interpretability is crucial. Engineers need to understand why the model flagged an area. We integrated Grad-CAM visualizations:

import numpy as np

def grad_cam(model, image, layer_name='conv5_block3_out'):
    """Generate Grad-CAM heatmap for model prediction."""
    grad_model = tf.keras.Model(
        inputs=model.input,
        outputs=[model.get_layer(layer_name).output, model.output]
    )
    
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(image)
        class_idx = tf.argmax(predictions[0])
        loss = predictions[:, class_idx]
    
    # Gradient of class output w.r.t. feature map
    grads = tape.gradient(loss, conv_outputs)
    
    # Global average pooling of gradients
    weights = tf.reduce_mean(grads, axis=(1, 2))
    
    # Weighted combination of feature maps
    cam = tf.reduce_sum(conv_outputs * weights[:, tf.newaxis, tf.newaxis, :], axis=-1)
    
    # ReLU and normalize
    cam = tf.nn.relu(cam)
    cam = cam / tf.reduce_max(cam)
    
    return cam.numpy()[0]

Grad-CAM showed that distilled models learned to focus on crack edges and textures, similar to their teacher models—evidence that knowledge transfer was successful.

Deployment Results

After distillation and INT8 quantization, our models achieved real-time performance:

Model Size (MB) Jetson Nano (ms) MCC
ResNet-152 (Teacher) 230 890 0.89
ResNet-50 (Distilled) 98 145 0.86
ResNet-50 (Distilled + INT8) 25 67 0.84
MobileNetV3 (Distilled + INT8) 8 23 0.76

The distilled ResNet-50 with INT8 quantization runs at 15 FPS on a Jetson Nano while retaining 94% of the teacher's accuracy!

Key Takeaways

  1. Temperature matters: Higher temperatures (3-5) work better for similar teacher-student architectures; lower temperatures (1-2) for very different architectures.
  2. Ensemble teachers outperform single teachers: The diversity of predictions provides richer supervision.
  3. Data augmentation during distillation: Helps the student learn robust features rather than memorizing teacher outputs.
  4. Combine with quantization: Distillation and quantization are complementary—use both for maximum compression.

Future Work

Related Project: Efficient Model for Concrete Cracks Detection - View the full project with 9 CNN architectures benchmarked.