arXiv: 2604.19398 · PDF
Authors: Ziyang Wang, Jiangfeng Xiao, Chuan Xiao, Ruoxiang Li, Rui Mao, Jianbin Qin
Primary category: cs.AI · all: cs.AI
Matched keywords: large language model, llm, rag, inference, kv cache, attention, gpu, latency, fine-tun
TL;DR
GRASPrune is a post-pretraining structured pruning framework that jointly prunes FFN channels and KV head groups under a single global budget using projected straight-through gate learning, producing a smaller dense checkpoint without fine-tuning the backbone.
Key Ideas
- Joint structured pruning across FFN channels and KV head groups under one unified global budget.
- Projected straight-through estimator (STE) enforces the hard budget mask at every training step, not just post-hoc.
- Backbone weights stay frozen; only lightweight gate scores are learned.
- Post-mask scale calibration folded into pruned weights → no inference-time overhead.
Approach
- Attach learnable gate scores to FFN channels and KV head groups.
- Apply a projected STE that always satisfies the global budget as a hard constraint during optimization.
- Freeze the LLM backbone; train only gates on unlabeled calibration data.
- Once the mask is fixed, calibrate scaling factors on retained units to compensate for distributional shift from pruning.
- Fold scales into the remaining weights, yielding a compact dense checkpoint.
Experiments
- Model: LLaMA-2-7B.
- Calibration: 512 unlabeled sequences, 4 epochs on a single A100 80GB.
- Perplexity benchmark: WikiText-2.
- Zero-shot accuracy: average over five benchmarks (names not specified in abstract).
- Baselines and specific competing methods are not listed in the abstract.
Results
- 50% parameter removal on LLaMA-2-7B.
- 12.18 WikiText-2 perplexity at that pruning rate.
- “Competitive” average zero-shot accuracy across five benchmarks (exact numbers and baseline comparisons not given in the abstract).
- No full fine-tuning required.
Why It Matters
Offers a cheap, fine-tuning-free path to halving LLM serving memory while jointly attacking both FFN and KV-cache costs — directly relevant for inference infra teams constrained by GPU memory and latency budgets.
Connections to Prior Work
- Post-training structured pruning: LLM-Pruner, SliceGPT, Sheared-LLaMA.
- KV-cache compression / GQA-style head grouping.
- Gate-based / mask-learning pruning with STE (e.g., movement pruning, DiffPruning).
- Distinction: most prior methods learn unconstrained importance then apply budget ex-post; GRASPrune enforces the budget as a hard constraint throughout.
Open Questions
- Which baselines does it beat, and by how much on each of the five zero-shot tasks?
- Does the method scale to larger models (13B/70B) or other families (Mistral, Qwen)?
- How does it compare to pruning + LoRA recovery pipelines?
- Sensitivity to calibration data size/domain.
- Behavior at more aggressive pruning ratios (>50%) and effect on long-context KV efficiency.
Figures
Figure 1: Page 2 (rendered)

Figure 2: Page 3 (rendered)

Figure 3: Page 4 (rendered)

Original abstract
Large language models (LLMs) are expensive to serve because model parameters, attention computation, and KV caches impose substantial memory and latency costs. We present GRASPrune, a structured pruning framework applied after pretraining that jointly prunes FFN channels and KV head groups under a single global budget. Instead of learning importance scores without constraints and applying the budget only after training, GRASPrune learns lightweight gate scores with a projected straight-through estimator that enforces a hard mask satisfying the budget at every step while keeping the backbone weights frozen. After the mask is fixed, we calibrate scaling factors on the retained units to mitigate scale mismatch caused by pruning, and fold these factors into the pruned weights to obtain a smaller dense checkpoint with no extra parameters at inference. On LLaMA-2-7B, GRASPrune removes 50% of parameters and achieves 12.18 perplexity on WikiText-2 while maintaining competitive average zero-shot accuracy on five benchmarks, using four epochs on 512 unlabeled calibration sequences on a single NVIDIA A100 80GB GPU without any full model fine-tuning.