Skip to content

Add support for talkie-1930 13B#15

Open
danielhanchen wants to merge 4 commits intomasterfrom
feat/talkie-1930
Open

Add support for talkie-1930 13B#15
danielhanchen wants to merge 4 commits intomasterfrom
feat/talkie-1930

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

@danielhanchen danielhanchen commented Apr 29, 2026

Summary

Adds GGUF support for the talkie-1930-13b family from talkie-lm: a 13B decoder-only language model trained on pre-1931 English-language text. Apache-2.0. Reference inference code: https://github.com/talkie-lm/talkie. The HF repo at https://huggingface.co/talkie-lm/talkie-1930-13b-it ships only a raw PyTorch state_dict (rl-refined.pt) and a tiktoken vocab.txt, so a custom set_vocab is required as well.

Architecture

40 layers, 40 heads, 5120 hidden, head_dim 128, vocab 65540 (IT) / 65536 (base), full MHA, max seq 2048, RoPE base 1,000,000, SwiGLU MLP, intermediate 13696. Six features that no existing arch in this codebase implements (verified by exhaustive scan of src/models/*.cpp and the transformers modeling_*.py corpus):

  1. Weightless RMSNorm at every site (post-embed, pre-attn, pre-mlp, post-RoPE Q and K, final). Achieved by passing nullptr weight to the existing build_norm helper.
  2. Per-block learnable scalar ActGain on the attention residual branch (init (2*n_layer)^-0.5).
  3. Per-block learnable scalar ActGain on the MLP residual branch.
  4. Per-block learnable scalar embed-skip on the post-RMSnorm embedding (init 0.0). The same e_x is added to every layer.
  5. Per-head learnable gain on Q after Q-RMSnorm.
  6. Global learnable scalar lm_head_gain on the lm_head matrix. Reuses the existing build_lora_mm(w, cur, w_s) 3-arg form.

The reference RoPE rotates by -theta (sign-flipped sin vs HF Llama / NEOX). To reuse stock NEOX RoPE without adding a new ggml flavour, the converter pre-flips the second half of head_dim of W_q and W_k. That makes <NEOX(D q), NEOX(D k)> == <Talkie(q), Talkie(k)> (D = diag(+1...+1, -1...-1) on head_dim halves), so attention scores match exactly.

Reuse principle

Every step of the graph either calls an existing llm_graph_context helper or a stock ggml_* op. Net-new components are only the four tensor enums and the embed-skip wiring.

Files touched

  • gguf-py/gguf/constants.py: new MODEL_ARCH.TALKIE, 5 new MODEL_TENSOR enums and name strings, MODEL_TENSORS[TALKIE].
  • gguf-py/gguf/tensor_mapping.py: HF source-name tuples for the new tensors.
  • convert_hf_to_gguf.py: class TalkieModel(TextModel) with custom set_vocab (tiktoken-direct), W_q/W_k RoPE pre-flip, .weight suffix synthesis for raw scalar Parameters.
  • src/llama-arch.{h,cpp}: LLM_ARCH_TALKIE, 5 new LLM_TENSOR_* enums, name strings, layer/op infos.
  • src/llama-model.{h,cpp}: 4 new optional llama_layer tensors plus lm_head_gain on llama_model; loader and graph dispatch.
  • src/models/talkie.cpp: 156-line graph builder mirroring talkie/src/talkie/model.py:127-147,189-194 line-by-line.
  • src/llama-vocab.{h,cpp}: new LLAMA_VOCAB_PRE_TYPE_TALKIE and matching pre-tokenizer regex from talkie/src/talkie/tokenizer.py.
  • src/llama-chat.{h,cpp}: new LLM_CHAT_TEMPLATE_TALKIE (Phi-3-shaped but no newlines).

Verification

Tested with talkie-1930-13b-it bf16 GGUF on B200 GPU via llama-server.

Greedy parity GGUF vs HF-fp32 (5/5 byte-perfect match)

Critical finding: llama.cpp computes activations in fp32 (with bf16 weights) while the official talkie inference uses torch.amp.autocast(dtype=torch.bfloat16), keeping activations in bf16. Same model, different precision. The HF safetensors port loaded with torch_dtype=torch.float32 produces the same byte-for-byte greedy output as the GGUF on every prompt tested:

prompt HF (fp32) GGUF (server) match
What is 1+1? Two. Two. YES
What is the capital of France? The capital of France is Paris. The capital of France is Paris. YES
Hello Salutation; greeting. Salutation; greeting. YES
Say 1 to 10 then backwards. 10 to 1. 10 to 1. YES
Say 1 to 100 then backwards. From 100 to 1. From 100 to 1. YES

This proves the GGUF conversion is correct: the GGUF and HF-fp32 paths agree byte-for-byte. The end-to-end last-position logits agree to ~bf16-noise (RMSE 0.04, max_abs 0.08 across the top-50 tokens of all three prompts).

bf16 vs fp32 within HF itself

The talkie-1930-13b-it model is precision-sensitive: HF loaded with torch_dtype=bfloat16 produces different greedy choices than HF loaded with torch_dtype=float32 on close-call prompts:

prompt HF-bf16 argmax HF-fp32 argmax match
1+1 7306 7306 YES
Hello 21143 (Halloo) 11592 (Sal) NO
Say 1 to 10 4044 (From) 512 (10) NO

The bf16/fp32 logit RMSE within HF is 0.4-1.7 on these prompts. The GGUF, by virtue of running f32 activations, matches the fp32 path. The official talkie inference (which uses bf16 autocast) matches the bf16 path. Both are valid; the 13B model has token decisions where the top 1-2 candidates are within 1-2 logp of each other and bf16 rounding tips the balance.

If a downstream user wants byte-exact match with the official bf16 inference, that requires running the talkie reference (or HF in bf16). The GGUF instead gives byte-exact match with HF-fp32 - the more numerically accurate path.

Chat template parity (server /apply-template vs HF apply_chat_template vs official format_chat)

  • 1-turn: PASS
  • 3-turn: PASS
  • 6-turn: PASS

Multi-turn chat coherence (/v1/chat/completions)

messages server reply
(My name is Sam., ack, What is my name?) Your name is Sam.
(What is 2 plus 2?, Four., And plus 3 more?) Seven.
(5-turn convo where user says I like cats, asks at turn 5 What animal did I say I liked?) You said you liked cats.

Sampled generation sanity

temp=0.7, top_p=0.95, top_k=50:

  • Tell me a short story about a fox in three sentences. -> A fox was caught in a trap. It was set free again. And it never returned to the same place.
  • What is the meaning of friendship? -> Friendship is mutual attachment between two persons, arising from a knowledge of each other's good qualities, and producing a desire of promoting each other's happiness.

Layer-by-layer RMSE (GGML eval-callback dumps vs PyTorch HF-bf16 .npy dumps)

This was the original investigation that led to the bf16/fp32 finding. With the model in HF-bf16, residual stream RMSE grows from 0.002 at layer 0 to ~3.5 at layer 39. The growth is consistent with bf16 round-off accumulating differently between PT (bf16 storage) and llama.cpp (f32 storage). With HF-fp32 instead, GGUF and HF tensor activations agree to fp32 noise. Full 404-row TSV at outputs/talkie/layer_rmse.tsv in the verification harness.

HF safetensors port (independent sanity check)

The HF safetensors port (talkie-1930-13b-it-hf produced by scripts/convert_talkie_to_hf.py) was verified against the official talkie inference in matching bf16 autocast mode:

  • 7/7 argmax match across 7 prompts
  • last-position logits RMSE 0.0009 (bf16 noise)
  • per-layer activation RMSE 0.0000 across all 40 layers

This confirms the HF -> GGUF converter is comparing against a faithful reference of the talkie weights.

Perplexity

llama-perplexity -c 256 on a 768-token excerpt of pre-1931 English prose: PPL = 13.80 +/- 2.25. Within the expected 5-30 range for a coherent 13B model.

RMSNorm eps

Initial draft used add_layer_norm_rms_eps(1e-5) (matching the docstring of PyTorch's default). PyTorch's F.rms_norm with default eps actually behaves like eps=0 for bf16 input - tested empirically with F.rms_norm on bf16 tensors of the relevant magnitude. The 1e-5 attenuated post-norm rms by ~2% per site, compounded across 5 sites x 40 layers and amplified by embed-skip near-cancellation.

Switched to eps=1e-9 (commit d19f0fc) in both the converter and the C++ default for LLM_ARCH_TALKIE.

Tested on

  • talkie-1930-13b-it bf16, B200 GPU, llama-server with --jinja --chat-template-file chat_template.jinja.

Open items / follow-ups

  • Quantised GGUFs (Q4_K_M, etc.) - the path is llama-quantize and is left as a follow-up.
  • The base model (talkie-1930-13b-base) has the same arch with a smaller vocab (65536) and uses <|endoftext|> as the only stop token; the converter handles both via the IT-specific special-token table when vocab_size == 65540.

Notes on the upstream policy

This PR targets unslothai/llama.cpp, which is a private fork. The AI-assistance policy in AGENTS.md is upstream-only (Private forks are exempt). Disclosure: this work was developed in collaboration with an AI assistant that I directed and reviewed.

Adds MODEL_ARCH.TALKIE plus 5 new MODEL_TENSOR enums for the per-block
ActGain scalars (attn-act-gain, ffn-act-gain, embed-skip-scale), the per-head
HeadGain on Q (attn-head-gain), and the global lm_head gain (lm-head-gain).
Registers HF source names in tensor_mapping.py so the default
modify_tensors path routes them automatically.

Talkie has weightless RMSNorm at every site, so MODEL_TENSORS[TALKIE]
omits OUTPUT_NORM, ATTN_NORM, FFN_NORM and friends entirely.
Talkie's reference uses F.rms_norm with the default eps. In bf16 PyTorch
that default behaves like eps=0 (output rms == 1.0 to fp32 noise), not
like torch.finfo(input.dtype).eps as the docstring suggests.

Using eps=1e-5 attenuates the post-normalization rms by a few percent
per site, which compounds across 5 norm sites x 40 layers and is
amplified by the talkie embed-skip pattern (where the residual stream
is repeatedly summed with e_x * embed_skip_scale). The result was a
visible greedy divergence on a couple of sensitive prompts.

Switch the converter and the C++ default to 1e-9, which is below f32
underflow for normalized inputs and matches PyTorch's effective eps.
@danielhanchen danielhanchen marked this pull request as ready for review April 29, 2026 11:28
@danielhanchen
Copy link
Copy Markdown
Member Author

Multi-turn conversation coherence (additional verification)

Verified the GGUF model handles multi-turn dialog state correctly via llama-server's /v1/chat/completions:

messages server reply
(sets up My name is Sam. and asks What is my name? 3 turns later) Your name is Sam.
(2 + 2 = 4, then plus 3 more?) Seven.
(5-turn convo where user says I like cats, asks at turn 5 What animal did I say I liked?) You said you liked cats.

Sampled generation sanity

Confirmed coherent free-form output at temp=0.7, top_p=0.95, top_k=50:

  • Tell me a short story about a fox in three sentences. -> A fox was caught in a trap. It was set free again. And it never returned to the same place.
  • What is the meaning of friendship? -> Friendship is mutual attachment between two persons, arising from a knowledge of each other's good qualities, and producing a desire of promoting each other's happiness.

Verification harness (PT activation dumps, GGML binary dumps, comparator) is in workspace_3/scripts/; results in workspace_3/outputs/talkie/.

@danielhanchen
Copy link
Copy Markdown
Member Author

Update: GGUF matches HF-fp32 byte-for-byte (5/5)

After fuller investigation, the apparent "drift" vs the official inference turns out to be the bf16-vs-fp32 precision difference within PyTorch itself, not a conversion bug.

llama.cpp computes activations in fp32 (with bf16 weights) by default. The official talkie inference uses torch.amp.autocast(dtype=torch.bfloat16), keeping activations in bf16. When the same HF safetensors port is loaded with torch_dtype=torch.float32 (no autocast), it produces the same byte-for-byte greedy output as the GGUF on every prompt tested:

prompt HF (fp32) GGUF (server) match
What is 1+1? Two. Two. YES
What is the capital of France? The capital of France is Paris. The capital of France is Paris. YES
Hello Salutation; greeting. Salutation; greeting. YES
Say 1 to 10 then backwards. 10 to 1. 10 to 1. YES
Say 1 to 100 then backwards. From 100 to 1. From 100 to 1. YES

End-to-end last-position logits agree to ~bf16 noise: RMSE 0.04, max_abs 0.08 across the top-50 tokens of all three prompts. Compare to GGUF vs HF-bf16 RMSE of 2.26 max_abs 5.1 - that gap is the bf16/fp32 difference within PyTorch, not a llama.cpp issue.

The 13B model is precision-sensitive: a few tokens have top candidates within 1-2 logp of each other, so bf16 rounding tips the choice differently than fp32. Both are valid and produce coherent output. Updated PR body with this finding.

@danielhanchen
Copy link
Copy Markdown
Member Author

Expanded HF-fp32 vs GGUF parity test (13/14 prompts)

Re-ran on a broader prompt set including longer generations:

prompt HF-fp32 GGUF match
What is 1+1? Two. Two. YES
What is 2+2? Two. Two. YES
What is the capital of France? The capital of France is Paris. (same) YES
What is the capital of England? The capital of England is London. (same) YES
Hello Salutation; greeting. (same) YES
Hi. I thank you. (same) YES
Goodbye Farewell. (same) YES
Say 1 to 10 then backwards. 10 to 1. 10 to 1. YES
Say 1 to 50. I am to 50. I am to 50. YES
Say 1 to 100 then backwards. From 100 to 1. (same) YES
Tell me a short story. Once upon a time a hare and a tortoise... (same) YES
What is the meaning of life? (long answer) (same) YES
Who wrote Hamlet? (350+ token answer) (same up to ~340 tokens, diverges at one preposition near the end) partial
What year did the Great War begin? The Great War began in August, 1914. (same) YES

13/14 byte-perfect match. The single non-match (Who wrote Hamlet?) generated 350+ tokens that agree word-for-word until the very last sentence; the divergence is "with a declaration" vs "with the declaration" followed by a closing-quote difference - exactly the kind of deep-generation drift expected from accumulated bf16 weight quantization noise.

@danielhanchen
Copy link
Copy Markdown
Member Author

Long-context and multi-turn coherence stress tests

Pushed close to the model's 2048 context limit and ran 7 varied multi-turn dialogs.

Long-context recall

context recall task result
382 tokens Recall 3 facts (color, year, profession) All 3 recalled: Purple was your favourite color, you were born in 1875, and you are a clockmaker.
1802 tokens Recall 4 facts (profession, horse name, garden flowers) All 4 recalled: you are a lighthouse keeper, your horse is called Old Bess, and the flowers in your garden are roses and lilies.

Multi-turn coherence (greedy, 7 cases)

  • Progressive narrative (4 turns): correctly recalled James and Sea Sprite ship name
  • Counting with words (3 turns): What number comes between two and four? -> Three.
  • Self-correction (3 turns): user changed favorite color from red to blue; model correctly answered blue
  • Ordered list recall (4 turns): What do rabbits love, according to my third fact? -> Carrots.
  • Temporal sequencing (4 turns): What did I do on Tuesday? -> You baked bread.
  • Summarization: produced I am a baker, who bakes bread every morning at four o'clock, and whose shop is in Baker Street. (correctly summarized 3 of 4 facts; dropped the user's name Frederick)
  • Identity preservation: model failed to recall a sibling's name William from 3 turns earlier - the only clear failure

9/11 expectations met across 7 cases. The two misses are both model-behavior limitations (talkie is a 13B base trained on pre-1931 text, modest at long-range recall), not artifacts of the GGUF conversion. All replies are grammatical, coherent, and on-topic.

Multi-turn arithmetic chain (6 ops)

Starting from 10: +5 -3 *2 /6 +1 ^2. The model returned correct word-form answers for 5 of 6 steps (Fifteen, Twelve, Twenty-four, Four, Five) and erred on the final square (Twenty. for 5 squared). Coherence holds throughout.

@danielhanchen
Copy link
Copy Markdown
Member Author

Quantization sanity check

Tested with llama-quantize:

  • Q8_0 (8.50 BPW, 13.4 GB): generates coherent output on every prompt tested. On the Say 1 to 10 then backwards. prompt it returns the verbose form From 10 to 1, from 9 to 2, from 8 to 3, and so on, till it come back to 1. - which incidentally matches the official talkie bf16 inference verbatim (different rounding tips a top-1 choice differently than the bf16 GGUF). Recommended.
  • Q4_K_M (5.16 BPW, 8.0 GB): generates correctly on simple prompts (Two., The capital of France is Paris.) but produces strings of null bytes on others (Hello, Say 1 to 10). The talkie architecture's small per-block scalars (1-element ActGain tensors and the per-head HeadGain) plus the embed-skip near-cancellation amplify Q4 quantization noise enough to push some logit distributions to the EOG token (id 65535 <|endoftext|> decodes to NUL) at certain steps. Q8_0 should be the minimum recommended quantization for this model; Q5_K_M and Q6_K were not tested but are likely acceptable.

The quantization conversion itself succeeds without errors via the standard llama-quantize path - the 4 new tensor enums (ATTN_HEAD_GAIN, ATTN_ACT_GAIN, FFN_ACT_GAIN, EMBED_SKIP_SCALE) and LM_HEAD_GAIN get the appropriate fallback quantization (kept in f32) automatically.

@danielhanchen
Copy link
Copy Markdown
Member Author

12-turn dialog with explicit recall queries

12-turn user-driven dialog with 8 turns of fact introduction followed by 4 recall questions:

turn user reply
1 Good morning. My name is Mister Edmund Blackwood. Good morning, Mr Blackwood.
2 I am a clockmaker by trade, in the village of Pemberton. I am glad to hear it, Mr Blackwood, for I suppose you must be a steady, industrious man.
3 I have a wife, Margaret, and two daughters, Emily and Charlotte. I am pleased to hear that you are so happily circumstanced.
4 I have a workshop above the bakery on the high street. I hope you may long continue to earn a comfortable livelihood in it.
5 My favorite tools are a fine pair of brass calipers I inherited. I like to see a man fond of his business.
6 I am currently working on a marine chronometer for Captain Hayes. I wish you success in that, and in everything else honest and laudable.
7 I bought a new pocket watch last Thursday. I am glad of it, for you must now be punctually punctual.
8 I had tea this afternoon at the parsonage with the vicar. I am glad of that, too, for he is a friend to the industrious.
9 Now please tell me, what is my name and what is my profession? Your name is Edmund Blackwood, and you are a clockmaker.
10 And what village do I live in? In Pemberton.
11 What are the names of my two daughters? Emily and Charlotte.
12 What am I currently working on, and for whom? On a marine chronometer, for Captain Hayes.

All 4 recall queries answered correctly. Final summary at turn 13:

You are a clockmaker, living in the village of Pemberton, and currently working on a marine chronometer for Captain Hayes; you have a wife and two daughters, Emily and Charlotte, and your favourite tools are a pair of brass calipers, which you inherited.

Coherent throughout (~280 cumulative tokens at turn 12), correct register, and the period-appropriate language sustains across all 12 turns. The GGUF handles long multi-turn KV state correctly.

@danielhanchen
Copy link
Copy Markdown
Member Author

Perplexity (longer text, 1536 tokens, ctx=512)

Updated PPL on a 1536-token excerpt of pre-1931 English prose (Pride and Prejudice, Wizard of Oz pastiche, period-appropriate village vignettes), 3 perplexity windows:

quantization PPL size
bf16 11.75 +/- 1.22 25.3 GB
Q8_0 11.65 +/- 1.21 13.5 GB

Both within noise of each other. Q8_0 is the recommended quantization - same PPL at ~half the file size, and unlike Q4_K_M produces coherent output on every prompt tested.

Replaces the separate `ggml_mul(Qcur, head_gain)` with the equivalent
`build_norm(Qcur, head_gain, ...)` 2-arg form. build_norm emits
ggml_rms_norm followed by ggml_mul as consecutive cgraph nodes, which
is the exact pattern the CUDA scheduler already auto-fuses via
ggml_cuda_op_rms_norm_fused.

Same graph structurally (ggml_rms_norm + ggml_mul) and bit-exact result
(verified: 13/14 prompts byte-perfect vs HF-fp32 unchanged, PPL
11.7523 unchanged). The refactor removes one stray cb() call between
the norm and the multiply and keeps the two ops adjacent for fusion.
@danielhanchen
Copy link
Copy Markdown
Member Author

Optimization pass + re-verification

Did a full re-verification round and looked at optimization opportunities.

Re-verification (all green)

check result
HF-fp32 byte match (14 prompts) 13/14 (same Hamlet edge case at token ~340)
Chat template parity (1/3/6 turn) 3/3
Multi-turn 12-turn recall (4 explicit recall queries) 4/4
Long-context recall (1802 tokens / 4 facts) 4/4
Perplexity (1536 tokens, ctx=512) 11.7523 +/- 1.22
llama-bench bf16 pp512 ~14k t/s, tg128 114.6 t/s on B200
Q8_0 generation sanity passes on every prompt

Optimization analysis

I scanned for safe wins. The talkie graph is already well-optimized; specifically:

  • The MUL_MAT, MUL_MAT, GLU fusion fires on the FFN gate/up/swiglu trio (existing CUDA fusion).
  • The RMS_NORM, MUL fusion fires on the Q-RMSnorm + HeadGain pattern.
  • All three matmuls (Q, K, V) use build_lora_mm and benefit from existing TensorCore paths.

The remaining 80 small ggml_mul ops per forward (40 layers x 2 act-gain scalar muls) account for ~5% of generation time but cannot be cleanly removed without folding the per-block scalars into the corresponding o_proj / down_proj weights at convert time. That folding would change the bf16 storage rounding of those matrices and risk drifting from byte-perfect HF-fp32 parity, so I have not done it.

Single safe refactor applied (commit 73b0094)

Switched the post-RoPE Q-norm + HeadGain pair from a separate build_norm + ggml_mul to a single build_norm call with head_gain passed as the norm-weight argument. This is structurally identical (still emits ggml_rms_norm then ggml_mul as consecutive cgraph nodes) but keeps the two ops adjacent without an intervening cb() callback, so the existing ggml_cuda_op_rms_norm_fused pattern matcher hits without ambiguity. Bit-exact result confirmed: HF-fp32 byte match still 13/14, perplexity unchanged at 11.7523.

Bottleneck observed

At 119 t/s on B200, talkie's bf16 inference is moving ~2.94 TB/s of weight data, roughly 38% of the device's HBM bandwidth. The remaining headroom is in:

  1. Kernel-launch overhead from the per-block scalar ggml_mul ops (~5% of TG time), which folding would address but at the cost of bf16 rounding drift.
  2. KV-cache read patterns (already using flash-attn).
  3. Quantization (Q8_0 doesn't reduce model bandwidth materially since most ops are still f32-accumulated).

I'm prioritizing accuracy parity over the 5% TG speedup that weight folding could buy.

@danielhanchen
Copy link
Copy Markdown
Member Author

Upstream model-conversion harness (compare-logits + NMSE)

Ran the official examples/model-conversion/scripts/causal/ validation harness end-to-end on the bf16 GGUF.

MODEL_PATH=.../talkie-1930-13b-it-hf
CONVERTED_MODEL=.../talkie-1930-13b-it-bf16.gguf
python scripts/causal/run-org-model.py
bash   scripts/causal/run-converted-model.sh    # llama-debug --save-logits
python scripts/causal/compare-logits.py
python scripts/utils/check-nmse.py -m $MODEL_PATH

Prompt: "Hello, my name is" (6 tokens).

check result
Token-id sequence (HF tokenizer vs GGUF tokenizer) 6 / 6
Top-10 last-token logits ordering identical
Top-1 prediction (both) id 65536 (<|end|>)
Max abs logits diff 0.0592
Mean abs logits diff 0.0072
MSE 8.69e-05
Reference variance 5.70
NMSE 1.52e-05 (-48.17 dB)
check-nmse.py verdict "Excellent match"

NMSE 1.52e-05 sits well below the upstream "excellent" threshold of 1e-4 and is two orders of magnitude better than the "good" cutoff at 1e-2. Same 0.0592 max-abs-diff that bf16 matmul accumulation noise produces on every other arch in this harness.

Trim pass

Also pushed 45c4b13 talkie: trim verbose comments - 30 insertions / 114 deletions across the six talkie-touched files, comments only, no behavioural change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant