Knowledge Distillation for Edge Deployment: Compressing CNNs Without Losing Accuracy
Published on January 5, 2026
Deep learning models achieve impressive accuracy, but deploying them on resource-constrained devices remains challenging. In this post, I'll share how we used knowledge distillation to compress large CNN models for crack detection, achieving a 20% improvement in Matthews Correlation Coefficient (MCC) while making models suitable for embedded deployment.
The Problem: Deploying AI on Edge Devices
Infrastructure inspection requires AI that can run on portable devices—drones, handheld scanners, or embedded systems. But state-of-the-art models like ResNet-152 or DenseNet-201 demand significant computational resources:
| Model | Parameters | Size (MB) | Inference Time (Jetson Nano) |
|---|---|---|---|
| ResNet-152 | 60.2M | 230 | 890ms |
| DenseNet-201 | 20.0M | 77 | 1240ms |
| VGG-19 | 143.7M | 549 | 1580ms |
For real-time inspection, we needed models that could run in under 100ms while maintaining detection accuracy.
Knowledge Distillation: Teaching Small Models to Think Big
Knowledge distillation transfers "knowledge" from a large teacher model to a smaller student model. The key insight is that the teacher's soft probability outputs contain more information than hard labels alone.
The Distillation Loss
We train the student with a combination of hard label loss and soft target loss:
import tensorflow as tf
def distillation_loss(y_true, y_pred, teacher_pred, temperature=3.0, alpha=0.5):
"""
Combined loss for knowledge distillation.
Args:
y_true: Ground truth labels
y_pred: Student predictions
teacher_pred: Teacher soft predictions
temperature: Softmax temperature (higher = softer)
alpha: Weight for distillation loss
"""
# Hard label loss (standard cross-entropy)
hard_loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
# Soft label loss (KL divergence with temperature)
soft_teacher = tf.nn.softmax(teacher_pred / temperature)
soft_student = tf.nn.softmax(y_pred / temperature)
soft_loss = tf.keras.losses.KLDivergence()(soft_teacher, soft_student)
# Temperature scaling for gradient magnitude
soft_loss = soft_loss * (temperature ** 2)
# Combined loss
return alpha * soft_loss + (1 - alpha) * hard_loss
Training Pipeline
class DistillationTrainer:
def __init__(self, teacher_model, student_model, temperature=3.0):
self.teacher = teacher_model
self.student = student_model
self.temperature = temperature
# Freeze teacher weights
self.teacher.trainable = False
def train_step(self, images, labels):
# Get teacher predictions
teacher_logits = self.teacher(images, training=False)
with tf.GradientTape() as tape:
# Get student predictions
student_logits = self.student(images, training=True)
# Calculate distillation loss
loss = distillation_loss(
labels,
student_logits,
teacher_logits,
self.temperature
)
# Update student weights
gradients = tape.gradient(loss, self.student.trainable_variables)
self.optimizer.apply_gradients(
zip(gradients, self.student.trainable_variables)
)
return loss
Case Study: SDNET2018 Crack Detection
We applied distillation to the SDNET2018 dataset containing 56,000+ images of concrete surfaces. Our approach:
1. Teacher Ensemble
Instead of a single teacher, we used an ensemble of high-performing models:
def create_teacher_ensemble(models):
"""Create ensemble teacher from multiple models."""
inputs = tf.keras.Input(shape=(224, 224, 3))
outputs = []
for model in models:
outputs.append(model(inputs))
# Average predictions
averaged = tf.keras.layers.Average()(outputs)
return tf.keras.Model(inputs, averaged)
# Ensemble of ResNet-152, DenseNet-201, and EfficientNet-B4
teacher = create_teacher_ensemble([resnet152, densenet201, efficientnetb4])
2. Student Architecture Selection
We benchmarked lightweight architectures as students:
| Student Model | Parameters | MCC (No Distillation) | MCC (With Distillation) |
|---|---|---|---|
| ResNet-50 | 25.6M | 0.72 | 0.86 |
| VGG-19 | 20.0M | 0.68 | 0.82 |
| MobileNetV3 | 5.4M | 0.61 | 0.78 |
Knowledge distillation improved MCC by 14-20% across all student architectures!
3. Data Augmentation for Distillation
We found that aggressive augmentation during distillation improved generalization:
augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal_and_vertical"),
tf.keras.layers.RandomRotation(0.2),
tf.keras.layers.RandomZoom(0.2),
tf.keras.layers.RandomContrast(0.2),
# Crack-specific augmentations
tf.keras.layers.GaussianNoise(0.1),
tf.keras.layers.RandomBrightness(0.2),
])
Interpretability with Grad-CAM
For infrastructure inspection, interpretability is crucial. Engineers need to understand why the model flagged an area. We integrated Grad-CAM visualizations:
import numpy as np
def grad_cam(model, image, layer_name='conv5_block3_out'):
"""Generate Grad-CAM heatmap for model prediction."""
grad_model = tf.keras.Model(
inputs=model.input,
outputs=[model.get_layer(layer_name).output, model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(image)
class_idx = tf.argmax(predictions[0])
loss = predictions[:, class_idx]
# Gradient of class output w.r.t. feature map
grads = tape.gradient(loss, conv_outputs)
# Global average pooling of gradients
weights = tf.reduce_mean(grads, axis=(1, 2))
# Weighted combination of feature maps
cam = tf.reduce_sum(conv_outputs * weights[:, tf.newaxis, tf.newaxis, :], axis=-1)
# ReLU and normalize
cam = tf.nn.relu(cam)
cam = cam / tf.reduce_max(cam)
return cam.numpy()[0]
Grad-CAM showed that distilled models learned to focus on crack edges and textures, similar to their teacher models—evidence that knowledge transfer was successful.
Deployment Results
After distillation and INT8 quantization, our models achieved real-time performance:
| Model | Size (MB) | Jetson Nano (ms) | MCC |
|---|---|---|---|
| ResNet-152 (Teacher) | 230 | 890 | 0.89 |
| ResNet-50 (Distilled) | 98 | 145 | 0.86 |
| ResNet-50 (Distilled + INT8) | 25 | 67 | 0.84 |
| MobileNetV3 (Distilled + INT8) | 8 | 23 | 0.76 |
The distilled ResNet-50 with INT8 quantization runs at 15 FPS on a Jetson Nano while retaining 94% of the teacher's accuracy!
Key Takeaways
- Temperature matters: Higher temperatures (3-5) work better for similar teacher-student architectures; lower temperatures (1-2) for very different architectures.
- Ensemble teachers outperform single teachers: The diversity of predictions provides richer supervision.
- Data augmentation during distillation: Helps the student learn robust features rather than memorizing teacher outputs.
- Combine with quantization: Distillation and quantization are complementary—use both for maximum compression.
Future Work
- Self-distillation: Using the model as its own teacher across training epochs
- Feature-based distillation: Matching intermediate layer representations
- Neural Architecture Search: Automatically finding optimal student architectures
- On-device fine-tuning: Adapting models to new crack types in the field