Learn how Zscaler uses AI for prod to get to RCA for 150K alerts in minutes

Making LoRA Work at Scale: A Debugging Story

At Resolve, we train and fine-tune large models as part of our core infrastructure. We recently scaled up our LoRA training to distributed multi-GPU setups with tensor parallelism using Megatron-Core.

Training went well. The loss curve converged cleanly. The model generated correct outputs. You'd look at this and move to evals.

Something looked off

Our internal evals weren't hitting the scores we expected. Training loss went down, accuracy went up. But not by nearly as much as we'd expect, which smelled like a classic training bug.

We looked at the traces. The model's tool calls were diverging from expected behavior right from the start of each example. Not subtly drifting partway through - different from the very first call. That shouldn't happen for a model that's supposedly learned.

So we looked at the log probabilities. We have tooling that renders per-token log probs as a color-coded heatmap: green for high-confidence predictions, red for surprises. We put the base model and our LoRA-trained model side by side on the same data.

They looked almost the same. The trained model showed a slight reduction in red, but nowhere near what you'd expect from a model whose training loss says it learned the data.

The model still generated reasonable text. If you just read the outputs, you might not notice a problem. But the eval scores, the traces, and the heatmap were all telling the same story: the weights weren't quite right. The question was whether this was just noise or something real.

I'll encourage readers to enjoy the Journey before Destination, but for those who want to skip to the answer you can take the quick way and jump straight to the answer.

Making it undeniable

To get a definitive signal, we designed a controlled experiment: fine-tune a large MoE model with LoRA and train it to heavily overfit on a small memorization dataset. If the model can perfectly reproduce training data at inference time, the pipeline is working. If it can't, something is broken. No ambiguity.

We trained until it converged. The training loss is ~1e-4. The model reproduced every training example word for word.

Then we loaded the checkpoint into SGLang and measured cross-entropy on the same data.

Training loss: ~1e-4. Inference loss: ~0.03. A 300x gap. Same model, same data, same loss metric.

The heatmap confirmed it. For a model that has memorized its data, this should have been solid green. Instead, we got scattered red everywhere — wrong confidence on tokens the model had perfectly memorized during training. The gap from our initial training wasn't noise. Something between training and inference was silently degrading the weights.

Narrowing it down

To find the bug, we first mapped out every variable in the pipeline that could cause the gap:

VariableOptions tested
Fine-tuning methodFull SFT, LoRA
Model architectureDense, MoE
Inference frameworkSGLang, HuggingFace (online LoRA), PEFT merge, native FP32 merge, Megatron-native eval
PrecisionBF16, FP16, FP32
LoRA applicationOnline (no merge) vs. merged into base weights
Parallelism configVarious TP, PP, EP combinations
Checkpoint methodMegatron distributed checkpoint vs. custom torch.save()

The goal was to change one variable at a time until we found the one that mattered. We worked through a series of hypotheses, testing each one against the memorization benchmark.

Does full fine-tuning have the same problem?

This was the first thing we checked. Distributed training with Megatron-Core involves a lot of moving parts: MoE routing across expert-parallel ranks, tensor-parallel communication for attention and MLP layers, gradient accumulation across data-parallel groups. A subtle bug in any of these could silently corrupt what the model learns.

To test this, we ran the same experiment with full fine-tuning instead of LoRA. Same model, same parallelism config, same memorization dataset. If full SFT also showed a gap, the problem would be in the training infrastructure itself.

It didn't. Full SFT on a dense model gave us a training loss of 9.59e-6 and an inference loss of ~9e-6. Essentially identical. We repeated this on MoE architectures with multiple parallelism configurations (different TP, PP, EP combinations). Every single one matched.

MethodTraining LossInference LossGap
Full SFT (Dense)9.59e-6~9e-61x
Full SFT (MoE, TP=4, EP=2)~07e-6~1x
Full SFT (MoE, TP=1, EP=8)~0~01x
LoRA (MoE)~1e-4~0.03~300x

We also tested across different model sizes and architectures (both dense and MoE) to rule out model-specific issues. Full SFT matched in every case. Megatron's training infrastructure was sound. The issue was LoRA-specific.

Is inference handling LoRA incorrectly?

LoRA adds a low-rank adapter to the base model's weights. At inference time, you have a choice: apply the adapter on the fly ("online LoRA") or merge it into the base weights permanently. Different frameworks implement this differently, and any of them could introduce errors.

Maybe SGLang's FP8 KV cache was quantizing something it shouldn't. Maybe the PEFT merge operation was losing precision.

We tested this one exhaustively. We ran seven different inference configurations on the same checkpoint:

Inference MethodPrecisionLoss
SGLang (production, FP8 KV cache)BF16 + FP8~0.058
SGLang (all optimizations disabled)BF16~0.025
HuggingFace online LoRAFP160.030
HuggingFace online LoRABF160.030
HuggingFace online LoRAFP320.030
PEFT merge_and_unloadFP160.030
Native FP32 merge (bypass PEFT)FP320.028

Every method agreed with every other method. Online LoRA (no merge) gave the same result as merged weights. FP16, BF16, and FP32 all landed in the same range. When that many independent implementations produce consistent results, the bug is upstream. Something was wrong with what gets saved, not how it gets loaded.

Is numerical precision causing accumulated errors?

This one looked promising initially. BF16 has only 8 bits of mantissa compared to FP16's 11 bits. When merging LoRA weights (W' = W + α·B·A), the low-order bits of the LoRA delta get rounded away because they fall below W's representable precision. This order of operations matters because floating-point arithmetic isn't associative ((W + α·B·A)·x ≠ W·x + α·B·A·x). Computing them separately preserves more of the correction. These errors compound across transformer layers.

We measured the per-layer difference between applying LoRA online versus merging it into the base weights:

PrecisionMax Per-Layer DiffRelative Error
FP320.0000010.001%
FP160.0039062.5%
BF160.03125019.75%

BF16 had 8x worse associativity error than FP16, and this compounds across multiple transformer layers. For a moment, it felt like we'd found the answer.

But switching to FP16 or FP32 for inference didn't meaningfully reduce the overall loss gap. The end-to-end numbers stayed in the ~0.028-0.030 range regardless of precision. The precision error was real and measurable, but something much larger was drowning out the signal. We set it aside and kept looking.

Is the weight conversion or merge math wrong?

Converting a checkpoint from Megatron's internal format to HuggingFace PEFT format involves reshaping tensors, renaming parameters, and applying the LoRA merge formula. Any of these steps could introduce errors.

We checked each one:

  • Weight conversion: Compared Megatron and HuggingFace tensors element by element. Max difference: 2.98e-8. Noise floor.
  • LoRA merge formula: Verified layer-by-layer at 1e-6 tolerance. Every layer matched.
  • Scaling: Both training and inference used alpha/rank = 32/128 = 0.25. Identical.
  • LayerNorm fusion: Verified both apply to post-normalized input. Identical.
  • RoPE positional encoding: Exact match on theta and frequencies.
  • Softmax precision: BF16 vs FP32 attention difference was 0.002. Too small.
  • Loss formula: Both use sum(losses)/num_tokens. Standard cross-entropy, properly averaged.

Everything matched. The math was right. The weights we had were being applied correctly. The problem was that we didn't have the right weights.

The breakthrough

We had eliminated precision, merge math, and weight conversion. Every inference method agreed on the same wrong answer. That consistency pointed upstream, but we needed to pinpoint exactly where.

So we cut inference out of the equation entirely. Instead of converting checkpoints and loading them into SGLang or HuggingFace, we ran evaluation directly inside Megatron using its native eval path. Same training code, same model and configs in memory, no conversion, no external inference framework.

We also added an in-training validation step that evaluates the model before saving. If the model is correct in memory but broken after save/load, the checkpoint is the problem.

TestTraining LossEval LossGap
In-training validation (before save)1.26e-51.34e-51.07x
After checkpoint save/load1.35e-40.0065850x

The model was correct in memory. The checkpoint was corrupting it.

To confirm, we wrote a custom export that bypassed Megatron's distributed checkpointing entirely. Instead of going through dist_checkpointing.save(), we gathered the adapter state dicts from all ranks and wrote them with raw torch.save(). Loaded them back. Ran evaluation.

Checkpoint MethodTraining LossEval LossGap
Standard Megatron checkpoint1.35e-40.0065850x
Custom .pt export2.47e-52.54e-51.03x

Same model. Different checkpoint method.

The training code was correct. The inference code was correct. The checkpoint save was dropping adapter weights.

The first bug: checkpoint corruption {#the-first-bug:-checkpoint-corruption}

With tensor parallelism, a model's weight matrices are split across multiple GPU ranks. Rank 0 holds one slice, rank 1 holds another. When you add LoRA adapters to these layers, the adapters get sharded the same way. Each rank has its own slice of each adapter.

The problem: these shards have identical parameter names but different tensor data. Rank 0's layer.0.lora_A.weight contains different values than rank 1's layer.0.lora_A.weight. Both are needed to reconstruct the full adapter.

During PEFT-filtered checkpoint save (which saves only adapter weights, not the full model), each TP rank creates a ShardedTensor for its adapter shard with the same logical key. These shards should be distinguished by their tensor offsets, but somewhere in the distributed save coordination, they are incorrectly treated as duplicates. The result is that only one TP rank's shard is persisted:

# What should happen:
# Rank 0: "layer.0.lora_A.weight" → shard [0:rank/2, :] → saved
# Rank 1: "layer.0.lora_A.weight" → shard [rank/2:rank, :] → saved

# What actually happens with PEFT-filtered save:
# Rank 0: "layer.0.lora_A.weight" → shard [0:rank/2, :] → saved
# Rank 1: "layer.0.lora_A.weight" → shard [rank/2:rank, :] → silently dropped

No errors, no warnings. Just missing data.

The model still worked because partial weights were preserved. Enough to produce plausible text, but not enough to match the model that was actually trained.

NVIDIA had already found and fixed this exact issue for expert-parallel adapters in MoE layers (PR #1564, related to volcengine/verl#4303). That fix adds proper replica_id handling so TP shards are distinguished during save.

But dense LoRA adapters, the most common case, never got the same treatment. We traced the likely root cause to ParallelLinearAdapter.sharded_state_dict() in Megatron-Bridge: the expert adapter path sets replica_id correctly, but the dense adapter path doesn't.

The fix

We implemented a workaround: temporarily disable the PEFT filter during checkpoint save, forcing Megatron to save the full model state dict. The full model save path already handles TP sharding correctly, so all adapter shards are preserved. The trade-off is larger checkpoint files, but the weights are correct.

# Workaround: disable PEFT filter to preserve TP shards
peft_backup = state.cfg.peft
state.cfg.peft = None  # Save full model (preserves TP sharding)
try:
    save_checkpoint(state, model, ...)
finally:
    state.cfg.peft = peft_backup  # Restore config

We filed an issue with NVIDIA's Megatron-Bridge team with a detailed root cause analysis and a suggested fix. The proper fix would ensure that TP shards of dense adapters are correctly distinguished during PEFT-filtered save, mirroring the handling that already exists for expert adapters. NVIDIA's team has proposed a fix in PR #2252.

The second bug: BF16 precision loss

With the checkpoint bug fixed, we circled back to the precision finding we'd set aside earlier. The per-layer BF16 associativity errors were real. The checkpoint corruption had been drowning out the signal, like trying to hear a conversation next to a fire alarm. Once the alarm was silenced, the conversation became audible.

In some ways, this is the more interesting finding. Checkpoint corruption is a software bug: someone missed a case, you file a PR, it gets fixed. BF16 precision loss is structural. It's a property of the arithmetic format that most large-scale training relies on. BF16 offers good dynamic range with lower memory and compute cost, but the trade-off is reduced mantissa precision, and that has real consequences when LoRA adapters need to be applied with exact fidelity across dozens of transformer layers.

The mitigation is straightforward: use FP16 or higher precision when applying LoRA weights at inference time, regardless of the precision used during training.

Why this matters

We found two bugs in the same pipeline. One was a software bug: checkpoint serialization silently dropping adapter weights. The other was a numerical property of BF16 arithmetic compounding errors across transformer layers. Different in kind, both real, both capable of silently degrading the quality you paid for with training compute.

The checkpoint bug ended up in the most mundane part of the pipeline: checkpoint serialization. Not the model. Not the math. Not inference. Just how a checkpoint gets written to disk. It affects anyone using LoRA with tensor parallelism in Megatron-based training. The precision bug affects anyone using BF16 for LoRA weight application, regardless of framework.

What got us there was process. The memorization test gave us a controlled, reproducible signal. Changing one variable at a time let us isolate the layer of abstraction that was broken. When configuration changes couldn't explain the problem, we bypassed entire subsystems until we found the one that was broken. And when the big bug was fixed, the smaller one became visible.

Distributed training has sharp edges where parallelism strategies meet adapter methods. As LoRA adoption grows, we expect more issues like this. We're investing in validation tooling that catches these discrepancies before deployment, and contributing fixes upstream when we find them. These frameworks are shared infrastructure. Bugs affect everyone, and so do fixes.

Thanks to the Megatron-Bridge team at NVIDIA for engaging on the upstream fix.