Skip to content

Training Safety: Defense in Depth

Training neural networks is treacherous. Gradients explode. Embeddings collapse. Memory exhausts. Models memorize instead of learning. Featrix handles all of this automatically through layered safety mechanisms that detect problems and recover without human intervention.

The Philosophy

The Featrix architecture reflects a core principle: the system must produce useful results on arbitrary data without human intervention, and it must never silently degrade.

Every mechanism described here exists because a real failure mode was encountered on real data—a column of ZIP codes that looked numeric, a scalar column that collapsed while strings carried the model, a bf16 gradient that overflowed and corrupted the optimizer state, a DataLoader worker that leaked gigabytes of RAM.

Gradient Safety: Defense in Depth

bf16-Safe Gradient Clipping

Featrix trains in bfloat16 for speed, but bf16 has limited dynamic range—gradient norm computation can overflow. The system:

  1. Casts all gradients to float32 before computing norms
  2. Scans every gradient tensor for inf/NaN values and zeros them out
  3. Prevents the classic inf × 0 = NaN corruption where one infinite gradient poisons the entire model

Per-Column Gradient Monitoring

The system doesn't just track a single global gradient norm. It monitors every column encoder individually:

Column gradients (exponential moving average):
- age: 0.0015 ✓
- income: 0.0023 ✓
- description: 0.0089 ✓
- zip_code: 847,293.5 ⚠️ EXPLODING

A single column with exploding gradients can destabilize the entire model while the global norm looks normal (averaged away by healthy columns).

Three-Tier Explosion Response

When a column's gradient norm exceeds thresholds:

Threshold Response
> 1e4 Clip: Scale gradients down while preserving direction
> 1e6 Zero: Zero out gradients entirely for this batch
3 consecutive zeros Freeze: Stop training this column altogether

The freeze is deferred to the next batch start, not applied mid-backward-pass, to avoid corrupting the computation graph.

NaN/Inf Recovery

If the total gradient norm comes back as NaN or infinity despite all protections:

  1. Identify every parameter with corrupted gradients
  2. Zero those gradients
  3. Replace any NaN/Inf values in the parameters themselves with 0.0
  4. Skip the optimizer step for this batch
  5. Log which columns were involved

Training continues. The batch is lost, but the model is not.

Embedding Collapse Detection

Embedding collapse—where all rows map to the same point—is insidious because the loss can look fine while embeddings are useless.

Multiple Detection Mechanisms

Spread Loss: Cross-entropy on the self-similarity matrix. Each row should be most similar to itself, creating repulsive force between different rows.

Per-Column Diversity Loss: Global diversity can mask per-column failure. If 80 string columns produce diverse embeddings, 2 collapsed scalar columns are invisible in the global metric. The per-column diversity loss checks each column independently.

The Scalar-Only Probe: Periodically computes Recall@1 using only scalar columns, then compares to joint Recall@1. If scalars are near random while the joint is strong, the model is ignoring scalars entirely—problematic collapse.

Hemisphere Clustering: A subtler failure where embeddings spread locally but cluster on one side of the sphere. The halfspace coverage loss samples random hyperplane splits and checks for one-sided clustering.

Embedding Health Zones

Zone Std/Dim Status Response
Random ≥ 0.055 Model hasn't learned yet Normal
Healthy 0.04 – 0.055 Target zone Maintain
Recovering 0.035 – 0.04 Improving after intervention Monitor
Warning 0.02 – 0.035 Embeddings compressing Increase spread/diversity
Emergency < 0.02 Critical collapse Aggressive intervention

Adaptive Loss Weighting: The Control Loop

Featrix doesn't train with fixed loss weights. It runs a continuous control loop:

Emergency Response (std/dim < 0.02): - Spread weight × 1.5 - Diversity weight × 1.3 - Diversity temperature × 0.8 (sharper gradients) - Marginal weight × 0.8

Warning Response (std/dim 0.02 – 0.035): - Spread weight × 1.2 - Diversity weight × 1.1

Recovery Response (std/dim back to healthy): - Spread weight × 0.95 per epoch (gradual back-off)

All weight changes happen through a nudge system: gradual adjustment over at least 3 epochs using cosine interpolation. Stacking limits prevent runaway escalation—at most 4 consecutive increases before improvement must be seen.

Training Failure Detection

Featrix continuously monitors for six specific failure modes:

Failure Detection Severity
Dead Network Gradient norms < 1e-6 Critical
Very Slow Learning < 1% improvement, tiny gradients High
Severe Overfitting Training ↓ while validation ↑ High
No Learning < 0.05% validation improvement for 15 epochs Medium
Moderate Overfitting Train/val gap > 10% after epoch 10 Medium
Unstable Training High loss variance, oscillation patterns Low

A "convergence exception" suppresses instability warnings during late-stage fine-tuning, where some oscillation at the bottom of the loss landscape is normal.

Learning Rate: Four-Phase Schedule

Phase Training % Curve Purpose
Aggressive Warmup 0–15% Cubic ramp Prevent gradient explosion
Stabilization 15–20% Hold at max Let model stabilize
OneCycle 20–70% Cosine anneal Main learning phase
Linear Cooldown 70–100% Linear descent Final convergence

Dynamic adjustments happen on top of this schedule when the training rules engine detects plateaus or opportunities to push harder.

Memory Management

GPU memory exhaustion after hours of training with no checkpoint is catastrophic.

OOM Retry: On out-of-memory, the system clears GPU cache, reduces batch size, and retries—up to 3 times.

Aggressive Defragmentation: Before validation runs (which temporarily double memory usage), multi-pass GPU clearing with garbage collection.

Worker Process Management: DataLoader workers can leak memory. The system tracks expected worker counts, detects runaways, and force-kills orphaned workers.

Memory-Aware Validation: GPUs ≤16GB get 2 validation workers; larger GPUs get 4.

CPU-Side Checkpointing: The encoder is cloned to CPU before checkpointing, preventing OOM when unpickling would try to allocate GPU memory for both checkpoint and active model.

Early Stopping with Safeguards

Patience-based early stopping has multiple safeguards:

  • Minimum epoch threshold: Disabled until epoch 50
  • NO_LEARNING recovery block: Blocked for 10 epochs after detecting a plateau
  • NO_STOP override: External file can disable early stopping entirely
  • Finalization phase: 5 epochs of spread+joint loss focus before actually stopping
  • Data rotation: Instead of stopping on plateau, rotate to fresh data partition (up to 3 rotations)

Checkpoint and Recovery

Automatic Resume: On startup, searches for the latest valid checkpoint and resumes. "Valid" means the checkpoint's column set matches current data.

Corrupted Checkpoint Handling: Falls back to earlier checkpoint rather than crashing.

Full State Recovery: Checkpoints save model weights, optimizer state, LR scheduler state, dropout scheduler state, loss weight timeline, and gradient tracking history.

External Control

Signal files allow control without interrupting the process:

Signal Effect
ABORT Stop immediately, mark failed
PAUSE Save checkpoint, pause gracefully
FINISH Complete current epoch, save
NO_STOP Disable early stopping
RESTART Log restart (for diagnostics)
PUBLISH Flag model for publication

Checked at batch boundaries to ensure consistent state.

WeightWatcher: Spectral Health

Beyond loss-based diagnostics, Featrix runs WeightWatcher analysis on weight matrices. WeightWatcher computes the power-law exponent (alpha) of each layer's spectrum:

Alpha Interpretation
2–5 Healthy, good generalization
> 6 Noise-dominated, stopped learning

Best-Epoch Selection: Featrix doesn't just save the checkpoint with lowest validation loss. It computes:

composite = val_loss × (1 + 0.3 × (alpha_score - 1))

A model with slightly higher validation loss but healthier weight matrices is preferred over one that achieved lower loss through memorization.

Observable Training

Every training run produces structured, machine-readable records:

  • Training timeline: Pre-calculated plan vs. actual results per epoch
  • Adaptive event log: Every parameter change with full attribution (what changed, why, what triggered it, expected effect)
  • Per-epoch quality scores: 0–100 score from embedding health (40pts), gradient health (30pts), ranking quality (30pts)
  • Per-row tracking: Which rows are hard, which are late learners, which flip-flop

Every decision Featrix makes is recorded, attributed, and available for human review—because a system that can't explain what it did is a system you can't trust.

The Result

All of this machinery exists so you don't have to babysit training. Upload your data, start training, and Featrix:

  1. Detects and classifies problems automatically
  2. Intervenes with appropriate responses
  3. Recovers from failures without losing progress
  4. Produces a model that has been stress-tested throughout training
  5. Documents every decision for post-hoc analysis

You get a trained model and a complete record of how it got there—not a black box that might have silently degraded.