Back to Home

Latent Reasoning: Teaching Small LLMs to Think Without Thinking

How I fine-tuned Mistral 7B to achieve 98% accuracy as a checker behind GPT-4, with 100x latency reduction using a novel training technique where reasoning happens during training but not inference.

When building AI products at scale, every millisecond and every token counts. At my previous startup, we were burning through GPT-4 API credits like there was no tomorrow—thousands of dollars per day just to classify customer support conversations. The irony? Most of these classifications were simple binary decisions that a much smaller model could handle.

The challenge wasn't accuracy—it was getting a small model to reason as well as GPT-4 without the latency and cost overhead of chain-of-thought prompting. This led me to develop what I call "Latent Reasoning": a training technique where the model learns to reason during training, but that reasoning is compressed into the answer token probabilities during inference.

The result? 98% accuracy on a complex classification task, with inference generating only a single token. Let me show you how it works.

The Problem: LLM Inference is Expensive

We had a customer support AI that needed to determine when a conversation should be escalated to a human agent. The key signal: "Is the customer saying that their technical problem has been resolved?"

Sounds simple, right? But consider these edge cases:

  • "Thanks, got it" — Does this mean resolved? (Usually no, just acknowledging a step)
  • "I found it" — Found what? The phone? The setting? Need context.
  • "Perfect!" — Could mean resolved, could mean understood instructions
  • "I have it now, thank you" — This one is actually resolved

GPT-4 with chain-of-thought prompting handled these nuances beautifully—but at 100+ output tokens per classification, we were looking at:

  • ~500ms latency per classification
  • $0.003-0.01 per request
  • Rate limiting issues at scale

We needed a solution that could run locally, cost nearly nothing, and still capture the reasoning that made GPT-4 so good at this task.

The Insight: Reasoning as Probability Compression

Here's the key observation that led to the breakthrough:

Key Insight

When a language model generates chain-of-thought reasoning, the reasoning tokens condition the final answer token. But what if we could compress all that reasoning into the answer token's probability distribution directly?

Traditional approaches to training small models on classification tasks use one of two methods:

  1. Direct answer training: Train on input → "yes" or input → "no"
  2. Chain-of-thought distillation: Train on input → "reasoning... therefore yes"

Method 1 is fast at inference but loses reasoning capability. Method 2 preserves reasoning but requires generating 50-100+ tokens at inference time.

Latent Reasoning flips the script: We train on input → "yes, because [reasoning]" but at inference, we only generate the first token and use its probability as our answer.

The Training Data: Labels + Reasoning

The magic starts with how we structure the training data. Each example contains:

# Training data structure
{
    "running_message": "Customer: ... Support Agent: ... Customer: I have it now, thank you",
    "label": "yes",
    "reason": "the customer has stated 'i have it now' which indicates resolution"
}

# Combined output format for training
output = label + ", because in the end, " + reason
# → "yes, because in the end, the customer has stated 'i have it now' which indicates resolution"

This format is critical. The model learns to:

  1. Output the answer token first (yes or no)
  2. Follow with reasoning that justifies that answer

During training, the cross-entropy loss backpropagates through the entire sequence. The answer token learns to encode information about the upcoming reasoning—creating what I call compressed reasoning.

Data Preparation: The 1,500 Label Journey

High-quality training data was crucial. Here's how we built the dataset:

Step 1: Gold Standard Labeling

We manually labeled ~1,500 customer support conversation endings with binary labels and explanations:

df = pd.read_csv('./data/issue_resolved_gold_standard_w_reason_new.csv')

# Distribution
# no:  764 examples
# yes: 718 examples

# Example entry
{
    "running_message": """Customer: Can you help me with my tech issue?
Support Agent: Hi there! Let's check if Silent Mode is on...
Customer: I checked and it's off, still not ringing.
Support Agent: Let's try Do Not Disturb settings...
Customer: That was it! Working now, thanks!""",
    "label": "yes",
    "reason": "the customer explicitly confirms 'Working now' indicating resolution"
}

Step 2: Prompt Engineering

The input prompt guides the model's reasoning process:

question_prompt = """Take a deep breath and let's think step-by-step.
Answer "yes" or "no": Is the Customer saying that all of their technical problems have been resolved?

1. Understand the customer's issues so that it's clear if their issues have been answered and resolved.
2. Analyze the customer's responses: Look for phrases like "it's working now", "problem solved", etc.
3. Identify keywords indicating problem resolution: Words like "resolved", "fixed", "working", "solved", etc.
4. Make a decision based on the analysis.
5. Document the reasoning and decision.

### NOTE:
- If the customer states that there is nothing else we can help with, answer: "yes"
- If the customer mentions completion of any intermediate step but not their entire problem, answer: "no"
- Usually, when a customer just answers "Thanks", "got it", "ok", that doesn't mean resolved → answer "no"

Transcript:
```
{running_message}
```

The answer is: """

Step 3: Format for Training

Using the ChatML format for Mistral-OpenOrca:

# Format the complete training example
df['out_w_reason'] = df['label'] + ', because in the end, ' + df['reason']

# Apply chat template
chat = [{"role": "system", "content": filled_prompt}]
input_text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

# Final format:
# <|im_start|>system
# [prompt with transcript]
# <|im_end|>
# <|im_start|>assistant
# yes, because in the end, the customer has confirmed the issue is resolved.<|im_end|>

Fine-Tuning with QLoRA

We used QLoRA (Quantized Low-Rank Adaptation) to fine-tune Mistral 7B efficiently on a single A10G GPU:

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

# LoRA configuration - target all attention + MLP layers
peft_config = LoraConfig(
    lora_alpha=64,
    lora_dropout=0.05,
    r=32,                    # rank
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",  # attention
        "gate_proj", "up_proj", "down_proj",      # MLP
        "lm_head",                                # output
    ],
)

# Only train on completions (not the prompt)
response_template = "<|im_start|>assistant\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

# Training arguments
training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=2,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=5e-5,
    lr_scheduler_type="constant",
    optim="paged_adamw_8bit",
    bf16=True,
)

# SFT Trainer with NEFTune noise for better generalization
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=4096,
    tokenizer=tokenizer,
    data_collator=collator,
    args=training_arguments,
    neftune_noise_alpha=5  # Adds noise to embeddings during training
)

Key training stats:

  • Trainable parameters: 85M (1.16% of 7.3B total)
  • Training time: ~2 hours on A10G
  • VRAM usage: ~13GB

The Inference Trick: Single Token Generation

Here's where the magic happens. At inference time, we only generate one token and extract probabilities:

def run_inference(input_text):
    generation_config = GenerationConfig(
        max_new_tokens=1,          # Only generate ONE token
        do_sample=False,
        use_cache=True,
        num_beams=1,
        eos_token_id=tokenizer.eos_token_id,
    )

    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        output_scores=True,        # Get logits
        generation_config=generation_config,
        return_dict_in_generate=True
    )

    # Extract probabilities from the first (only) generated token
    probs = torch.stack(outputs.scores, dim=1).softmax(-1).cpu()
    filtered_probs = np.array(probs).flatten()

    # Aggregate probabilities across token variants
    # "yes" can be tokenized as: "yes", " yes", "Yes", " Yes"
    yes_token_ids = [9780, 5081, 5592, 5613]
    no_token_ids = [708, 1510, 1770, 4032, 2501]

    yes_prob = np.sum(filtered_probs[yes_token_ids])
    no_prob = np.sum(filtered_probs[no_token_ids])

    # Normalize and decide
    total = yes_prob + no_prob
    yes_normalized = yes_prob / total
    no_normalized = no_prob / total

    return "yes" if yes_normalized > no_normalized else "no", max(yes_normalized, no_normalized)
Why This Works

The model was trained to output yes, because... or no, because.... During training, the answer token had to "commit" to an answer that would be consistent with the reasoning that follows. This forces the model to compress its reasoning into the probability distribution of that first token.

Token Aggregation: Handling Tokenization Variance

One critical detail: the same word can be tokenized differently depending on context:

# Finding all token variants
tokenizer.convert_tokens_to_ids(['yes', '▁yes', '▁Yes', 'Yes'])
# → [9780, 5081, 5592, 5613]

tokenizer.convert_tokens_to_ids(['▁no', 'no', '▁No', 'NO', 'No'])
# → [708, 1510, 1770, 4032, 2501]

The prefix indicates the token includes a leading space (SentencePiece tokenization). By aggregating probabilities across all variants, we capture the model's full "belief" in each answer.

Results: 98% Accuracy at 100x Speed

The numbers speak for themselves:

Metric GPT-4 + CoT Mistral 7B + Latent Reasoning
Accuracy ~99% 98.2%
Output Tokens 100-150 1
Latency ~500ms ~20ms
Cost per 1M requests $3,000-10,000 $0 (self-hosted)

We also get a confidence score for free—the normalized probability. This enables smart routing:

# Production routing logic
answer, confidence = run_inference(conversation)

if confidence > 0.95:
    return answer  # High confidence, use SLM result
else:
    return gpt4_classify(conversation)  # Low confidence, escalate to GPT-4

Why "Latent" Reasoning?

I call this technique "Latent Reasoning" because:

  1. The reasoning exists — the model learned to reason during training
  2. The reasoning is hidden — it's encoded in token probabilities, not generated text
  3. The reasoning influences output — answer probabilities reflect reasoning quality

It's analogous to how humans develop intuition: we don't consciously reason through every decision, but our quick judgments are informed by past reasoning experiences.

Extending the Technique

This approach generalizes beyond binary classification:

Multi-class Classification

# Train on: "category_a, because [reasoning]"
# Aggregate probabilities for each category's token variants
categories = {
    "billing": [tok_ids_for_billing_variants],
    "technical": [tok_ids_for_technical_variants],
    "general": [tok_ids_for_general_variants],
}
probs = {cat: sum(filtered_probs[ids]) for cat, ids in categories.items()}

Scoring/Ranking

# Train on: "8, because the response quality is good but..."
# Use probability distribution across digit tokens [0-9]
score_tokens = [tokenizer.convert_tokens_to_ids(str(i)) for i in range(10)]
expected_score = sum(i * filtered_probs[score_tokens[i]] for i in range(10))

Lessons Learned

1. Reasoning Order Matters

Training on answer, reasoning works better than reasoning, answer for this technique. The model needs to "commit" to the answer first.

2. Quality Over Quantity

1,500 high-quality examples with good reasoning outperformed 10,000 examples with labels only.

3. Token Variant Aggregation is Critical

Without aggregating across tokenization variants, accuracy dropped by ~5%.

4. Confidence Calibration

The confidence scores are well-calibrated—low confidence predictions are genuinely harder cases.

Conclusion

Latent Reasoning bridges the gap between expensive, slow chain-of-thought models and fast but naive classifiers. By teaching the model to reason during training and compressing that reasoning into single-token probabilities, we get the best of both worlds.

The technique is particularly powerful for:

  • High-volume classification tasks
  • Latency-sensitive applications
  • Cost-constrained deployments
  • Cases where explainability isn't required at inference time

Connect me on LinkedIn if you want the full training notebook. If you're building AI products at scale, I'd love to hear how you're solving similar challenges—reach out!