arXiv: 2604.19398 · PDF
作者: Ziyang Wang, Jiangfeng Xiao, Chuan Xiao, Ruoxiang Li, Rui Mao, Jianbin Qin
主分类: cs.AI · 全部: cs.AI
命中关键词: large language model, llm, rag, inference, kv cache, attention, gpu, latency, fine-tun
TL;DR
GRASPrune 提出面向 LLM 的结构化剪枝框架,用全局预算下的轻量门控分数,在预训练后联合剪枝 FFN 通道和 KV head group,无需微调骨干权重。
核心观点
- 将 FFN 通道与 KV head group 统一在单一全局预算下联合剪枝,而非分模块独立决策。
- 用投影式 straight-through estimator (STE) 学习门控分数,每一步都强制满足硬预算掩码。
- 骨干权重冻结,仅训练轻量 gate,显著降低训练开销。
- 剪枝后通过scaling factor 校准并折叠进权重,得到无额外推理参数的更小 dense checkpoint。
方法
- 训练后剪枝(post-pretraining):在 FFN channel 与 KV head group 两种结构单元上放置 gate。
- 通过 projected STE 让前向使用硬 0/1 mask 并满足预算约束,反向传递连续梯度更新 gate。
- mask 固定后,对保留单元的 scale 做校准,补偿剪枝引起的激活 scale mismatch。
- 将校准后的 scale 折叠到权重,产出一个尺寸更小的 dense 模型,推理时无额外算子。
实验
- 模型:LLaMA-2-7B。
- 剪枝率:50% 参数移除。
- 数据:512 条无标签校准序列,训练 4 个 epoch。
- 硬件:单卡 NVIDIA A100 80GB。
- 评测:WikiText-2 perplexity,以及 5 个 zero-shot 基准的平均准确率。
结果
- WikiText-2 上 perplexity 12.18(50% 剪枝率下)。
- 5 个 zero-shot 基准平均准确率与基线具竞争力(摘要未给具体数值)。
- 无需对整模型做 full fine-tuning,成本低。
为什么重要
- 对推理基础设施,同时压缩 FFN 和 KV cache 的剪枝能同时降低 memory、latency 与 KV cache 占用。
- 单 A100 完成 7B 模型剪枝且无 full FT,门槛低、可复现,适合部署团队快速裁剪自有模型。
- 输出 dense checkpoint,兼容现有推理栈,无需稀疏算子支持。
与已有工作的关系
- 延续 LLM-Pruner、SliceGPT、Wanda、SparseGPT 等 post-training 结构化/非结构化剪枝思路。
- gate + STE 学习 mask 的做法与 Movement Pruning、DSNet 等一脉相承。
- 关注 KV head 组剪枝与 GQA、MQA、KV cache compression 研究方向相关。
- scale 校准折叠思想与 SmoothQuant、AWQ 的 scale 迁移技巧类似。
尚未回答的问题
- 在更大模型(13B/70B)和更高剪枝率下是否仍保持 PPL?
- 与 SparseGPT / Wanda 等强基线的直接对比数字未给出。
- 与量化(INT4/INT8)叠加后的效果与误差累积如何?
- 对 long-context 推理与 KV cache 实际延迟/显存节省的端到端测量缺失。
- gate 训练对校准数据领域分布的敏感性未讨论。
论文图表
图 1: Page 2 (rendered)

图 2: Page 3 (rendered)

图 3: Page 4 (rendered)

原始摘要
Large language models (LLMs) are expensive to serve because model parameters, attention computation, and KV caches impose substantial memory and latency costs. We present GRASPrune, a structured pruning framework applied after pretraining that jointly prunes FFN channels and KV head groups under a single global budget. Instead of learning importance scores without constraints and applying the budget only after training, GRASPrune learns lightweight gate scores with a projected straight-through estimator that enforces a hard mask satisfying the budget at every step while keeping the backbone weights frozen. After the mask is fixed, we calibrate scaling factors on the retained units to mitigate scale mismatch caused by pruning, and fold these factors into the pruned weights to obtain a smaller dense checkpoint with no extra parameters at inference. On LLaMA-2-7B, GRASPrune removes 50% of parameters and achieves 12.18 perplexity on WikiText-2 while maintaining competitive average zero-shot accuracy on five benchmarks, using four epochs on 512 unlabeled calibration sequences on a single NVIDIA A100 80GB GPU without any full model fine-tuning.