Post

When Reasoning Models Break Tokenization: The Hidden Complexity of Multiturn Training

When Reasoning Models Break Tokenization: The Hidden Complexity of Multiturn Training

I recently spent two weeks refactoring multiturn tokenization and masking for VeRL. While VeRL already had a functional implementation, what initially seemed like a straightforward refactor turned out to be surprisingly nuanced. Through multiple iterations, we arrived at a solution that is both robust and flexible for VeRL users. This post shares the key learnings and design choices from that journey.

Single-Turn: The Simple Case

In single-turn LLM training, each example consists of a prompt string — typically created from a set of messages using a chat template — and an LLM response (labels in SFT or generated by the actor in RL). These elements are tokenized, concatenated, and padded for training.

The Basic Structure

Let’s walk through a concrete example to understand the process:

1
2
3
4
5
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "How are you?"}
]
response = "I'm good, thank you!"

Expand

Message Formatting

To prepare this for training, we follow a two-step process:

  1. Format the prompt: Use the tokenizer’s chat template to convert messages into a properly formatted prompt string
  2. Format the response: Append an EOS token to the assistant’s response so that the LLM learns when to stop
1
2
3
4
5
6
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", use_fast=True)

prompt_str = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
response_str = response + tokenizer.eos_token

Expand

This produces:

1
2
3
4
5
6
7
8
9
# Prompt string
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant

# Response string
I'm good, thank you!<|im_end|>

Expand

From Strings to Token IDs

Next, we tokenize these strings separately, then concatenate them to form a single sequence. Finally, we pad the sequence to match our desired training length:

1
2
3
4
5
6
7
# Tokenize
prompt_token_ids = tokenizer.encode(prompt_str, add_special_tokens=False)
response_token_ids = tokenizer.encode(response_str, add_special_tokens=False)

# Concatenate and pad
input_ids = prompt_token_ids + response_token_ids
input_ids = input_ids + [tokenizer.pad_token_id] * (max_length - len(input_ids))

Expand

Creating Training Masks

With our token IDs ready, we need to generate two critical masks:

  1. Attention Mask — A binary mask (0s and 1s) indicating which tokens the model should attend to. Padding tokens receive 0 and are ignored during the forward pass.

  2. Loss Mask — Specifies which tokens to compute loss on. Since we only want to learn from the assistant’s response, other tokens are set to -100 (ignored by the loss function).

Here’s how this works in practice:

1
2
3
4
5
6
7
8
9
10
11
12
13
prompt_token_ids = [1, 2, 3, 4, 5]
response_token_ids = [6, 7, 8]

# Create masks separately
prompt_attn_mask = [1, 1, 1, 1, 1]
response_attn_mask = [1, 1, 1]
prompt_loss_mask = [-100, -100, -100, -100, -100]
response_loss_mask = [6, 7, 8]

# Concatenate and pad (assuming pad token ID is 151643)
input_ids = [1, 2, 3, 4, 5, 6, 7, 8, 151643, 151643]
attention_mask = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
loss_mask = [-100, -100, -100, -100, -100, 6, 7, 8, -100, -100]

Expand

This setup ensures the model focuses exclusively on learning from the assistant’s response, ignoring prompts and padding during training. Simple and straightforward — but as we’ll see next, multiturn conversations complicate this elegant approach.

Multiturn and Multistep: The Complex Case

Before diving into the challenges, let’s clarify two important concepts:

  • Multiturn: Unlike single-turn (one user message → one assistant response), multiturn conversations have multiple rounds of user-assistant exchanges in a single training example.
  • Multistep: After receiving a user message, the assistant may need to perform multiple steps — typically tool calls — to gather information before providing a final response.

These patterns are essential for training capable AI agents, but they significantly complicate tokenization and masking. In VeRL, we support both multiturn and multistep rollout to enable training of sophisticated agent models.

The Core Challenge

The core challenge is this: we still want the model to learn only from assistant messages, but in multiturn multistep settings, these assistant responses are scattered throughout a longer conversation history that includes user messages, system prompts, and potentially tool outputs.

Here’s why this becomes problematic:

  1. Multiple interleaved messages: Instead of a simple prompt → response structure, we now have alternating sequences like: system → user → assistant → user → assistant → tool → assistant → user → assistant.

  2. Single string collapse: Modern models use chat templates that convert this entire sequence into one continuous training string, making it difficult to track where each assistant response begins and ends.

  3. Token-level precision required: During training, we need to identify exactly which tokens belong to assistant responses for loss computation — but after tokenization, the message boundaries are no longer obvious.

This creates a fundamental problem: identifying which tokens should be included in the loss becomes challenging once tokenization collapses all messages into a single string, especially when assistant responses may appear at multiple, unpredictable positions within that string.

Finding a Solution

1. Using Tokenizer’s Built-in Functionality

HuggingFace tokenizers actually provide an option to return an assistant tokens mask when applying chat templates to messages.

This seems like the perfect solution at first glance. The tokenizer could automatically identify which tokens belong to assistant messages and create the loss mask for us.

However, there’s a catch: this feature only works if the model’s chat template explicitly marks assistant messages using the {% generation %} keyword. The template needs to wrap assistant content with special markers so the tokenizer knows exactly which parts to mask.

Unfortunately, as of today, very few models support this feature. Since VeRL is a framework that needs to support a wide variety of models, we can’t rely on a solution that only works for a handful of them. This limitation ruled out what would otherwise be the most elegant approach.

2. Mimicking Single-Turn Tokenization

Another approach is to treat multiturn like single-turn: tokenize and mask each message separately, then concatenate them. However, this isn’t straightforward because chat templates add different tokens depending on the context.

For example, when tokenizing messages separately, the template might add unexpected system prompts:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "How are you?"}
]

# Tokenizing together
tokenizer.apply_chat_template(messages, tokenize=False)
# <|im_start|>system
# You are a helpful assistant.<|im_end|>
# <|im_start|>user
# How are you?<|im_end|>

# Tokenizing separately - note the extra system message!
tokenizer.apply_chat_template([messages[1]], tokenize=False)
# <|im_start|>system
# You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
# <|im_start|>user
# How are you?<|im_end|>

Expand

To work around this, VeRL’s first implementation manually maintained format configs for popular models. These configs specified the exact tokens that chat templates add around messages (like <|im_start|>system\n and <|im_end|>\n).

For example, when processing tool responses in Qwen:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Tool messages to add
messages = [
    {"role": "tool", "content": "1 + 1 = 2"},
    {"role": "tool", "content": "2 + 2 = 4"}
]

# Expected Qwen format:
# <|im_start|>user
# <tool_response>
# 1 + 1 = 2
# </tool_response>
# <tool_response>
# 2 + 2 = 4
# </tool_response><|im_end|>

# Using our format config, we'd tokenize each part:
token_ids = (
    tokenizer.encode("<|im_start|>user") +  # from format config
    tokenizer.encode("\n<tool_response>\n") +  # from format config
    tokenizer.encode("1 + 1 = 2") + 
    tokenizer.encode("\n</tool_response>") +  # from format config
    tokenizer.encode("\n<tool_response>\n") +   # from format config
    tokenizer.encode("2 + 2 = 4") + 
    tokenizer.encode("\n</tool_response>") +   # from format config
    tokenizer.encode("<|im_end|>")  # from format config
)

Expand

Since we tokenize each part separately, we know each part’s length and can create appropriate masks.

However, this approach had significant drawbacks:

  1. Scalability: High maintenance burden - We had to manually maintain format configs for every model and update them whenever chat templates changed or new models were released. Matching the exact behavior of chat templates (especially spaces and newlines) was error-prone.

  2. Consistency: Tokenization differences - Tokenizing parts separately can produce different token IDs than tokenizing the whole string. Adjacent characters might be fused into single tokens differently:

1
2
3
4
5
6
7
8
9
# Tokenized together
token_ids1 = [151644, 872, 198, 27, 14172, 9655, 1339, 16, 488, 220, 16, 284, 220, 17, 271, 522, 14172, 9655, 29, 151645]

# Tokenized separately  
token_ids2 = [151644, 872, 198, 27, 14172, 9655, 397, 198, 16, 488, 220, 16, 284, 220, 17, 198, 198, 522, 14172, 9655, 29, 151645]

# Both decode to the same text, but with different token IDs!
# The key difference: two consecutive \n characters get tokenized as one token when together
# <|im_start|>user\n<tool_response>\n\n1 + 1 = 2\n\n</tool_response><|im_end|>

Expand

3. Incremental Tokenization with Validation

Given the scalability and consistency issues, we explored a solution that relies solely on the model’s built-in chat template, eliminating the need for manual format configs in VeRL.

The key insight: to tokenize a new message, we can compute the difference between applying the chat template to different numbers of messages:

  1. Apply the template to the first i messages
  2. Apply the template to the first i+1 messages (including the new message)
  3. The string difference between these two results is exactly how the template formats the new message

Here’s how it works:

1
2
3
4
5
6
7
8
9
10
11
12
# For an assistant message at position i
# add_generation_prompt=True excludes the generation prompt from our delta
prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False)
curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False)

# Extract and tokenize only the new content
delta = curr[len(prev):]
new_token_ids = tokenizer.encode(delta, add_special_tokens=False)

# Create masks for the new tokens
attention_mask = [1] * len(new_token_ids)
loss_mask = new_token_ids if messages[i]["role"] == "assistant" else [-100] * len(new_token_ids)

Expand

This approach elegantly solves the maintenance problem — we rely entirely on the model’s chat template. However, it still tokenizes substrings separately, which can produce different results in edge cases. To catch these issues, we added validation that compares the incremental tokenization against full tokenization. You can see the prototype implementation here.

Why We Moved On

We discovered this approach fails with reasoning models that render messages differently based on their position. Consider Qwen/QwQ-32B, which automatically removes reasoning content (marked with <think></think> tags) from non-final assistant messages:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "2 + 2 = ?"},
    {"role": "assistant", "content": "<think>user asked about a simple math question.</think> 2 + 2 = 4"},
    {"role": "user", "content": "Thank you!"}
]

tokenizer = AutoTokenizer.from_pretrained("Qwen/QwQ-32B", use_fast=True)

# When the assistant message is last:
prev = tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=False, tokenize=False)
# ... <|im_start|>assistant\n<think>user asked about a simple math question.</think> 2 + 2 = 4<|im_end|>

# When the same message is NOT last:
curr = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
# ... <|im_start|>assistant\n2 + 2 = 4<|im_end|>  # Note: <think> content removed!
# <|im_start|>user
# Thank you!<|im_end|>

Expand

The reasoning content disappears when it’s no longer the final message, making it impossible to compute a consistent delta. This position-dependent rendering affects multiple models (QwQ-32B, Qwen3 series) and fundamentally breaks the incremental approach.

4. The Final Solution: Fixed-Base Incremental Tokenization

One possible reason for this behavior is context optimization: reasoning traces often contain information that might be implicitly reflected in the final response. Removing this potentially redundant content during inference could help use the model’s context window more efficiently.

While this optimization makes sense for inference, it’s problematic for training. During training, we must preserve reasoning content so models learn:

  • When to engage in reasoning before responding
  • How to produce high-quality reasoning chains

This creates a fundamental mismatch: models ship with inference-optimized chat templates, not training-compatible ones.

First Attempt: Automatic Template Replacement

We initially tried to solve this by automatically replacing inference templates with training-compatible versions (prototype here). Creating training templates was straightforward — we simply removed the logic that strips reasoning content.

However, detecting which models need replacement proved too complicated. We explored several identification methods:

  1. Model type doesn’t work — Qwen2.5 and QwQ-32B share the same type (qwen2) but only QwQ needs replacement
  2. Template hashing is brittle — minor formatting changes produce different hashes for functionally identical templates
  3. Model name/path matching is unreliable — checking for substrings like “qwen3” or “QwQ” in the model name may work initially, but users can rename models to something completely different after training while the model still uses the same chat template

Plus, we’d still need to manually create training templates for every new model release — not scalable.

The Breakthrough: Fixed-Base Approach

We observed that chat templates only conditionally modify content in two scenarios:

  1. Adding default system messages when none exist
  2. Removing reasoning content from assistant messages under certain conditions, like when they’re not the final message

Our solution was to sidestep the problem entirely. Instead of using all previous messages as the base for incremental tokenization, we use a fixed, minimal conversation that never changes throughout processing.

Here’s the concept:

1
2
3
4
5
6
7
8
9
BASE_CONVERSATION = [
    {"role": "system", "content": "You are a helpful assistant."}, 
    {"role": "user", "content": "I am a user."}
]

# Calculate delta for any new message
base = tokenizer.apply_chat_template(BASE_CONVERSATION, add_generation_prompt=False, tokenize=False)
with_new_message = tokenizer.apply_chat_template([*BASE_CONVERSATION, new_message], add_generation_prompt=False, tokenize=False)
delta = with_new_message[len(base):]

Expand

This elegantly sidesteps the position-dependent rendering issue:

  • The system message prevents templates from adding defaults
  • No models we tested conditionally modify system or user message content
  • The base length remains constant throughout the conversation

You can find our full implementation here.

Robustness Through Validation

While this approach works for all models we’ve tested, it relies on an assumption: chat templates don’t conditionally modify system and user messages. To guard against future models breaking this assumption, we:

  • Run validation checks at the end of each rollout
  • Compare incremental tokenization against full tokenization
  • Warn users immediately if discrepancies are detected

This ensures the system fails loudly rather than silently producing incorrect results.

Next Steps

While fixed-base tokenization solves our immediate training needs, we’ve identified a discrepancy across different stages of a model’s lifecycle:

  • Training: Full reasoning traces are preserved in training data
  • RL rollout: Full reasoning traces are maintained
  • Production inference: Inference-optimized chat templates remove reasoning content for better latency

The Consistency Challenge

Ideally, model inputs should remain consistent throughout the lifecycle for optimal performance. However, the reality is that we need different chat templates for training versus inference.

This discrepancy may not significantly impact capable models — as mentioned earlier, they likely align reasoning and response content, making the information somewhat redundant. However, for less capable models or challenging tasks, inconsistencies can emerge. For example, a model might decide in its reasoning to call three tools but then lose track and only call two at the end.

Current Mitigation

While we continue evaluating the full impact of this discrepancy, VeRL provides a flag allowing users to choose which chat template to use during rollout:

  • Training template (default): Preserves consistency between training and rollout stages
  • Inference template: Removes reasoning traces, which can help when context length is a bottleneck for rollout

We default to the training template because we believe consistency between training and rollout is more critical than context optimization.

Looking Forward

We’re actively working on a solution that guarantees consistency across all three stages — training, rollout, and inference. Stay tuned for updates as we develop this unified approach.

This post is licensed under CC BY 4.0 by the author.