import neurenix as nx
from neurenix.quantization import (
quantize_model,
quantization_aware_training,
calibrate_model,
QuantizationType
)
# 1. Train full precision model
model = create_resnet18()
optimizer = nx.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(90):
train_one_epoch(model, train_loader, optimizer)
print(f"FP32 accuracy: {evaluate(model, test_loader):.2f}%")
# 2. Post-training quantization (fast, good for inference)
quantized_model = quantize_model(model, dtype=QuantizationType.INT8)
print(f"PTQ INT8 accuracy: {evaluate(quantized_model, test_loader):.2f}%")
# 3. Quantization-aware training (better accuracy)
qat_model = quantization_aware_training(model, dtype=QuantizationType.INT8)
optimizer = nx.optim.SGD(qat_model.parameters(), lr=0.01)
for epoch in range(10): # Fine-tune
train_one_epoch(qat_model, train_loader, optimizer)
final_quantized = quantize_model(qat_model, dtype=QuantizationType.INT8)
print(f"QAT INT8 accuracy: {evaluate(final_quantized, test_loader):.2f}%")
# 4. Measure model size
import os
nx.save(model, 'model_fp32.pth')
nx.save(final_quantized, 'model_int8.pth')
fp32_size = os.path.getsize('model_fp32.pth') / 1024 / 1024
int8_size = os.path.getsize('model_int8.pth') / 1024 / 1024
print(f"FP32 model: {fp32_size:.2f} MB")
print(f"INT8 model: {int8_size:.2f} MB")
print(f"Compression ratio: {fp32_size / int8_size:.2f}x")