Skip to content

WIP: DeepSeek-V4-Flash architecture port (BF16 reference, accuracy-first MVP)#14

Draft
danielhanchen wants to merge 10 commits intomasterfrom
deepseek-v4-flash-port
Draft

WIP: DeepSeek-V4-Flash architecture port (BF16 reference, accuracy-first MVP)#14
danielhanchen wants to merge 10 commits intomasterfrom
deepseek-v4-flash-port

Conversation

@danielhanchen
Copy link
Copy Markdown
Collaborator

Summary

Port DeepSeek-V4-Flash (284B / 13B activated, 1M context, FP8 + native FP4 mixed) to llama.cpp.

This is a draft PR while implementation lands in stages. Accuracy first, performance second. The MVP target uses BF16 routed experts (FP4 dequant at convert time) and a dense-Indexer fallback (V3.2-nolight equivalent). MTP serving, native MXFP4 expert passthrough, and the real top-k=512 DSA Indexer are deferred to follow-ups.

Reference (golden baseline)

Architectural deltas vs DeepSeek-V3 (high level)

  • 43 main blocks + 1 MTP head; hidden 4096; 64 attn heads; head_dim 512 (rope_head_dim 64 + nope 448).
  • MQA-style with one shared K=V projection (wkv: D -> head_dim). Per-head Q-norm without learnable scale.
  • Grouped low-rank output: wo_a [n_groups=8, o_lora_rank=1024, ...] einsum + wo_b.
  • Per-layer compress_ratios array (length 44) with values in {0, 4, 128}: SWA-only (0), CSA + indexer (4), HCA (128). All paths combine a 128-token sliding window with the compressed blocks.
  • Per-head learnable attn_sink added in softmax denominator.
  • mHC residual streams (hc_mult=4, sinkhorn 20 iters with asymmetric iter-0 eps placement) replace the standard residual add.
  • MoE: 256 routed + 1 shared, 6 activated. sqrt-softplus scoring; router bias affects topk selection only (not weights). First 3 layers use hash routing via a tid2eid int32 LUT. SwiGLU clamps gate.clamp(max=10) BEFORE silu and up.clamp(plus or minus 10) on routed AND shared experts.
  • YaRN RoPE: ON for compress layers (theta=160000, factor=16, original_max=65536); OFF for SWA layers (theta=10000).
  • Quant: FP8 e4m3fn (UE8M0 scales, 128x128 blocks) for non-expert weights; FP4 e2m1 (UE8M0 scales, 32-block) for routed experts. FP8 to BF16 at convert time. Routed FP4 to BF16 in MVP; native MXFP4 deferred.

MVP scope (this PR)

  • Architecture registration (LLM_ARCH_DEEPSEEK_V4 + KV keys + tensor schema in arch.h/arch.cpp)
  • Hparams parsing in llm_load_hparams
  • Tensor loader in llm_load_tensors
  • Converter: DeepseekV4Model class in convert_hf_to_gguf.py
  • Stub graph: load-only llm_build_deepseek_v4 (junk logits)
  • Chat template port: LLM_CHAT_TEMPLATE_DEEPSEEK_V4 matching encoding_dsv4.encode_messages chat-mode subset
  • BF16 reference graph (full 43 layers + dense indexer fallback, MTP tensors loaded but graph skips MTP forward)
  • Per-sub-block .npy goldens captured from inference/model.py deterministic runs on B200
  • RMSE harness comparing C++ activations to Python goldens, layerwise (first/last layer, then expand)

Deferred

  • Native MXFP4 routed expert passthrough (gpt-oss precedent reusable)
  • Real DSA top-k=512 Indexer
  • MTP self-speculative decoding
  • Performance work (custom kernels for compressed attention and indexer paths)

Test plan

End-to-end tests use llama-server with chat template applied. Easy-to-verify prompts:

  • "What is 1+1?"
  • "Say 1 to 100 then backwards"
  • "Say 1 to 1000"

Token-by-token argmax must match the Python reference for the prompts above.

Layerwise validation order: layer 0 + layer 42 first. Expand to {0, 1, 2, 3, 41, 42} (covering SWA, CSA, HCA boundaries). Then all 43 layers.

Add LLM_ARCH_DEEPSEEK_V4 to llm_arch enum and register the architecture
name "deepseek-v4". Declare the V4-specific KV keys (o_lora_rank,
o_groups, compress_ratios, compress_rope_freq_base, hyperconnections.*,
n_hash_layers) and tensor IDs covering the single shared K=V projection,
grouped low-rank wo_a/wo_b, KV compressor (CSA + HCA), separate indexer
compressor, manifold-constrained hyper-connection (mHC) tensors, hash
routing LUT, and MTP-specific tensors.

Add tensor name templates and llm_tensor_info entries; mirror the
DEEPSEEK2 case in llm_arch_supports_sm_tensor and llama_model_rope_type
(V4 uses GGML_ROPE_TYPE_NORMAL / LLAMA_ROPE_TYPE_NORM, matching the
view_as_complex pair convention in inference/model.py).

This commit is load-only scaffolding. Hparams parsing, tensor loading,
and the build_v4 graph follow in subsequent commits.

Refs:
  https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/inference/model.py
  https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/config.json
Add SQRT_SOFTPLUS = 4 to llama_expert_gating_func_type for V4 routing.
Add llama_hparams fields for V4 (o_lora_rank, o_groups, hc_mult,
hc_sinkhorn_iters, hc_eps, n_hash_layers, compress_rope_freq_base,
compress_ratios array).

llm_load_hparams: register a DEEPSEEK_V4 case that reads all V4 KV
keys plus the standard ones (Q lora, expert counts, sliding window,
swiglu clamps for routed and shared, indexer dims for the CSA layers
that are dense-fallback in the MVP).

llm_load_tensors: PR1 minimal - load only top-level token_embd,
output_norm, output. Per-layer V4 tensors land in a follow-up commit
once the full graph is wired.

src/models/deepseek-v4.cpp + models.h: stub llm_build_deepseek_v4
that embeds, RMSNorms, and projects to logits so the model loads
end-to-end and llama-server boots. Output is junk.

Graph dispatch in llama-model.cpp routes LLM_ARCH_DEEPSEEK_V4 to the
new builder.

Refs:
  https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/inference/model.py
  https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/config.json
PR1 BF16 path:
- DeepseekV4Model registered for "DeepseekV4ForCausalLM"
- set_gguf_parameters writes all V4 KVs (q/o lora ranks, o_groups, sliding_window,
  compress_ratios array, compress_rope_freq_base, hc_*, n_hash_layers, indexer_*,
  swiglu_clamp arrays, expert_gating_func=SQRT_SOFTPLUS)
- modify_tensors does:
    * source-name remap (layers.{i}.attn.wq_a.weight → blk.{i}.attn_q_a.weight,
      mtp.0.* → blk.43.mtp.*, hc_* → hc_*.{fn,base,scale}.weight)
    * FP8 e4m3 weight × UE8M0 scale (block 128×128) → BF16 dequant
    * FP4 e2m1 packed (int8, 2 nibbles/byte) × UE8M0 scale (block 32 along K) → BF16
    * tid2eid int64 → int32
    * keep hc_*, attn_sinks, ape, exp_probs_b in F32; downcast other 2D → BF16
    * stack 256 routed experts per layer/per kind into 3D [n_experts, out, in]
    * reshape attn_o_a 2D (n_groups*o_lora_rank, dim_per_group) → 3D for grouped einsum
- gguf-py ExpertGatingFuncType: add SOFTMAX_WEIGHT=3 and SQRT_SOFTPLUS=4
  to keep python enum aligned with src/llama-hparams.h
- llama-model.cpp DEEPSEEK_V4 hparams: switch swiglu_clamp_{exp,shexp} from scalar
  read-and-broadcast to get_key_or_arr (matches converter's per-layer array emit)

Refs:
- https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/inference/model.py
- https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/inference/convert.py
- https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/inference/kernel.py
- LLM_CHAT_TEMPLATE_DEEPSEEK_V4 enum + str-table entry "deepseek-v4"
- Detection by markers: V3 markers (<|User|>, <|Assistant|>, <|end▁of▁sentence|>)
  AND </think> sentinel; placed before V3 detection so it wins
- Renderer: emits <|begin▁of▁sentence|> (V4 tokenizer has add_bos_token=False),
  per-message system="{content}", user="<|User|>{content}", assistant="{content}<eos>",
  with <|Assistant|></think> emitted as a turn-transition suffix when next role
  is assistant or as add_generation_prompt suffix
- Converter writes equivalent Jinja chat template into the GGUF so HF
  apply_chat_template renders byte-equivalent output
- test-chat-template.cpp: V4 case for the standard 6-message multi-turn fixture
- llama-model-loader.cpp: instantiate get_arr<std::array<uint32_t, LLAMA_MAX_LAYERS>>
  to satisfy the swiglu_clamp + compress_ratios array reads from the V4 path

Verified byte-for-byte equivalence of three pipelines on 7 cases
(single user, system+user, full multi-turn, 5-turn, developer role, "Say 1 to 100",
"Say 1 to 1000"):
  encoding_dsv4.encode_messages == llama.cpp LLM_CHAT_TEMPLATE_DEEPSEEK_V4
                               == HF AutoTokenizer.apply_chat_template (with V4 Jinja)

Refs:
- https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/encoding/encoding_dsv4.py
- https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/tokenizer_config.json
Wires up all DeepSeek-V4-Flash per-layer tensors so the converter's GGUF
loads end-to-end. Stub graph still uses only token_embd + output_norm +
output (junk logits); per-layer tensors are TENSOR_NOT_REQUIRED so partial
GGUFs still load.

Per-block (43 main + 1 MTP):
  attn_norm, ffn_norm, attn_q_a_norm, attn_kv_norm,
  attn_q_a, attn_q_b, attn_kv (single shared MQA), attn_sinks,
  attn_o_a (3D [dim_per_group, o_lora_rank, n_groups]), attn_o_b,
  hc_attn_{fn,base,scale}, hc_ffn_{fn,base,scale},
  ffn_gate_inp, exp_probs_b OR ffn_gate_tid2eid (hash layers 0..2),
  ffn_{gate,down,up}_exps (stacked), ffn_{gate,down,up}_shexp.

Per-block, conditional on compress_ratios[il]:
  ratio>0  → compressor.{wkv,wgate,ape,norm}
  ratio==4 → indexer.{attn_q_b,weights_proj,compressor.{wkv,wgate,ape,norm}}

MTP block (last layer): adds e_proj, h_proj, enorm, hnorm, norm,
mtp.hc_head.{fn,base,scale}.

Top-level: token_embd, output_norm, output, output.hc_head.{fn,base,scale}.

Also:
- llama-model.h: V4 fields on llama_layer (v4_attn_kv, v4_attn_kv_norm,
  v4_attn_o_{a,b}, v4_hc_{attn,ffn}_{fn,base,scale}, v4_compressor_*,
  v4_idx_compressor_*, v4_ffn_gate_tid2eid, v4_mtp_*) and on llama_model
  (v4_hc_head_{fn,base,scale}).
- convert_hf_to_gguf.py: emit hc_head as `output.hc_head.{fn,base,scale}`
  (matches gguf-py constants.py); MTP block keeps standard attention/ffn
  names at bid=N so they slot into the per-layer table; only the MTP
  extras carry the `.mtp.` infix; indexer.wq_b → indexer.attn_q_b
  (matches LLM_TENSOR_INDEXER_ATTN_Q_B template).
Previous _maybe_pair buffered ALL .weight tensors awaiting a sibling .scale.
That over-collected: V4 has many BF16 / F32 .weight tensors (norms,
hc_*, compressor.{wkv,wgate,ape,norm}, indexer.{compressor,weights_proj},
ffn.gate.{weight,bias,tid2eid}, mtp norms) that have no .scale companion;
they would never emit because the loop kept them buffered.

Fix: explicit allowlist _FP8_PACKED_SUFFIXES of the names that ship a
sibling .scale on disk (verified via safetensors index inspection):
  attn.wq_a, attn.wq_b, attn.wkv, attn.wo_a, attn.wo_b,
  attn.indexer.wq_b,
  ffn.shared_experts.{w1,w2,w3},
  e_proj, h_proj
Routed-expert weights (ffn.experts.*.w[123].weight) are FP4-packed and
handled by the separate _consume_expert path.

modify_tensors restructured to emit non-paired tensors directly without
touching the buffer. Verified by tests/converter_smoke/synthetic_v4_test.py
covering FP8 dequant, wo_a 2D→3D reshape, tid2eid → int32, hc_* F32
passthrough, top-level output.hc_head, FP4 expert stacking (256 → 3D),
MTP bid=43 remap with both `.mtp.` extras and standard attn tensors.
Previous version yielded BF16 from FP8/FP4 dequant and torch.int32 for
tid2eid. The framework's LazyTorchTensor.numpy() at line 13745 only
supports F16/F32/U8 — BF16/I32/I64 raise KeyError. The fix splits along
two paths:

1. _fp8_dequant and _fp4_dequant now return F32. The GGUF writer's
   tensor_force_quant pipeline downcasts to BF16 at write time when
   --outtype bf16 is set, mirroring the standard conversion contract for
   every other model in the codebase.

2. tid2eid: instead of yielding (name, int32_tensor), call
   self.gguf_writer.add_tensor(name, np.int32_array, raw_dtype=I32)
   directly and return [] from modify_tensors so the standard pipeline
   skips it. This keeps the I32 storage type intact; loader on the C++
   side reads it as I32 for the hash-routing LUT.

3. _emit_single uniformly upcasts non-F32 inputs to F32 before yielding
   (in practice this only triggers for the wo_a path post-reshape since
   the parent loop already cast everything to F32).

Refs:
- inference/convert.py — FP8 / FP4 dequant formula
- inference/model.py:556-559 — tid2eid is I32 by construction

synthetic_v4_test.py updated: asserts F32 dtype on dequant outputs and
verifies the gguf_writer.add_tensor I32 path for tid2eid via a stub.
…op type

Two small graph extensions used by the V4-Flash forward graph:

1. LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS = 4 dispatch in build_moe_ffn:
   probs = sqrt(softplus(logits)). Wired into the gating switch using the
   already-existing ggml_softplus and ggml_sqrt ops.
   ref: https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/inference/model.py:565,571

2. LLM_FFN_SWIGLU_CLAMPED_PRE op type. Pattern:
     gate.clamp(-INF, lim) → silu, up.clamp(±lim), gate*up
   Plumbed into both build_ffn (shared expert / dense FFN path) and
   build_moe_ffn (routed expert path). Per-layer limit comes from
   hparams.swiglu_clamp_{exp,shexp} arrays. The shared expert site reads
   swiglu_clamp_shexp[il] for V4 (V4 sets both arrays = swiglu_limit=10.0).
   This is the pre-silu clamp variant; not to be confused with the existing
   Step35 post-silu clamp at line 1200-1218.
   ref: https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/inference/model.py:600-602
        SGLang PR #23776 (apply on shared experts too)

Verified clean build of llama target.
V4-Flash safetensors include the UE8M0 (E=8, M=0) FP8 scale dtype
alongside the standard FP8 e4m3fn weight tensors. The framework's
LazyTorchTensor._dtype_str_map didn't know "F8_E8M0", causing KeyError
the moment safetensors-rs handed back a scale tensor.

Adds:
  _dtype_str_map["F8_E8M0"] = torch.float8_e8m0fnu
  _dtype_byteswap_map[torch.float8_e8m0fnu] = np.uint8
  _dtype_map[torch.int32] = np.int32

The int32 entry is needed by the tid2eid path: hash-routing LUT is
written via gguf_writer.add_tensor with raw_dtype=I32. Tid2eid is
delivered to modify_tensors as F32 (auto-cast at line 783-784); we
materialize via LazyTorchTensor.to_eager() so the .to(int32).numpy()
chain runs on a real torch.Tensor instead of a Lazy wrapper.

Verified with the synthetic V4 fixture; weights conversion now starts
emitting tensors past blk.0 (it stops here previously on tid2eid).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants