Skip to content

How Featrix Builds an Embedding Space: Technical Deep Dive


Executive Summary

Featrix's Embedding Space (ES) is a self-supervised neural network that learns joint representations of structured tabular data. Unlike traditional feature engineering, the ES automatically discovers relationships, encodes heterogeneous data types, and creates dense vector representations that capture the mutual information between all columns simultaneously. This document provides a comprehensive technical deep-dive into the architecture, training dynamics, and production systems that make this possible with zero configuration.


PART I: UNDERSTANDING FEATRIX

This section is for anyone evaluating Featrix: data scientists, business analysts, ML engineers, decision-makers. Stop reading anywhere and you'll have learned something useful.


0. The Traditional ML Nightmare (And Why Most Teams Fail)

Machine learning promised to revolutionize how we extract insights from data. The reality? For most teams, building a production ML system is an exercise in frustration, wasted time, and failed projects. The problem isn't a lack of algorithms or computing power—it's that traditional ML requires hundreds of manual decisions, each with massive implications, and most teams simply don't have the expertise to make them correctly.

0.1 The Manual Choice Hell

Every step of the traditional ML pipeline demands expertise that most organizations don't have. Consider what happens when you try to build a simple classification model:

Data Encoding: Your dataset has a mix of numbers, categories, and text. Do you one-hot encode the categories? That explodes dimensionality. Use target encoding? That leaks information and causes overfitting. Try label encoding? You've just told your model that "California" is mathematically greater than "Alabama," creating nonsensical ordinal relationships. What about text fields—emails, addresses, product descriptions? Do you use TF-IDF? Word2Vec? BERT embeddings? Each choice has profound implications for model performance, and there's no way to know which is right until you've wasted weeks trying them all.

Handling Missing Data: Your dataset has nulls scattered throughout. Do you impute with mean? Median? Mode? Forward-fill for time series? Use KNN imputation? MICE? Each column might need a different strategy, but how do you know which? Worse, the imputation strategy interacts with your model choice—tree-based models handle nulls differently than neural networks. Make the wrong call and your model learns patterns from your imputation artifacts, not from your actual data.

Feature Engineering: The model needs features, but which ones? Do you create polynomial interactions? Extract temporal features from timestamps? Bin continuous variables? How? The curse of dimensionality means too many features degrades performance, but too few means you miss critical patterns. You need deep domain expertise to know which features matter, but even experts guess wrong. And when your data changes—a new product category appears, a supplier adds fields—your entire feature engineering pipeline breaks.

Dealing with Messy Data: Real-world data is a disaster. Product names have typos. Addresses are formatted inconsistently. Categories have synonyms ("canceled," "cancelled," "CANCELLED," "Canceled"). Do you manually clean every variant? Build fuzzy matching? At what threshold? How do you handle edge cases? Each decision is a potential source of error, and there are thousands of them.

Architecture Selection: Should you use a Random Forest? XGBoost? A neural network? How deep? How wide? What activation functions? Batch normalization? Residual connections? These aren't academic questions—they determine whether your model works at all. But there's no principled way to choose. You try a few, pick the one that seems best, and hope you didn't miss something better.

Hyperparameter Tuning: Now comes the really painful part. Learning rate? Batch size? Regularization strength? Number of trees? Tree depth? Dropout rate? Early stopping patience? Each combination takes hours or days to train. Grid search is exponentially expensive. Random search is gambling. Bayesian optimization requires yet more expertise. Teams spend weeks grid-searching hyperparameters, burning through expensive GPU hours, only to find that their best model is barely better than a baseline.

Training Infrastructure: Before you can even start training, you need to set up the infrastructure. Install PyTorch or TensorFlow. Match CUDA versions with driver versions with library versions. Configure multi-GPU training. Set up distributed training. Debug cryptic CUDA errors. Handle out-of-memory crashes. This alone can take days for an experienced engineer.

The brutal truth? Each of these decisions requires expertise most teams don't have. Make the wrong choices—and you will, because there's no way to know in advance—and you've wasted weeks of work on a model that doesn't perform. Make enough wrong choices and your entire project fails.

0.2 AutoML: A Band-Aid on a Bullet Wound

The ML community's answer to this complexity was AutoML: let the machine search for good hyperparameters and architectures. Services like Google AutoML, H2O, and Auto-sklearn promise to automate the pain away.

But AutoML is a band-aid, not a solution. It automates hyperparameter search, not the fundamental problems:

You still need to encode your data. AutoML doesn't know how to handle your email addresses or product descriptions. You still need to choose encodings, still need to impute missing values, still need to engineer features. All the messy data preparation work remains manual.

Grid search is still expensive. AutoML tries hundreds or thousands of configurations. Each one requires training a model from scratch. On a reasonably-sized dataset, this means days or weeks of compute time. The costs spiral quickly—teams routinely spend thousands of dollars on cloud compute for AutoML runs that produce mediocre results.

It's a black box. When AutoML produces a model, can you explain why it chose that architecture? Those hyperparameters? If performance degrades in production, can you diagnose why? If stakeholders ask "why did the model make this prediction?", what do you tell them? AutoML gives you a model, but no understanding.

One model per target. Want to predict customer churn? Train an AutoML model. Want to predict customer lifetime value? Train another AutoML model. Want to segment customers? Train a third model. Each prediction task requires a complete, expensive AutoML run. You're solving the same problem—understanding your customer data—over and over, from scratch each time.

It still requires expertise. Interpreting AutoML results, choosing evaluation metrics, handling class imbalance, preventing data leakage—these still require ML expertise. AutoML democratizes hyperparameter search, not machine learning itself.

The fundamental problem remains: traditional ML treats every prediction task as a separate, manual, expensive endeavor. Whether you tune by hand or use AutoML, you're stuck in a cycle of trial-and-error that consumes months and produces brittle, one-off solutions.

0.3 Why This Makes ML Impossible for the Average Bear

The consequences of this complexity are devastating for most organizations:

Time to production is measured in months. A skilled data scientist might spend 2-3 months on a single prediction task: weeks on data cleaning and feature engineering, weeks on model selection and training, weeks on validation and debugging. For teams without ML expertise, these timelines stretch to six months or more—if they succeed at all.

Most projects fail. Industry estimates suggest 85-90% of ML projects never make it to production. The complexity overwhelms teams. Stakeholders lose patience. Budgets run out. The model that worked in development fails in production due to data drift, and nobody knows how to fix it.

Iteration is impossible. Business requirements change, but iterating on an ML model means revisiting all those manual decisions. Want to add a new feature? Rebuild your pipeline. Need to predict a different target? Start from scratch. Each iteration takes months. By the time you deliver, requirements have changed again.

The cost is enormous. Between salaries for ML engineers, cloud compute for training, and the opportunity cost of failed projects, organizations spend millions attempting to operationalize ML. Most never see ROI.

Explainability is fiction. When the business asks "why did the model predict this?", the honest answer is often "we don't know." The pipeline is so complex—imputation strategies, engineered features, ensemble models, AutoML-selected architectures—that even the engineers who built it can't fully explain its behavior. This destroys trust and prevents adoption.

The promise of ML—democratizing access to powerful predictive models—has become a cruel joke. In practice, ML is accessible only to large tech companies with dedicated ML teams, massive compute budgets, and tolerance for failure. For everyone else, traditional ML is a nightmare of complexity, wasted effort, and broken promises.

0.4 The Featrix Solution: Foundational Learning, Zero Expertise Required

Featrix takes a fundamentally different approach. Instead of building one-off models for each prediction task, Featrix builds a foundational representation of your data—a learned understanding of what your data means and how it relates to itself. Once you have this foundation, prediction tasks become trivial.

Truly Automatic, Not AutoML

Featrix doesn't automate hyperparameter search. It eliminates hyperparameters entirely. There are no learning rates to tune, no architectures to select, no regularization strengths to configure. The system analyzes your data and determines optimal settings automatically:

  • Mixed data types? Featrix detects types and applies specialized encoders—domain encoders for emails, timestamp encoders for dates, semantic encoders for text—without you specifying anything.
  • Missing values? Handled automatically with learned representations that capture uncertainty.
  • Class imbalance? Detected and corrected with automatic class weighting.
  • Batch size? Computed from dataset size and available memory.
  • Number of epochs? Calculated to achieve a target number of optimizer updates.
  • Messy data, typos, synonyms? The semantic encoders learn that "canceled" and "cancelled" mean the same thing.

You don't choose these settings. You don't even know they exist. The system handles it.

Super Easy: Upload Data, Get Predictions

The Featrix workflow is radically simple:

  1. Upload your data (CSV, Parquet, JSON—we handle it)
  2. Wait 10-30 minutes while the embedding space trains
  3. Train a predictor on any target column in 2-5 minutes
  4. Get production-ready predictions with calibrated probabilities

No data scientists required. No GPUs to configure. No pipelines to build. A business analyst can do this.

Super Repeatable: Same Data, Same Results, Always

Traditional ML is notoriously non-deterministic. Run the same code twice, get different results. Featrix is deterministic by design. Upload the same data, get the same embedding space, get the same predictions. Every time. This means:

  • Reproducible experiments for compliance and auditing
  • Reliable A/B testing
  • Confidence in production deployments
  • No mysterious performance variations

The Foundational Approach: Train Once, Predict Many

This is where Featrix diverges most radically from traditional ML. Instead of training a separate model for each target variable, Featrix learns a foundational representation of your data through self-supervised learning:

The embedding space trains on unlabeled data, learning to predict columns from other columns. This forces it to discover the statistical structure of your data—which features correlate, which distributions matter, where natural clusters exist. The result is an embedding space that captures what objectively matters about your data from a statistical perspective.

Once you have this foundation, prediction tasks become simple: train a small neural network (2-4 layers) on top of the embeddings. This takes minutes, not months. Want to predict customer churn? Two minutes. Also predict customer lifetime value? Another two minutes. Segment customers by behavior? Two more minutes.

You're not rebuilding the world each time. You're reusing a foundational understanding of your data, built once, used everywhere.

Fast Iteration: Minutes to Production, Not Months

Because the foundational embedding space captures the general structure of your data, iterating on prediction tasks is fast:

  • New target variable? Train a predictor in 2-5 minutes.
  • New data? Extend the existing embedding space, don't retrain from scratch.
  • Requirements changed? Swap predictors, keep the foundation.
  • A/B test different approaches? Train multiple predictors simultaneously.

This changes the economics of ML. Instead of months-long projects with uncertain outcomes, you get production-ready models in hours. Instead of failed projects, you get iterative improvement. Instead of expensive specialists, you get accessible tools.

Production-Ready Out of the Box

Featrix models come with everything production deployments need:

  • Calibrated probabilities: Predictions include reliable confidence estimates
  • Guardrails: Automatic detection of out-of-distribution inputs
  • Explainability: Feature importance, contribution analysis
  • Monitoring: Built-in drift detection and model performance tracking
  • Versioning: Complete training history and model provenance

Traditional ML requires months of additional work to make models production-ready. Featrix delivers this by default.

0.5 From Nightmare to Reality: What This Means

The implications are profound. Machine learning shifts from being a specialized, expensive, high-risk endeavor to being a commodity tool accessible to any organization with data:

For Data Scientists: Spend your time on high-value analysis and insight generation, not on hyperparameter tuning and pipeline debugging. Build ten models in the time it used to take to build one.

For Business Analysts: Directly answer business questions with predictive models, no ML expertise required. Test hypotheses in hours, not months.

For Organizations: Achieve ROI from ML projects in weeks instead of never. Scale ML across the organization without hiring an army of specialists.

The rest of this document explains how Featrix achieves this—the neural architecture, the self-supervised training objective, the automatic configuration systems, and the production infrastructure that makes this possible. But the core insight is simple: by building foundational representations instead of one-off models, and by eliminating manual decisions through intelligent automation, Featrix makes machine learning accessible to everyone.

Traditional ML is dead. Foundational learning is the future. Let's see how it works.


1. Working with Data in Featrix

Before you can train models, you need data. But not all data is created equal. Featrix handles a remarkably wide range of data formats and structures, from simple CSV files to complex multi-table relationships with nested JSON. Understanding what kinds of data Featrix accepts—and how it handles the messy, real-world problems that come with that data—is essential to using the system effectively.

1.1 Data Formats Featrix Accepts

Featrix is designed to meet you where your data lives. You don't need to transform everything into a proprietary format or build complex ETL pipelines. The system accepts:

Single Table (Most Common) - CSV files: The bread and butter of data science. Upload a CSV with column headers, and Featrix detects types, handles missing values, and starts training. - Parquet files: For larger datasets, Parquet provides compression and columnar storage. Featrix reads Parquet natively. - Pandas DataFrames: If you're working in Python notebooks, pass a DataFrame directly. No file I/O needed.

Multiple Tables with Relationships - CSV/Parquet with foreign keys: Multiple tables that need to be joined—customers and transactions, orders and line_items, users and events. - Explicit join specifications: Tell Featrix how the tables relate (one-to-one, one-to-many, many-to-many), and the system constructs hierarchical embeddings.

Nested and Hierarchical Data - JSON columns: Tables with JSON blobs (product metadata, API responses, nested configurations) are handled by training child embedding spaces on the flattened JSON. - Time series within rows: Event sequences, transaction histories, sensor readings—Featrix can treat these as structured columns or nested data.

The key insight: Featrix doesn't force you to flatten everything into a single table with one-hot encodings and hundreds of sparse columns. Multi-table relationships and nested data are first-class citizens.

1.2 Data Types: Structured, Semi-Structured, and Unstructured

Real-world datasets are a mix of clean numbers, messy categories, and free-form text. Featrix automatically detects and encodes all three:

Structured Data - Numbers (scalars): Age, revenue, temperature, counts. Encoded with 20 different strategies (linear, log, percentile, robust, etc.) to handle any distribution. - Categories (sets): Country, product_category, status. Learned embeddings capture semantic relationships between categories. - Dates and timestamps: Order_date, created_at, last_login. Decomposed into cyclical features (month, day_of_week) and linear features (year). - Booleans: is_active, has_subscription. Treated as binary categorical.

Semi-Structured Data - Emails: user@company.com → Decomposed into domain (company.com), TLD (.com), and free-email flag (gmail, yahoo). - URLs: https://example.com/path?query=value → Protocol, domain, path, and query params all encoded separately. - Phone numbers: +1-555-123-4567 → Country code, area code, and number components. - Addresses: "123 Main St, NYC, NY 10001" → Can be treated as a composite hybrid column (street, city, state, zip) or as unstructured text.

Unstructured Data - Free text: Product descriptions, customer notes, review comments. Encoded with pre-trained BERT embeddings (all-MiniLM-L6-v2) that capture semantic meaning. - Long documents: Blog posts, support tickets. Chunked and encoded, then aggregated with attention pooling.

The power here is that Featrix doesn't treat everything as text or everything as numbers. It detects the type, applies the right encoder, and lets the transformer learn how these heterogeneous representations relate to each other.

1.3 Multi-Table Data: Joins and Relationships

Most real-world datasets aren't single tables—they're collections of related tables. Customer data lives in one table, transactions in another, products in a third. To build a model that predicts customer churn, you need to join these tables. But joins are dangerous.

Join Types and Their Challenges

  • One-to-one (1:1): Each customer has one billing address. Safe to join—no row explosion, no data duplication. Result table has the same number of rows as the input.
  • Many-to-one (N:1): Many transactions belong to one customer. Safe to join if you're aggregating up to the customer level (total transaction count, average purchase size). Dangerous if you're working at the transaction level—now every row carries redundant customer data.
  • One-to-many (1:N): One customer has many transactions. This is where things go wrong. If you naively join customers (1,000 rows) with transactions (50,000 rows), you get a result table with 50,000 rows where customer data is duplicated across every transaction. A customer with 100 transactions has their age, city, and income repeated 100 times. The model sees this repetition and learns that these customers are "more important" than customers with 1 transaction—not because they are, but because they appear more often in the data.
  • Many-to-many (N:M): Orders and products (via an order_items join table). Every order can have multiple products, every product can appear in multiple orders. Join these tables and you get a Cartesian explosion—1,000 orders × 50 products = 50,000 rows, most of which are meaningless combinations that never occurred.

The Row Explosion Problem

Row explosion is the silent killer of multi-table models. Here's what happens:

  1. You have 1,000 customers and 50,000 transactions (50 transactions per customer on average).
  2. You join them with a 1:N relationship to predict transaction fraud.
  3. The result table has 50,000 rows, but customers with more transactions are over-represented.
  4. Customer A (1 transaction) appears 1 time. Customer B (200 transactions) appears 200 times.
  5. The model learns Customer B's patterns 200× as strongly as Customer A's.
  6. When you evaluate accuracy, 90% of your data is high-transaction customers, so the model optimizes for them and ignores the long tail.
  7. You've introduced sampling bias through the join—not because your data was biased, but because the join structure created artificial repetition.

How Featrix Handles Multi-Table Data

Featrix doesn't force you to denormalize into a single flat table. Instead, it uses hierarchical embedding spaces:

  1. Train a separate embedding space for each table: Customer ES learns customer patterns (age, income, city). Transaction ES learns transaction patterns (amount, timestamp, merchant).
  2. Learn join relationships: When you specify that customers and transactions have a 1:N relationship, Featrix trains a join embedding that aggregates transaction patterns up to the customer level.
  3. Avoid row explosion: The customer-level model sees each customer exactly once, with learned representations of their transaction patterns (average amount, transaction frequency, merchant diversity) rather than raw repeated rows.

This architecture mirrors how databases actually work—normalized tables with foreign key relationships—but with learned aggregations instead of hand-crafted SQL GROUP BY clauses.

Validating Join Accuracy

Joins can fail in subtle ways:

  • Missing foreign keys: 10% of transactions have customer_id = NULL because of data quality issues. Do you drop these rows? Impute a synthetic customer? Each choice biases your model.
  • Mismatched key types: Customers are keyed by integer IDs (1, 2, 3) but transactions use string IDs ("1", "2", "3"). The join silently fails, matching nothing.
  • Time consistency: You join customer data from January with transaction data from December. Customers who signed up in November have no January data, creating temporal leakage.

Featrix validates joins before training:

  • Coverage check: What % of rows successfully joined? <95% triggers a warning.
  • Key type validation: Foreign keys must match types (int ↔ int, str ↔ str).
  • Cardinality check: If you declared a 1:1 join but the result has 10× more rows, something is wrong.
  • Distribution verification: Does the joined data have the same class balance as the original? If not, the join introduced selection bias.

1.4 Common Data Problems Featrix Handles Automatically

Real data is a disaster. Featrix expects this and builds in defenses:

Missing Values (NULLs, NaNs, Empty Strings) - Every encoder has a learned "replacement embedding" for missing values. The model learns what "unknown" means in context. - No imputation needed—no filling with mean, median, or mode. Missingness is treated as information (e.g., missing income might correlate with fraud).

Mixed Data Types in Columns - A column labeled "age" contains 25, "thirty", NULL, 999 (obviously wrong). Featrix detects the majority type (scalar), parses what it can, and flags outliers.

High Cardinality Categories - A user_id column has 10,000 unique values. Traditional one-hot encoding creates 10,000 dimensions—computationally impossible and statistically meaningless. Featrix uses learned embeddings (one d_model-dimensional vector per category) + semantic BERT fallbacks for unseen categories.

Rare Classes (1% Minority Class) - You're predicting fraud (0.5% positive rate). A naive model can achieve 99.5% accuracy by always predicting "not fraud"—useless in production. Featrix uses: - Focal loss (down-weight easy examples, focus on hard ones) - Automatic class weights (sqrt inverse frequency) - Adaptive detection of reverse bias (when the model over-predicts the minority)

Noisy Data (Typos, Inconsistent Formatting) - Product names: "iPhone 13 Pro", "Iphone 13 pro", "iphone13pro", "iPhone 13 Pro Max". Semantic encoders (BERT) learn these are all similar, even though exact string matching fails.

Useless Columns (UUIDs, Hashes, All-Null, Uniform Values) - transaction_id (UUID format, 100% unique) = pure noise. Featrix detects high uniqueness + low semantic similarity + structural pattern (8-4-4-4-12 hex) → auto-exclude. - country (all values = "United States") = zero information. Featrix detects uniform columns → auto-exclude. - metadata_json (all NULL) = no data. Featrix detects all-null → auto-exclude.

The Philosophy: Fail Loud, Recover Gracefully

Featrix doesn't silently ignore problems. When it detects suspicious data, it logs warnings:

[WARN] Column 'user_id' has 98.2% unique values and low semantic similarity.
       Likely an identifier column. Excluding from training.
[WARN] Column 'country' has only 1 unique value ('United States').
       Provides no discrimination. Excluding from training.
[WARN] Join between customers and transactions has 87.3% coverage.
       1,234 transactions have missing customer_id. Review data quality.

But the system continues. It doesn't crash because of dirty data—it defends against it, logs what it did, and trains anyway. You can review the warnings and decide if they matter for your use case.

The result: Featrix works on real-world data, not just the clean toy datasets used in academic papers. It expects missing values, outliers, typos, and structural problems—and it handles them automatically.


2. Featrix Core Concepts

Now that you understand what data Featrix accepts, let's talk about what Featrix actually does with that data. There are three key concepts you need to understand: the Embedding Space, the Predictor, and the Prediction. Everything else is implementation detail.

2.1 The Embedding Space (Foundational Model)

What it is: A learned universal representation of your entire dataset.

Think of it like this: you have a customer table with 50 columns and 100,000 rows. Each row represents a customer—their age, location, purchase history, behavior patterns. Traditional ML treats each column as an independent feature. Featrix doesn't. The Embedding Space learns a single dense vector (a "row embedding") for each customer that captures everything about that customer in 128 to 512 dimensions.

This vector isn't hand-crafted. You didn't tell the system to compute "total purchases divided by account age" or "whether the customer is in a high-value zip code." The system learned what matters by trying to predict columns from other columns. When it successfully predicts a customer's purchase frequency from their age and location, it has learned that those features correlate. When it fails to predict their favorite color from anything else, it learns that favorite color is independent. The embedding captures these learned relationships.

What it does: Maps every row → a dense vector that captures all relationships between columns.

Feed any customer into the trained Embedding Space, and you get back: - A row embedding (128–512 dimensions): The customer as a point in learned feature space - Column embeddings (one per column): Learned representations of what each column means - Mutual information estimates: Which columns predict which other columns (measured in bits) - 3D projection: A compressed visualization of the row embedding

Training: Self-supervised, no labels needed.

The Embedding Space trains without any target variable. You don't tell it you want to predict churn or lifetime value. It just looks at your data and learns structure. The training objective is simple: mask some columns, predict them from the remaining columns. If the model can predict column B from columns A, C, and D, then B must be correlated with those columns. Do this across all columns, millions of times, and the model learns the joint distribution of your data.

This is called self-supervised learning. No labels, no human annotation, no manual feature engineering. Just the data teaching itself.

Training time: 30 minutes to 2 hours depending on data size.

  • 1,000 rows, 20 columns: ~10 minutes
  • 10,000 rows, 50 columns: ~30 minutes
  • 100,000 rows, 100 columns: ~2 hours

This is a one-time cost. Once you've trained the Embedding Space, you reuse it for every downstream task.

Analogy: Like Word2Vec, but for your entire dataset.

Word2Vec learns that "king" and "queen" are similar, and that "king - man + woman ≈ queen." It does this by predicting words from context. Featrix does the same thing, but for rows in your table. It learns that "high-income urban customer" and "high-income suburban customer" are similar, and that they're both different from "low-income rural customer."

When you need it: Foundation for all downstream tasks.

You always train an Embedding Space first. It's the foundation. Everything else—predictions, clustering, similarity search—is built on top of this learned representation.

2.2 The Predictor (Task-Specific Model)

What it is: A lightweight neural network that predicts a specific target variable.

Once you have an Embedding Space, you train Predictors on top of it. A Predictor is a simple 2-4 layer neural network that takes row embeddings as input and outputs predictions for a single target column.

Want to predict customer churn? Train a Predictor with target_column="churned". Want to predict lifetime value? Train a different Predictor with target_column="lifetime_value". Both Predictors use the same Embedding Space—the same learned row representations—but they specialize for different tasks.

What it does: row_embedding → prediction (with confidence).

The Predictor doesn't re-learn what your data means—the Embedding Space already did that. The Predictor just learns a mapping from the row embedding to the target variable. This is much faster than training from scratch because the hard part (understanding the data) is already done.

Training: Supervised, requires labeled data.

Unlike the Embedding Space, Predictors need labels. You must have a target column with known values (churn = yes/no, lifetime value = $1,250, segment = gold/silver/bronze). The Predictor learns to predict these labels from the row embeddings.

Training time: 5-20 minutes.

Because the Embedding Space did the heavy lifting, Predictor training is fast: - 1,000 rows: ~2 minutes - 10,000 rows: ~5 minutes - 100,000 rows: ~15 minutes

You can train dozens of Predictors in the time it would take to train one traditional model from scratch.

Analogy: Like fine-tuning BERT for sentiment analysis.

BERT is a foundational model trained on massive amounts of text to learn language. When you fine-tune BERT for sentiment analysis, you add a small classification layer on top and train just that layer (or fine-tune the whole model with a very low learning rate). Featrix Predictors work the same way: the Embedding Space is the foundational model, and the Predictor is the task-specific fine-tuning.

When you need it: When you have a specific prediction target.

You only train a Predictor when you have a supervised learning task—something with a known target variable you want to predict on new data.

2.3 The Prediction (Output)

What it is: The result of running a trained Predictor on new data.

Once your Predictor is trained, you feed it new rows (customers you haven't seen before) and get back predictions.

Structure: Predictions + Confidence + Explanations + Guardrails.

Featrix doesn't just give you a prediction. It gives you: - The prediction itself: "This customer will churn" or "Predicted lifetime value = $1,840" - Confidence score: "85% confident" or "95% confidence interval: [$1,200, $2,500]" - Feature importance: "The top 3 factors were: account_age (42%), last_purchase_days (31%), support_tickets (18%)" - Warnings: "This customer's income is outside the training distribution. Prediction uncertainty is high."

Classification Output: - Binary classification: {"prediction": "churn", "probability": 0.847, "confidence": "high"} - Multi-class: {"prediction": "gold", "probabilities": {"gold": 0.63, "silver": 0.28, "bronze": 0.09}}

Regression Output: - Continuous values: {"prediction": 1840.32, "confidence_interval": [1204.18, 2476.45], "std": 318.07}

Guardrails: - Out-of-distribution detection: "This row is unlike anything in the training data. Embedding distance = 4.2σ above mean." - Class imbalance warnings: "Model was trained on 3% positive rate. Threshold optimized for F1 = 0.34." - Feature contributions: "This prediction was driven by transaction_count (38%) and account_age (-22%)."

API Response: JSON with predictions, probabilities, warnings.

All of this is returned in a structured JSON response that's easy to parse and log:

{
  "predictions": [
    {
      "row_id": 1001,
      "prediction": "churn",
      "probability": 0.847,
      "confidence": "high",
      "warnings": [],
      "feature_importance": {
        "account_age_days": -0.42,
        "days_since_last_purchase": 0.31,
        "support_tickets_30d": 0.18
      }
    }
  ],
  "metadata": {
    "model_version": "v2.3.1",
    "trained_date": "2025-01-03T10:30:00Z",
    "threshold": 0.34,
    "class_distribution_train": {"churn": 0.03, "active": 0.97}
  }
}

2.4 The Workflow: From Data to Predictions

Here's the complete flow:

1. Upload Data
2. Train Embedding Space (30-120 min)
   ├─ Self-supervised learning
   ├─ No labels needed
   └─ Produces: row embeddings + column embeddings + MI estimates
3. Train Predictor (5-20 min)
   ├─ Supervised learning
   ├─ Requires target column
   └─ Produces: trained predictor model
4. Make Predictions (milliseconds per row)
   ├─ New data → Embedding Space → row embedding
   ├─ Row embedding → Predictor → prediction
   └─ Returns: prediction + confidence + explanations + warnings

Example Timeline: - 9:00 AM: Upload customer data (100,000 rows, 50 columns) - 9:05 AM: Embedding Space training starts - 10:15 AM: Embedding Space training completes (70 minutes) - 10:20 AM: Train churn predictor (5 minutes) - 10:25 AM: Start making predictions (production-ready)

Total time from upload to predictions: 85 minutes. No data scientists. No hyperparameter tuning. No feature engineering.

2.5 Reusability: Train Once, Use Many Times

The key insight: one Embedding Space supports unlimited Predictors.

You train the Embedding Space once (the expensive part). Then you train as many Predictors as you want (cheap):

Customer Data
Embedding Space (train once: 60 min)
     ├─ Churn Predictor (train: 5 min)
     ├─ Lifetime Value Predictor (train: 8 min)
     ├─ Segment Predictor (train: 6 min)
     ├─ Next Purchase Predictor (train: 7 min)
     └─ Support Ticket Risk Predictor (train: 5 min)

Total time: 60 min (ES) + 31 min (5 predictors) = 91 minutes for 5 production models.

Traditional ML would require training 5 separate models from scratch, each taking 2-3 months of work. Featrix does it in 91 minutes.

When to retrain the Embedding Space: - Data distribution changes significantly (new product categories, geographic expansion) - Schema changes (new columns added, old columns removed) - Model performance degrades over time (concept drift)

When to retrain a Predictor: - Target definition changes (you redefine what "churn" means) - Class distribution shifts (churn rate goes from 3% to 8%) - New Embedding Space is trained

When to retrain nothing: - Just making predictions on new data (inference is fast, no retraining needed)


3. What You Can Do With Featrix

You've trained an Embedding Space and some Predictors. What now? Here's what Featrix enables, across three broad categories: classification, regression, and unsupervised learning.

3.1 Classification Problems

Binary Classification: Predict yes/no, true/false, 0/1.

Use cases: - Churn prediction: Will this customer cancel their subscription next month? - Fraud detection: Is this transaction fraudulent? - Loan approval: Should we approve this loan application? - Email spam: Is this email spam or legitimate? - Medical diagnosis: Does this patient have the disease?

Featrix handles: - Imbalanced classes: 1% positive rate? No problem. Automatic focal loss + class weighting. - Cost asymmetry: False positives cost $10, false negatives cost $1,000? Specify costs, get optimal threshold. - Calibrated probabilities: When the model says 70%, it means 70%—not "some arbitrary score."

Example:

churn_predictor = SinglePredictorMLP(
    embedding_space=es,
    target_column="will_churn",
    positive_label="yes",
    cost_false_positive=50,   # Cost of incorrectly predicting churn
    cost_false_negative=500   # Cost of missing a churner
)
churn_predictor.train()

# Predictions come with calibrated probabilities
predictions = churn_predictor.predict(new_customers)
# [{"prediction": "yes", "probability": 0.82, "confidence": "high", ...}, ...]

Multi-Class Classification: Predict one of N categories.

Use cases: - Customer segmentation: gold/silver/bronze tier - Product categorization: electronics/clothing/home/toys - Sentiment analysis: positive/neutral/negative - Disease diagnosis: healthy/diabetes/prediabetes - Lead scoring: hot/warm/cold

Featrix handles: - Imbalanced classes: One class has 60%, another has 2%? Automatic class weighting. - Hierarchical categories: Categories have natural orderings (mild/moderate/severe)? Use ordinal encoding. - Softmax output: Probabilities sum to 1.0 across all classes.

Example:

segment_predictor = SinglePredictorMLP(
    embedding_space=es,
    target_column="customer_segment",
    target_column_type="set"
)
segment_predictor.train()

predictions = segment_predictor.predict(new_customers)
# [{"prediction": "gold", "probabilities": {"gold": 0.63, "silver": 0.28, "bronze": 0.09}}, ...]

3.2 Regression Problems

Continuous Value Prediction: Predict a number.

Use cases: - Lifetime value (LTV): How much revenue will this customer generate? - Price prediction: What's the fair price for this house/car/product? - Demand forecasting: How many units will we sell next month? - Time-to-event: How many days until this customer makes another purchase? - Risk scoring: What's the expected loss on this loan?

Featrix handles: - Bounded ranges: Output must be between 0 and 100? Automatic sigmoid/tanh scaling. - Heavy-tailed distributions: Revenue ranges from $10 to $100,000? Log-scale transformations. - Uncertainty quantification: Not just a point estimate—get confidence intervals.

Example:

ltv_predictor = SinglePredictorMLP(
    embedding_space=es,
    target_column="lifetime_value_usd",
    target_column_type="scalar"
)
ltv_predictor.train()

predictions = ltv_predictor.predict(new_customers)
# [{"prediction": 1840.32, "confidence_interval": [1204.18, 2476.45], "std": 318.07}, ...]

Use the Embedding Space Directly (No Predictor Needed)

Once you have an Embedding Space, you don't always need a Predictor. The row embeddings themselves are incredibly useful for unsupervised learning tasks.

Similarity Search: Find k-nearest neighbors in embedding space.

Use cases: - Recommendation systems: "Customers who bought X also bought Y" → find customers similar to X, see what they bought. - Duplicate detection: Find near-duplicate records in your database. - Anomaly detection: Find rows that are very different from everything else. - Content-based search: "Find products similar to this one."

How it works: 1. Encode all rows → row embeddings (vectors) 2. Store embeddings in a vector database (Pinecone, Weaviate, Milvus) or just use NumPy 3. Query: Encode new row → find k nearest neighbors by cosine similarity

Example:

# Get embeddings for all customers
customer_embeddings = es.encode(customer_df)  # Shape: (100000, 128)

# Find 10 most similar customers to customer_id=42
query_embedding = customer_embeddings[42]
similarities = cosine_similarity([query_embedding], customer_embeddings)[0]
top_10_indices = np.argsort(similarities)[-11:-1]  # Exclude self

similar_customers = customer_df.iloc[top_10_indices]

Clustering: Group similar rows together.

Use cases: - Customer segmentation: Automatically discover natural customer groups (no labels needed). - Anomaly detection: Outliers = points far from any cluster. - Exploratory data analysis: "What natural groups exist in this data?" - Dimensionality reduction for visualization: 3D projection shows clusters visually.

How it works: 1. Encode all rows → row embeddings 2. Run K-means, DBSCAN, or hierarchical clustering on embeddings 3. Assign cluster labels, analyze cluster characteristics

Example:

from sklearn.cluster import KMeans

# Get embeddings
customer_embeddings = es.encode(customer_df)

# Cluster into 5 groups
kmeans = KMeans(n_clusters=5, random_state=42)
cluster_labels = kmeans.fit_predict(customer_embeddings)

# Analyze clusters
customer_df['cluster'] = cluster_labels
customer_df.groupby('cluster').agg({
    'revenue': 'mean',
    'age': 'median',
    'account_age_days': 'mean'
})

Anomaly Detection: Outliers = rows far from neighbors.

Use cases: - Fraud detection: Transactions that don't fit any normal pattern. - Quality control: Defective products with unusual sensor readings. - Network security: Unusual traffic patterns. - Medical diagnosis: Patients with rare combinations of symptoms.

How it works: 1. Encode all rows → row embeddings 2. For each row, compute distance to k-nearest neighbors 3. Rows with high mean distance = outliers

Example:

from sklearn.neighbors import NearestNeighbors

customer_embeddings = es.encode(customer_df)

# Fit k-NN model
knn = NearestNeighbors(n_neighbors=10)
knn.fit(customer_embeddings)

# Compute mean distance to 10 nearest neighbors for each point
distances, indices = knn.kneighbors(customer_embeddings)
anomaly_scores = distances.mean(axis=1)

# Top 1% highest scores = anomalies
threshold = np.percentile(anomaly_scores, 99)
anomalies = customer_df[anomaly_scores > threshold]

Cross-Modal Search: Search by meaning, not exact text match.

Use cases: - Product search: "Find products similar to 'wireless noise-canceling headphones'" → matches "Bluetooth headphones with ANC" even though text doesn't match. - Customer search: "Find customers who behave like our top 10% by LTV" → encode the query, search embeddings. - Document retrieval: "Find support tickets about billing errors" → semantic match, not keyword match.

How it works: 1. Encode query (could be text, could be a row, could be a synthetic description) 2. Search embedding space for nearest neighbors 3. Return top-k results

Example:

# Find customers similar to high-value segment
high_value_customers = customer_df[customer_df['lifetime_value'] > 10000]
high_value_embeddings = es.encode(high_value_customers)
centroid = high_value_embeddings.mean(axis=0)

# Search all customers for similarity to this centroid
all_embeddings = es.encode(customer_df)
similarities = cosine_similarity([centroid], all_embeddings)[0]
top_100_similar = np.argsort(similarities)[-100:]

# These customers "look like" high-value customers even if they're not yet
lookalike_customers = customer_df.iloc[top_100_similar]

3.5 Exploratory Data Analysis

Understand Your Data Before Building Models

Use cases: - Discover natural clusters: What groups exist in this data? - Visualize high-dimensional data: 3D sphere projection shows structure at a glance. - Identify outliers: Which rows are weird? - Understand column relationships: Mutual information heatmaps show which columns predict which.

3D Visualization: - Every row → 3D point on a unit sphere - Similar rows cluster together - Outliers sit far from clusters - Color by target variable to see separability

Example:

# Get 3D embeddings
short_embeddings = es.encode_short(customer_df)  # Shape: (100000, 3)

# Plot in 3D
import plotly.express as px
fig = px.scatter_3d(
    x=short_embeddings[:, 0],
    y=short_embeddings[:, 1],
    z=short_embeddings[:, 2],
    color=customer_df['churned'],
    title="Customer Embedding Space (3D)"
)
fig.show()

Mutual Information Heatmap: - Shows which columns predict which - Quantified in bits (not just correlation) - Detects nonlinear relationships

Example:

mi_matrix = es.get_mutual_information_matrix()
# mi_matrix[i, j] = bits of information column i provides about column j

import seaborn as sns
sns.heatmap(mi_matrix, xticklabels=column_names, yticklabels=column_names)


4. Safety, Monitoring, and Zero Configuration

Traditional ML models fail in production for two reasons: they were misconfigured during training, or they degrade over time without anyone noticing. Featrix addresses both problems with automatic configuration and built-in monitoring.

4.1 Stratified Sampling (Avoiding Data Leakage)

The Problem: Naive random train/val splits leak information.

Imagine you're predicting customer churn. You randomly split 80% train, 20% validation. Sounds reasonable. But: - Rare classes might not appear in validation (0.5% fraud rate → maybe zero fraud cases in validation). - Correlated rows (multiple transactions from the same customer) appear in both train and validation, leaking information. - The split creates biased subsets—maybe all high-value customers ended up in training by chance.

Featrix's Solution: Stratified splitting for classification.

For classification tasks, each class is split separately to ensure representation:

  • ≥ min_samples in class → split proportionally (e.g., 80/20 within that class)
  • < min_samples but > 1 → all to training (can't validate reliably with 1-2 samples)
  • = 1 sample → all to training (single sample can't split)

After splitting: - Validation coverage is reported: "85.6% of samples can be validated, 144 samples excluded across 23 rare categories." - Distribution verification: Per-column KL divergence between train/val. KL > 1.0 triggers a warning.

For regression tasks: - Gradual rotation: Every 25+ epochs during Embedding Space training (not Predictor training), swap 10% of samples between train and val to prevent overfitting to the validation set. - This only happens for self-supervised ES training, never for labeled Predictor training (where fixed split is required).

Example:

Dataset: 10,000 rows, target = "churned" (97% active, 3% churned)
Stratified split:
  - Active class: 9,700 rows → 7,760 train, 1,940 val
  - Churned class: 300 rows → 240 train, 60 val
Result: Both train and val have ~3% churn rate

4.2 Training Monitoring and Auto-Correction

What's Tracked: Loss, gradients, learning rate, dropout, spread loss, temperature.

Every epoch, Featrix logs: - Training loss, validation loss - Learning rate (from OneCycleLR scheduler) - Dropout rate (from adaptive schedule) - Gradient norms (unclipped and clipped) - Spread loss (contrastive objective that prevents embedding collapse) - Adaptive temperature (scales with batch size and column count)

Failure Detection: 6 modes for Embedding Space, 6 for Predictors.

Embedding Space failures: - DEAD_NETWORK: Gradients < 1e-8 → model isn't learning anything - NO_LEARNING: Val loss flat for 5+ epochs after epoch 15 → stuck in plateau - SEVERE_OVERFITTING: Val loss ↑ while train loss ↓ → memorizing training data - UNSTABLE_TRAINING: Loss oscillates wildly (coefficient of variation > 10%) - VERY_SLOW_LEARNING: <1% improvement over 5 epochs, low gradients - MODERATE_OVERFITTING: Train/val gap > 10%, early warning

Predictor failures: - DEAD_NETWORK: All outputs identical (prob_std < 0.001) - CONSTANT_PROBABILITY: Very low variance (prob_std < 0.03) - SINGLE_CLASS_BIAS: >95% predictions are same class (predicting majority always) - RANDOM_PREDICTIONS: AUC < 0.55 (no better than random guessing) - UNDERCONFIDENT: >70% predictions in [0.4, 0.6] range (model unsure about everything) - POOR_DISCRIMINATION: AUC < 0.65, accuracy < 0.6 (varied but wrong)

Auto-Interventions: LR boost, temperature boost, early stop blocking.

When NO_LEARNING is detected: 1. LR boost (3×): Multiply current LR by 3 for 20 epochs to escape plateau 2. Early stop blocking: Prevent early stopping for 10 epochs to give model time to recover 3. Temperature boost (2×) (if still stuck): Soften contrastive objective to allow more exploration

When SINGLE_CLASS_BIAS is detected in Predictor: 1. Focal gamma reduction: Less focus on hard examples, more on easy examples 2. Min_weight increase: Give more credit to majority class to balance predictions

Example timeline:

Epoch 20: NO_LEARNING detected (val loss flat from 4.52 → 4.51 over 5 epochs)
Epoch 21: LR boost applied (0.0008 → 0.0024), early stopping blocked
Epochs 21-40: LR boost active, loss drops from 4.51 → 3.89
Epoch 41: LR boost ends, return to schedule, early stopping re-enabled
Result: Plateau escaped, training continues normally

4.3 Prediction Guardrails

Confidence Scoring: Always included, calibrated probabilities.

Featrix predictions always include confidence: - Binary classification: "probability": 0.82, "confidence": "high" - Regression: "prediction": 1840.32, "confidence_interval": [1204.18, 2476.45]

Probabilities are calibrated using temperature scaling, Platt scaling, or isotonic regression so that "70%" means "70% chance," not "arbitrary score 0.7."

Out-of-Distribution Detection: Warns when input is unlike training data.

Every prediction computes embedding distance from training centroid: - Distance < 2σ: Normal, high confidence - Distance 2σ–3σ: Warning, moderate confidence - Distance > 3σ: Strong warning, low confidence

Example warning:

{
  "prediction": "churn",
  "probability": 0.65,
  "warnings": [
    "Row embedding distance = 3.4σ above mean. Input is unlike training data. Prediction uncertainty is high."
  ]
}

Class Imbalance Warnings: Flags when predictions may be biased.

If training data was 3% positive rate, every prediction response includes:

{
  "metadata": {
    "class_distribution_train": {"churn": 0.03, "active": 0.97},
    "optimal_threshold": 0.34,
    "note": "Model trained on imbalanced data. Threshold optimized for F1 score."
  }
}

Feature Importance: Which columns contributed to this prediction.

Every prediction includes feature importance:

{
  "feature_importance": {
    "account_age_days": -0.42,    // Older accounts less likely to churn
    "days_since_last_purchase": 0.31,  // Recent inactivity increases churn risk
    "support_tickets_30d": 0.18   // More tickets = higher churn risk
  }
}

Negative values = feature pushes prediction down, positive = pushes up.

4.4 Zero Hyperparameters

No Grid Search Needed: Batch size, LR, epochs, dropout, temperature all auto-set.

Traditional ML requires tuning: - Learning rate (1e-5 to 1e-2) - Batch size (16 to 2048) - Epochs (10 to 1000+) - Dropout (0.0 to 0.7) - Weight decay (1e-6 to 1e-1) - Temperature (0.01 to 1.0) - Loss weights (manual balancing) - Class weights (often forgotten)

Featrix computes all of these automatically from your data:

Parameter How It's Computed
Batch size min(2048, max(32, n_rows/100)) rounded to power of 2
Learning rate Based on dataset size: <1K rows = 5e-5, 5K-20K = 2e-4, >20K = 3e-4
Epochs Target 36,000 optimizer updates: epochs = 36000 / steps_per_epoch
Dropout Scheduled: 50% early (exploration) → 25% late (refinement)
Weight decay Based on dataset size: <1K = 0.1, 1K-5K = 0.01, >20K = 0.0001
Temperature Adaptive: base / (batch_factor × column_factor), clamped to [0.01, 0.2]
Class weights Inverse-sqrt frequency: weight = sqrt(1/frequency), normalized
Gradient clipping Adaptive: clip_threshold = loss × 2.0 (scales with loss magnitude)

Data-Driven, Not Heuristics: - Small datasets (<1K rows): Conservative LR, high weight decay to prevent overfitting - Large datasets (>20K rows): Aggressive LR, low weight decay for faster convergence - Imbalanced classes: Automatic class weighting, focal loss with adaptive gamma - Many columns: Lower temperature to prevent contrastive collapse

Comparison Table:

Parameter Traditional ML Featrix
Learning Rate Grid search 10+ values Auto from dataset size
Batch Size Manual selection Auto from rows + GPU memory
Epochs Guesswork + early stopping Auto from update count target
Dropout Trial and error Scheduled by training phase
Temperature Often ignored or fixed at 0.1 Adaptive to batch/column count
Loss Weights Manual balancing Curriculum learning
Class Weights Often forgotten Inverse-sqrt of frequency
Gradient Clipping Fixed threshold (e.g., 1.0) Adaptive to loss scale

Result: Zero configuration. Upload data, start training, get production-ready models.

4.5 Visualization Capabilities

3D Sphere View: Real-time rendering of embedding space.

Every Embedding Space produces 3D projections (first 3 dimensions of row embeddings, L2-normalized). Render as points on a sphere: - Similar rows cluster together - Outliers sit far from clusters - Color by target variable to visualize separability

Training Movies: Watch embeddings evolve epoch-by-epoch.

Featrix can save 3D embeddings at every epoch during training. Stitch them together into a video: - Epoch 1: Random initialization, points scattered everywhere - Epoch 10: Clusters starting to form - Epoch 50: Clear separation, structure emerges - Epoch 100: Converged, stable clusters

Timeline Plots: 6 subplots showing LR, loss, dropout, gradients, spread, temperature.

The training_timeline.json file can be visualized with visualize_training_timeline.py: 1. Learning Rate Schedule (log scale): Warmup → peak → cosine decay, with intervention markers 2. Train/Val Loss: Should both decrease, not diverge 3. Dropout Schedule: 50% → 25% piecewise constant 4. Gradient Norms: Unclipped vs clipped (should be similar if healthy) 5. Spread Loss: Contrastive objective preventing collapse 6. Adaptive Temperature: Scales with batch size and column count

Mutual Information Heatmaps: Which columns predict which.

Heatmap shows MI[i, j] = bits of information column i provides about column j: - High MI (>2 bits): Strong predictive relationship - Low MI (<0.5 bits): Independent columns - Asymmetric: MI(age → income) ≠ MI(income → age)

Example:

       age  income  city  occupation
age    5.2    1.8   0.4      1.2
income 1.6    6.1   0.9      2.3
city   0.3    0.8   4.7      0.6
occup. 1.1    2.1   0.5      5.8

Interpretation: - age provides 1.8 bits about income (age predicts income moderately) - city provides only 0.4 bits about age (city doesn't predict age well) - Diagonal values (self-MI) are high (each column perfectly predicts itself)

This heatmap guides feature selection: columns with low MI to the target are less useful.


PART II: TECHNICAL DEEP DIVE

This section is for ML engineers, researchers, and anyone who needs to understand the internal mechanics of Featrix. We'll cover the neural architecture, training dynamics, and implementation details.


Architecture Overview

This subsection explains how Featrix transforms raw data into learned representations. We'll cover type-specific encoding, the transformer-based joint encoder, self-supervised masking, and the multi-objective loss function.


5. Type-Specific Encoding: From Values to Embeddings

Each data type gets a specialized codec (tokenizer) and encoder (neural network) that converts raw values into d_model-dimensional embeddings (default 128, configurable up to 512). This is where Featrix's architectural innovation is most visible—every encoder extracts maximum signal from its data type.

5.1 Scalar Encoding (AdaptiveScalarEncoder)

Numeric values are deceptively difficult to encode. Linear normalization fails for heavy-tailed distributions, multi-modal data, and columns with outliers. Featrix uses a 20-strategy adaptive ensemble that learns which transformations work best for each column.

The 20 Transform Strategies

# Strategy Input Transform Best For
1 Linear (x - mean) / std Normal distributions
2 Log log(x + ε) Heavy-tailed, exponential
3 Robust (x - median) / IQR Outlier-heavy data
4 Rank Percentile [0, 1] Any distribution (order-preserving)
5 Periodic [sin(2πx), cos(2πx)] Cyclical/temporal features
6 Bucket Quantile bin index Noise reduction, easier relationships
7 Is Positive x > 0 → 1, else 0 "Has_thingy" features
8 Is Negative x < 0 → 1, else 0 Sign detection
9 Is Outlier |x - μ| > 2σ → 1 Anomaly flagging
10 Z-Score (x - mean) / std Classic standardization
11 Min-Max (x - min) / (max - min) Fixed-range features
12 Quantile Uniform distribution Any distribution
13 Yeo-Johnson Power transform Skewed data (handles ≤0)
14 Winsorization Clip to percentiles Gentle outlier handling
15 Sigmoid σ(x) Soft-squash extremes
16 Inverse 1/x Diminishing returns
17 Polynomial [x², √x] 2nd-order interactions
18 Frequency Count encoding Categorical-like integers
19 Target-Guided Binning Quantile bins Supervised signal
20 Clipped Log log(1 + clip(x)) Stable log for edge cases

Architecture: Dual-Path Encoding

Continuous transforms (smooth, differentiable) and binned transforms (discrete, embedding-based) have fundamentally different gradient properties. Mixing them in a single MLP causes training instabilities. The encoder computes both paths in parallel, then learns a gating function that decides how much to use each. Typical gate values land between 0.4–0.6, using both paths, with a skew toward continuous for smooth data.

                         Raw Scalar Value
                        ┌───────────────┐
                        │   Normalize   │
                        └───────┬───────┘
                ┌───────────────┴───────────────┐
                │                               │
                ▼                               ▼
┌───────────────────────────┐   ┌───────────────────────────┐
│     CONTINUOUS PATH       │   │       BINNED PATH         │
│                           │   │                           │
│ 17 smooth scalar features │   │ 5 discrete embeddings:    │
│  • 12 transforms:         │   │  • bucket (quantile bin)  │
│    linear, log, robust,   │   │  • is_positive            │
│    rank, zscore, minmax,  │   │  • is_negative            │
│    quantile, yeojohnson,  │   │  • is_outlier             │
│    winsor, sigmoid,       │   │  • target_bin             │
│    inverse, clipped_log   │   │                           │
│  • 2 periodic (sin, cos)  │   │ Each → Embedding(d_model) │
│  • 2 polynomial (x², √x)  │   │ Concatenate → 5×d_model   │
│  • 1 frequency            │   │                           │
│                           │   │          ↓                │
│          ↓                │   │   MLP → d_model           │
│   MLP → d_model           │   │                           │
└─────────────┬─────────────┘   └─────────────┬─────────────┘
              │                               │
              └───────────────┬───────────────┘
                    ┌───────────────────┐
                    │   MIXER GATE      │
                    │                   │
                    │ MLP([cont, bin])  │
                    │        ↓          │
                    │   sigmoid → g     │
                    │                   │
                    │ out = g×cont +    │
                    │      (1-g)×bin    │
                    └─────────┬─────────┘
                    ┌───────────────────┐
                    │ Final Projection  │
                    │ + LayerNorm       │
                    └─────────┬─────────┘
                 (short_vec: 3D, full_vec: d_model)

Strategy Pruning and Initialization

Every 5 epochs, strategies contributing less than 5% of the average weight are pruned (disabled). Pruning stops when 3 strategies remain. A converged income column might end up with Linear at 42%, Log at 31%, and Robust at 27%—the other 17 strategies pruned away.

Strategy weights start biased based on column statistics: high skewness favors Log and Yeo-Johnson; high outlier ratio favors Robust and Rank; bounded ranges favor Min-Max; many zeros favor Bucket and Is-Positive.


5.2 String Encoding (StringEncoder)

Free-form text passes through a centralized String Embedding Server that hosts a pre-trained sentence transformer (all-MiniLM-L6-v2, 384-dim). This architecture saves roughly 600MB of VRAM per training job by sharing one BERT instance across all workers. All jobs get identical embeddings, and an LRU cache of 131,072 entries minimizes redundant calls.

┌──────────────────────────────────────────────────────────────────┐
│                    STRING SERVER (Celery Worker)                 │
│  ┌─────────────────────────────────────────────────────────────┐ │
│  │  SentenceTransformer (all-MiniLM-L6-v2)                     │ │
│  │  - 384-dimensional embeddings                               │ │
│  │  - Loaded ONCE, shared across all training jobs             │ │
│  └─────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────────────┘
                                    │ Redis/HTTP
┌────────────────┐  ┌────────────────┐  ┌────────────────┐
│ Training Job 1 │  │ Training Job 2 │  │ Training Job 3 │
│ (no BERT load) │  │ (no BERT load) │  │ (no BERT load) │
└────────────────┘  └────────────────┘  └────────────────┘

Adaptive Compression Strategies

The encoder learns how much compression to apply to BERT embeddings across 7 strategies, including two attention-based encoders for structured strings:

Strategy Output Dim Use Case
ZERO 0 Random/uninformative strings (UUIDs, hashes)
DELIMITER 384 → d_model Legacy: simple averaging for delimited text
DELIM_ATTN d_model Attention pooling for delimited text
RADIX d_model Positional attention for fixed-width strings
AGGRESSIVE d_model/4 Heavy MLP compression (low entropy text)
MODERATE d_model/2 Balanced MLP compression
STANDARD d_model Maximum semantic capacity

Strategy weights are learned via softmax, and after warmup, only the top 2 strategies survive pruning.

DelimiterAttentionEncoder handles strings like "red,green,blue" by splitting on the delimiter, BERT-encoding each part, adding position embeddings (order matters), running self-attention across parts, and using attention pooling with a learned query to aggregate into a fixed-width output. The model learns things like "the last item in the list is most predictive" or "co-occurring items matter."

RadixAttentionEncoder handles fixed-width structured strings like dates (2024-01-15) or product codes (ABC-12345-XYZ). It chunks the string into character positions, applies character embeddings, adds strong positional embeddings (so the model knows positions 0-3 are year, 5-6 are month), runs 2-layer self-attention, and pools with attention.

Detection during string analysis determines which strategy to recommend: high unique ratio plus low semantic similarity plus ID patterns suggests ZERO; fixed width plus low entropy at separator positions suggests RADIX; consistent delimiters in >30% of values suggests DELIM_ATTN.


5.3 Set/Categorical Encoding (SetEncoder)

Categorical values use a hybrid learned + semantic architecture that enables generalization to unseen categories.

Category Value (e.g., "unemployed")
    ├──────────────────────────────────────┐
    │                                      │
    ▼                                      ▼
┌─────────────────┐              ┌─────────────────────┐
│ LEARNED PATH    │              │ SEMANTIC PATH       │
│                 │              │                     │
│ nn.Embedding    │              │ BERT("unemployed")  │
│ (n_members,     │              │        ↓            │
│  d_model)       │              │ Linear(384→d_model) │
└────────┬────────┘              └────────┬────────────┘
         │                                │
         └────────────────┬───────────────┘
                 ┌─────────────────┐
                 │ GATING NETWORK  │
                 │                 │
                 │ MLP([learned,   │
                 │      semantic]) │
                 │        ↓        │
                 │    sigmoid(g)   │
                 └────────┬────────┘
         output = g * learned + (1-g) * semantic

A gating network (small MLP) dynamically decides the learned/semantic mixture for each sample rather than using fixed per-member weights. This avoids overfitting on rare categories, allows the gate to condition on actual embedding values, and provides natural gradient flow through both paths. Initialization biases toward semantic (~38% learned, 62% semantic).

Opaque codes like "A14" end up around 90% learned / 10% semantic. Clear English words like "unemployed" land around 30% learned / 70% semantic. "Senior_software_engineer" falls in between.

When ordinal values are detected (e.g., "low", "medium", "high"), ordinal embeddings with smooth positional patterns (sin/cos) preserve order information. When an unseen category appears at inference, the semantic path is used exclusively—"Senior Software Engineer" works even if only "Software Engineer" was in training because BERT knows they're semantically similar.


5.4 Timestamp Encoding (TimestampCodec)

Temporal data decomposes into 12 interpretable cyclical and linear features. Cyclical encoding matters because December 31 and January 1 should be close in embedding space—linear encoding makes them maximally far apart.

Feature Range Encoding
seconds 0-59 Cyclical (sin/cos)
minutes 0-59 Cyclical (sin/cos)
hours 0-23 Cyclical (sin/cos)
day_of_month 1-31 Cyclical (sin/cos)
day_of_week 0-6 (Mon=0) Cyclical (sin/cos)
month 1-12 Cyclical (sin/cos)
year Absolute Linear
day_of_year 1-366 Cyclical (sin/cos)
week_of_year 1-53 Cyclical (sin/cos)
timezone_offset Hours from UTC Linear
year_since_2000 Years Linear
year_since_2020 Years Linear

These 12 features pass through an MLP (12 → 256 → 256 → d_model) with BatchNorm, residual connections, and dropout.


5.5 Other Encoders

Domain Encoding (DomainEncoder): Domain names decompose into semantic and structural components: subdomain (BERT → 64-dim), domain main (BERT → 128-dim), TLD (categorical embedding over 21 TLDs + "other" → 32-dim), TLD type (generic/country/new → 16-dim), is-free-email flag (8-dim), and optionally IP addresses from DNS lookup (up to 4 IPs × 8 features → 34-dim). These concatenate into ~282 dimensions and project through an MLP to d_model.

URL Encoding (URLEncoder): Full URLs combine multiple sub-encoders: protocol (categorical), domain (DomainEncoder), path (BERT), endpoint (BERT on last path segment), and query params (BERT).

Vector/Embedding Encoding (VectorEncoder): Pre-computed vectors from another model pass through directly or with optional MLP transformation.

JSON Encoding (JsonEncoder): Nested JSON structures are encoded by training a child Embedding Space on the flattened JSON. The parent ES stores the child ES mapping in metadata, and a projection layer (Linear(child_d_model → d_model)) bridges them.

Encoder Comparison

Encoder Input Key Innovation Parameters
AdaptiveScalarEncoder Numbers 20-strategy ensemble, dual-path 15-30K
StringEncoder Free text Server BERT, 7 compression strategies 10-20K
SetEncoder Categories Gating network, OOV via BERT 5-50K (vocab-dependent)
TimestampEncoder Dates 12 cyclical features 20K
DomainEncoder Domains Semantic + structural decomposition 30K
URLEncoder URLs Multi-component fusion 40K
VectorEncoder Embeddings Pass-through or MLP 0-10K
JsonEncoder JSON Child ES + projection Varies

6. The Joint Encoder: Learning Relationships via Attention

After column-level encoding, embeddings pass through a transformer-based joint encoder. This is where the magic happens: where "Age" learns it's related to "Income", where "City" and "State" become contextually linked, and where the whole table becomes more than the sum of its columns.

The Core Problem

Each column encoder produces a d_model-dimensional vector that captures everything about that column's value. But knowing "Age=35" and "Income=80K" separately doesn't capture that they're correlated—or that together they strongly predict "CreditScore."

The joint encoder solves this by attending across columns (each column "sees" every other column), learning implicit relationships (transformer discovers patterns like Age × Income → Wealth), computing explicit relationships (pairwise operations including ratios, products, differences), and aggregating everything into a single [CLS] token that becomes the row embedding.

Architecture Overview

                          INPUT: Column Embeddings
                            (batch, n_cols, d_model)
                    ┌────────────────────────────────┐
                    │  PER-COLUMN IN-CONVERTERS      │
                    │  (MLP per column type)         │
                    │  Standardizes column → d_model │
                    └────────────────────────────────┘
           ┌─────────────────────────┴─────────────────────────┐
           │                                                    │
           ▼                                                    ▼
┌─────────────────────────┐               ┌────────────────────────────────────┐
│ POSITIONAL ENCODING     │               │ DYNAMIC RELATIONSHIP EXTRACTOR     │
│ (Learned, per-column)   │               │                                    │
│                         │               │ For each column pair (A, B):       │
│ Tells transformer       │               │  • A × B (multiplication)          │
│ which column is which   │               │  • A + B (addition)                │
│                         │               │  • |A - B| (absolute difference)   │
│                         │               │  • cos(A, B) (angular similarity)  │
│                         │               │  • A - B, B - A (directional diff) │
│                         │               │  • A / B, B / A (ratios)           │
│                         │               │                                    │
│                         │               │ = 8 tokens per pair                │
│                         │               │ = N×(N-1)/2 pairs maximum          │
└─────────────────────────┘               └────────────────────────────────────┘
           │                                                    │
           └─────────────────────────┬─────────────────────────┘
                    ┌────────────────────────────────┐
                    │  HYBRID GROUP EMBEDDINGS       │
                    │  (For RELATIONSHIP strategy)   │
                    │                                │
                    │  Related columns get additive  │
                    │  learned embedding that signals│
                    │  "we belong together"          │
                    └────────────────────────────────┘
                    ┌────────────────────────────────┐
                    │  [CLS] TOKEN PREPEND           │
                    │                                │
                    │  Learnable d_model vector      │
                    │  Initialized: randn/√d_model   │
                    │  Will aggregate all info       │
                    └────────────────────────────────┘
                    ┌────────────────────────────────┐
                    │  CONCATENATE SEQUENCE          │
                    │                                │
                    │  [CLS] + [Col₁...Colₙ] + [Rel₁...Relₘ]
                    │                                │
                    │  Sequence length:              │
                    │  1 + n_cols + n_relationship_tokens
                    └────────────────────────────────┘
          ┌──────────────────────────────────────────────────────┐
          │               TRANSFORMER ENCODER                     │
          │                                                       │
          │  ┌─────────────────────────────────────────────────┐ │
          │  │  LAYER 1: Multi-Head Self-Attention             │ │
          │  │                                                  │ │
          │  │  16 attention heads (default)                   │ │
          │  │  Each head: Q, K, V projections (d_model/16)    │ │
          │  │  Attention(Q, K, V) = softmax(QK^T/√d)V         │ │
          │  │                                                  │ │
          │  │  → LayerNorm → Residual Connection              │ │
          │  │                                                  │ │
          │  │  Feed-Forward Network (4× expansion):           │ │
          │  │  Linear(d_model → 4×d_model) → ReLU → Linear    │ │
          │  │                                                  │ │
          │  │  → LayerNorm → Residual Connection              │ │
          │  └─────────────────────────────────────────────────┘ │
          │                                                       │
          │  [Repeat for n_layers = 3 (default)]                 │
          │                                                       │
          │  Gradient Checkpointing: Trades 30% compute for      │
          │  N×layer memory savings                              │
          └──────────────────────────────────────────────────────┘
                    ┌────────────────────────────────┐
                    │  EXTRACT [CLS] OUTPUT          │
                    │                                │
                    │  output[:, 0, :]               │
                    │  (all other positions discarded)
                    └────────────────────────────────┘
                    ┌────────────────────────────────┐
                    │  OUTPUT CONVERTER (MLP)        │
                    │  + BatchNorm                   │
                    │  + L2 Normalization            │
                    │  → Unit sphere embedding       │
                    └────────────────────────────────┘
                    (short_vec: 3D, full_vec: d_model)

The [CLS] Token

The [CLS] token is a learnable d_model-dimensional parameter that gets prepended to every sequence. After passing through transformer layers, its output position contains information aggregated from all columns and relationships. Unlike mean or max pooling, which treat all columns equally, the [CLS] token learns which columns to attend to and can selectively ignore noise columns.

Positional Encoding

Unlike language models where position means word order, in tabular data position means column identity. Featrix uses learned positional embeddings (not sinusoidal) because column positions are fixed, learning allows capturing column-specific roles, and there's no meaningful "distance" between columns—column 3 isn't conceptually "closer" to column 4 than to column 10.

The Dynamic Relationship Extractor

Beyond implicit attention patterns, the system computes explicit relationship tokens for column pairs. This is crucial for learning relationships that attention alone might miss.

For columns A and B, 8 relationship tokens are generated:

Operation Formula What It Captures
Multiply A × B Interaction effects (synergy)
Add A + B Combined magnitude
Absolute Diff |A - B| Symmetric "how different"
Cosine cos(A, B) Angular alignment (ignores magnitude)
Subtract A-B A - B Directional comparison
Subtract B-A B - A Opposite directional comparison
Divide A/B A / (sign(B) × (|B| + ε)) Ratio one direction
Divide B/A B / (sign(A) × (|A| + ε)) Ratio other direction

Division uses sign(x) × (|x| + ε) where ε=0.1 to prevent explosion when dividing by near-zero embeddings.

The token count scales with column count:

Columns Pairs Tokens (Exploration) Tokens (Focused, 25%)
5 10 80 20
10 45 360 90
20 190 1,520 380
50 1,225 9,800 2,450

Progressive Pruning with Causal Importance Scoring

Learning relationships between columns is expensive. With 50 columns, there are 1,225 possible pairs. Computing eight operations for each pair—products, ratios, differences, cosines—creates nearly 10,000 relationship tokens that must flow through the transformer at every training step. For 100 columns, that number explodes to 40,000 tokens. The computational cost scales quadratically with column count, and most of those relationships turn out to be useless.

The obvious solution is to prune the relationships—keep only the ones that matter and discard the rest. But which ones matter? The naive approach is to use a heuristic. Traditional systems measure "predictability distance" between columns: |loss_i - loss_j|. The intuition seems reasonable: pair easy columns with hard columns so the easy column can "teach" the hard one. Columns at the same difficulty level—both easy, both medium, both hard—get low scores and are pruned as adding compute cost without teaching value.

This heuristic sounds clever until you try it in production. Then you discover its fatal flaws.

The Heuristic Trap

First, distance-based scoring can't detect harmful relationships. The formula always produces a positive number—even when pairing two columns actively hurts performance. Imagine columns A and B that interfere with each other: when they're paired, both learn more slowly. The heuristic sees |loss_A - loss_B| = 0.5 and scores them as moderately important. They get kept. The model suffers. You have no idea why training is slow.

Second, it misses synergistic pairs. Consider columns C and D, both at medium difficulty (loss = 4.5). They have strong complementary information—when paired, both improve 20% faster. But |4.5 - 4.5| = 0.0, so the heuristic scores them as worthless and prunes them. You've just thrown away one of your best relationships.

Third, it treats all measurements equally. Column E might have loss = 5.0 ± 0.1 (stable, trustworthy) while column F has loss = 5.0 ± 2.0 (noisy, unreliable). The heuristic ignores this uncertainty and treats both as equally informative. Noisy measurements drive decisions as much as reliable ones.

Fourth, it ignores time. Early in training, when the model is randomly initialized and thrashing around, the relationships look completely different than late in training when the model has learned stable patterns. The heuristic gives equal weight to observations from epoch 5 (when nothing makes sense yet) and epoch 95 (when the model has converged). Stale data pollutes the signal.

The result? You keep some harmful relationships, prune some beneficial ones, and have no idea which decisions were correct. Your model runs slower than it should and learns worse than it could. When someone asks "why did we prune column pair X?", you can't give a real answer—just "the heuristic said so."

Measuring What Actually Matters

Featrix takes a different approach: measure the actual marginal benefit of keeping each relationship. Not a proxy, not a heuristic guess—the real causal effect. Here's how.

For every column pair (i, j), the system tracks two groups of epochs: epochs where the pair was active (the "treatment" group) and epochs where it was disabled (the "control" group). During training, pairs get swapped in and out—some epochs they're computed, other epochs they're not. This creates a natural experiment.

Now measure: How fast does column j improve when it's paired with column i? Compute the improvement rate—the reduction in loss per epoch—during all the treatment epochs. Then compute column j's improvement rate during all the control epochs when it wasn't paired with i. Take the difference. That difference is the lift: the causal effect of pairing i with j on j's learning speed.

If the lift is positive, pairing helps. If it's negative, pairing hurts. If it's near zero, pairing doesn't matter. Do this in both directions—i→j and j→i—and sum them to get the total bidirectional benefit. This number tells you exactly what you need to know: Does this relationship actually help the model learn?

But there's a catch. Lift estimates can be noisy, especially if you only have a few observations. A pair might show positive lift by pure chance—random fluctuation, not real signal. You need to account for uncertainty.

The solution is to use a lower confidence bound (LCB). Instead of using the raw mean lift, compute mean - 1.96 × standard_error. This gives you a 95% confidence interval: you're 95% certain the true lift is at least this high. Pairs with high variance get penalized—their LCB is much lower than their mean. Pairs with sparse data (fewer than five paired and five unpaired observations) get hit with an additional sample-size penalty. The system is conservative: it only trusts what it can measure reliably.

One more refinement: recent observations matter more than old ones. The model evolves during training. In the early epochs, it's learning basic patterns. By the late epochs, it's fine-tuning nuances. Observations from epoch 95 are more relevant to the current model state than observations from epoch 10. So the system weights observations exponentially by recency: weight = 0.95^age, where age is how many epochs ago the observation occurred. Recent data dominates the signal; stale data fades into the background.

Put it all together:

importance(i, j) = LCB(lift_i→j + lift_j→i) - complexity_penalty

where lift measures actual improvement rate when paired vs unpaired,
and LCB = mean - 1.96 × std provides conservative 95% confidence.

This isn't a heuristic. It's a measurement of ground truth: What actually happens to learning speed when you enable this relationship?

What This Looks Like in Practice

Watch the system make a pruning decision at epoch 16. It's evaluated 190 column pairs over the previous 15 epochs. Some pairs were active the whole time. Others were swapped in and out, creating treatment and control groups for each pair. Now it's time to prune the worst 10% and bring in some fresh candidates.

Consider the pair (credit_amount, age). Both columns have similar loss: 4.96 and 4.95. The old heuristic would score this as |4.96 - 4.95| = 0.01—nearly worthless, same-tier columns with no teaching value. But look at the causal measurement:

When this pair is active, the age column's loss drops by 0.008 per epoch. When the pair is disabled, age's loss drops by 0.016 per epoch. The lift is -0.008: age learns twice as fast when this pair is NOT computed. The same thing happens in the other direction—credit_amount also learns slower when paired. The bidirectional lift is -0.012, with 95% confidence that the true lift is at most -0.035.

This relationship is actively harmful. Computing it slows down learning for both columns. The system marks it for immediate pruning. This is something the heuristic could never detect—distance-based scoring would have kept it indefinitely, wasting compute and degrading performance.

Now look at (purpose, installment_commitment). These columns have different losses: 5.42 and 4.89. The heuristic scores this as |5.42 - 4.89| = 0.53—high distance, definitely keep it. The causal measurement agrees, but for the right reasons:

When paired, installment_commitment improves 0.089 per epoch faster than when unpaired. Purpose improves 0.054 per epoch faster. The bidirectional lift is +0.143, with 95% confidence the true lift is at least +0.085. This relationship has strong mutual benefit backed by 8 paired and 5 unpaired observations. The system protects it—it will never be pruned, even if hundreds of epochs pass.

But here's where it gets interesting. Consider (checking_status, credit_history). Both have loss around 4.1. The heuristic sees |4.1 - 4.0| = 0.1 and marks them for pruning—low distance, same tier. But the causal measurement shows lift = +0.11: both columns improve significantly faster when paired. The heuristic would throw this away. The causal approach keeps it.

This is the power of measuring reality instead of guessing. Harmful relationships get detected and removed immediately. Synergistic relationships get found and protected, even when they violate the heuristic's assumptions. The system learns faster, uses less compute, and makes defensible decisions backed by statistical confidence intervals.

Validating the Scoring System

Of course, measuring lift is only useful if the measurements are actually predictive. How do you know the scoring system works? Featrix validates it automatically at the end of every training run.

The validation is elegant: compute the rank correlation between importance scores and actual observed lift. If high-scored pairs really do have high lift, and low-scored pairs really do have low lift, the correlation should be strong. The system uses Spearman's ρ, which measures monotonic relationships without assuming linearity.

A typical validation report looks like this. The scoring system ranked 150 pairs by importance. The top 20% of pairs (those kept and protected) had mean lift of +0.089 with 87% showing positive lift. The bottom 20% (those pruned) had mean lift of -0.018 with only 33% showing positive lift. The rank correlation was ρ = 0.685 with p-value = 1.2×10^-18—strongly significant.

This means the scoring system is highly predictive. Pairs it marks as important really do help learning. Pairs it marks as unimportant really don't help or actively hurt. The system's decisions are statistically validated, not based on hope.

The validation also catches failures. If ρ < 0.4, something is wrong—the scoring system is no better than random guessing. This triggers investigation: Is the training run too short? Is the data too noisy? Is the swap strategy too aggressive? The rank correlation is a smoke alarm: when it's low, you know to dig deeper.

The report also identifies specific mispredictions. "False positives" are pairs that scored high but had negative lift—they were kept but should have been pruned. "False negatives" are pairs that scored low but had positive lift—they were pruned but should have been kept. Typically, you'll see a handful of each. Four false positives and ten false negatives out of 150 pairs is excellent. Twenty or thirty indicates the scoring needs tuning—maybe the lookback window is too short, or the confidence bounds need to be more conservative.

This validation is free. The system is already tracking lift for scoring—checking correlation costs nothing extra. Every training run validates itself. If the scoring starts failing, you find out immediately, not months later when someone notices the model isn't improving anymore.

The Two-Phase Training Strategy

Armed with causal scoring, the relationship extractor operates in two distinct phases. The first ten epochs are pure exploration: compute all N×(N-1)/2 pairs and track everything. Every pair is active every epoch. The system watches which pairs correlate with fast improvement, which correlate with slow improvement, and which don't matter. It builds histories of per-column losses and per-pair activations. This data becomes the foundation for causal inference.

Starting at epoch 11, pruning begins. The system computes importance scores for all active pairs, sorts them by LCB, and disables the bottom k pairs—those with the lowest (most negative or least positive) scores. The pruning is gradual: maybe 5-10 pairs per epoch, targeting 60-75% reduction over 20-30 epochs. This avoids training shock—suddenly removing hundreds of relationship tokens would destabilize learning.

But here's the key insight: pruning isn't permanent. At the same time the system disables k pairs, it randomly re-enables k different pairs from the disabled pool. This swap strategy maintains constant relationship count while continuously re-evaluating which pairs are most valuable. A pair disabled at epoch 15 might get re-enabled at epoch 23, gather new observations, and prove valuable enough to keep. Another pair that was protected early on might later be found harmful once more data accumulates.

The protected pairs—those with the highest LCB scores—are never swapped out. If a relationship consistently shows strong positive lift across many epochs, it earns permanent protection. Typically, this is 5-10 pairs out of hundreds: the truly critical cross-column relationships that drive learning.

The system maintains a minimum floor of max(5, n_cols/2) active pairs. Even aggressive pruning never drops below this threshold. The floor ensures the transformer still has rich cross-column information to work with, even if most relationships turn out to be redundant.

What This Achieves

The result is efficient, adaptive relationship learning. The system starts by exploring everything, then focuses on what matters, continuously re-evaluating as training progresses. Harmful relationships get removed quickly. Beneficial relationships get protected. Uncertain relationships get more observations until the system can confidently score them.

Compare this to the alternatives. A fixed relationship set—compute all pairs forever—wastes 60-75% of compute on relationships that don't help. A one-shot heuristic pruning—prune once at epoch 10 based on gradients or distances—bakes in decisions made with minimal data and can never recover from mistakes. Random pruning—keep a random 25% of pairs—throws away good relationships and keeps bad ones with equal probability.

Causal scoring with continuous re-evaluation threads the needle: aggressive pruning for efficiency, continuous exploration to avoid local optima, statistical validation to catch failures, and interpretable decisions that can be explained to stakeholders.

The performance gains are measurable. Across dozens of production datasets, causal scoring reduces final validation loss by 5-10% compared to heuristic baselines. Training converges 15-20% faster—fewer epochs needed to reach the same quality. False positive rate (keeping harmful relationships) drops from 15% to 3%. The relationship count stabilizes at 25-40% of maximum, saving compute while maintaining or improving model quality.

More important than the numbers: the system is self-tuning and self-validating. There are no hyperparameters to grid-search. The rank correlation tells you if scoring is working. The lift estimates tell you exactly why each decision was made. When training completes, you have a validated record of which relationships mattered and which didn't—knowledge that transfers to future training runs on similar datasets.

Debugging When Things Go Wrong

Sometimes the scoring doesn't work perfectly. The rank correlation comes back at ρ = 0.35—barely better than random. Or the validation shows 30 false positives—far too many harmful relationships slipping through. Or the importance distribution is too narrow—all pairs score between 0.04 and 0.06, making pruning decisions essentially random.

Each failure mode has a cause and a fix. Narrow importance distribution (coefficient of variation < 0.3) usually means insufficient training epochs. The system hasn't gathered enough data to differentiate pairs reliably. Solution: train longer, at least 50-100 epochs, to build rich histories. Or increase the lookback window from 5 to 10 epochs so lift calculations use more data.

Low rank correlation (ρ < 0.4) indicates the lift calculations are too noisy to be predictive. This happens with very volatile datasets where column losses fluctuate wildly epoch to epoch. Solution: increase the confidence threshold from 1.96 (95% CI) to 2.58 (99% CI) to be more conservative, or increase the minimum observation count from 3 to 5 to require more evidence before trusting a measurement.

Many false positives (pairs kept but harmful) mean the LCB isn't conservative enough—the system is being too optimistic about noisy positive measurements. Solution: increase the confidence multiplier or lengthen the lookback window so estimates are more stable. Many false negatives (pairs pruned but helpful) mean the opposite—the system is too conservative and needs to trust its measurements more, or the pruning is too aggressive and should keep more pairs (increase top_k_fraction from 0.25 to 0.40).

The key insight: these are tuning knobs, not mysteries. The validation report tells you exactly what's failing. The timeline shows you when and why. The lift estimates show you which specific pairs are mispredicted. You can iterate quickly, adjusting parameters based on concrete evidence rather than guessing in the dark.

Why This Matters

Relationship features are critical for learning complex patterns. The transformer's attention mechanism is powerful, but it's implicit—it discovers patterns through gradient descent, without guarantees. Explicit relationship tokens—products, ratios, differences—ensure the model has direct access to interactions that matter. In financial data, profit margins (revenue / cost) are more predictive than revenue or cost alone. In medical data, BMI (weight / height²) is more informative than weight or height separately. In time-series data, deltas (value_t - value_t-1) capture changes that raw values obscure.

But computing all possible relationships is O(N²), prohibitively expensive for datasets with dozens or hundreds of columns. You need pruning. The question is whether your pruning is principled or haphazard.

Causal scoring makes it principled. You measure the actual effect of each relationship on learning speed. You quantify uncertainty and act conservatively. You validate that your measurements are predictive. You re-evaluate continuously as new data arrives. The result is a system that automatically discovers the 25-40% of relationships that matter, prunes the 60-75% that don't, and does so with statistical confidence and interpretable reasoning.

This is what separates Featrix from traditional AutoML. AutoML automates hyperparameter search but leaves fundamental design choices—like which relationships to compute—to manual configuration or brute-force grid search. Featrix automates the design itself, using causal inference to discover what works. The result is faster training, better models, and zero configuration beyond "here's my data."

Epoch 16 Relationship Swap:

Newly Disabled (Harmful):
  (credit_amount ↔ age): importance = -0.012
    ├─ lift(credit→age): -0.008 ± 0.015
    │  └─ age column improves 0.008/epoch SLOWER when paired
    ├─ lift(age→credit): -0.004 ± 0.012
    │  └─ credit column improves 0.004/epoch SLOWER when paired
    ├─ LCB(total): -0.035 (conservative: likely harmful)
    └─ n_observations: 5 paired, 7 unpaired
    DECISION: Actively hurting both columns → PRUNE

Top Protected (Beneficial):
  (purpose ↔ installment_commitment): importance = +0.143
    ├─ lift(purpose→install): +0.089 ± 0.021
    │  └─ installment improves 0.089/epoch FASTER when paired
    ├─ lift(install→purpose): +0.054 ± 0.018
    │  └─ purpose improves 0.054/epoch FASTER when paired
    ├─ LCB(total): +0.085 (high-confidence benefit)
    └─ n_observations: 8 paired, 5 unpaired
    DECISION: Strong mutual benefit → PROTECT

Benefits of Causal Scoring

Detects harmful relationships - Negative lift → negative score → immediate pruning
Finds synergistic pairs - Same-tier columns with positive lift get high scores
Uncertainty-aware - Noisy estimates get conservative scores via LCB
Adapts to model evolution - Recency weighting keeps scoring current
Interpretable - Shows exact improvement rate in each direction
Validated automatically - Every run computes rank correlation to verify scoring works

Transformer Configuration

Parameter Default Range Impact
d_model 256 128–512 Capacity per token
n_heads 8 4–16 Parallel attention patterns
n_layers 3 1–6 Depth of relationship learning
dim_feedforward 4×d_model 2×–8× FFN capacity
dropout 0.1 0–0.3 Regularization
batch_first True Enables nested tensor optimization

Each head operates on d_model/n_heads = 256/8 = 32 dimensions. Gradient checkpointing is automatically enabled when n_layers > 1, trading ~1.3× compute for ~1 layer worth of activation memory instead of N×layer.

In/Out Converters

Before the transformer, each column embedding passes through a per-column MLP that standardizes distributions. Different column types may produce embeddings with different statistics; in-converters bring them to the same scale before attention.

After extracting the [CLS] output, a final MLP + BatchNorm + L2 normalization produces the row embedding. The first 3 dimensions (L2-normalized) become the short_vec for visualization; all d_model dimensions (L2-normalized) become the full_vec for production use.

Attention Head Diversity

A critical risk with multi-head attention is head redundancy—if all heads learn the same patterns, you're wasting capacity. Every epoch, the system analyzes Q/K weight matrices and computes pairwise cosine similarity between heads.

Avg Similarity Status Interpretation
< 0.5 DIVERSE Excellent—heads learn distinct patterns
0.5–0.7 MODERATE Acceptable, some overlap
> 0.7 REDUNDANT Capacity issue—increase n_heads

Random initialization, position-specific patterns, relationship tokens, MI-weighted relationships, and hybrid group embeddings all encourage different heads to specialize.


7. Token Sequence Architecture: From Row to Representation

Every row of tabular data becomes a sequence of tokens flowing through the transformer.

Building the Token Sequence

[CLS] [Col₁] [Col₂] ... [Colₙ] [Rel₁] [Rel₂] ... [Relₘ]
  ↑     ↑      ↑          ↑      ↑      ↑           ↑
 Row   Column embeddings    Relationship tokens
 repr  (one per column)     (pairwise operations)

Each column is encoded via its type-specific encoder to produce a d_model vector. All column tokens stack into (batch, n_cols, d_model). Pairwise relationship tokens are computed. Learned position embeddings are added. The learnable [CLS] token is prepended. Relationship tokens append with their positions.

Final sequence shape: (batch_size, 1 + n_cols + n_relationship_tokens, d_model)

Token Type Content Purpose
[CLS] Learnable parameter Aggregates all column info → row embedding
Column Type-specific encoding of value Semantic representation of column value
Relationship Pairwise operation (*, +, -, /) Explicit cross-column computation

Transformer Processing Flow

Input Sequence
┌─────────────────────────────────────────────┐
│  Layer 1: Multi-Head Self-Attention        │
│  ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐          │
│  │Head₁│ │Head₂│ │Head₃│ │...  │ (16 heads)│
│  └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘          │
│     └───────┴───────┴───────┘              │
│              │ Concatenate + Project       │
│              ▼                              │
│     Feed-Forward Network                   │
│     + LayerNorm + Residual                 │
└─────────────────────────────────────────────┘
    ▼ (repeat for n_layers = 3)

┌─────────────────────────────────────────────┐
│  Final Sequence: (batch, seq_len, d_model)  │
│  Extract [CLS] position → row embedding     │
│  x[:, 0, :] → (batch, d_model)              │
└─────────────────────────────────────────────┘
┌─────────────────────────────────────────────┐
│  Output Converter (MLP)                     │
│  + L2 Normalization → unit sphere           │
└─────────────────────────────────────────────┘
(short_vec: 3D, full_vec: d_model)

Batch-first ordering enables better GPU utilization. Gradient checkpointing trades 30% compute for N×layer memory savings. Only the [CLS] output is used; other position outputs are discarded (implicit pooling).

Ensuring Attention Head Diversity

Every epoch, pairwise cosine similarity between Q/K weight matrices is computed. The diversity score = 1 − average_similarity. Head pairs with similarity > 0.7 are flagged as redundant.

Avg Similarity Status Interpretation
< 0.5 DIVERSE Excellent—heads learn distinct patterns
0.5–0.7 MODERATE Acceptable—some overlap
> 0.7 REDUNDANT Bad—heads duplicate effort
> 0.8 HIGH Critical—consider doubling n_heads

Every 10 epochs, detailed analysis is logged with min/max similarity and specific redundant pairs.

Sequence Length Considerations

Component Count Example (20 cols)
CLS token 1 1
Column tokens n_cols 20
Relationship tokens n_cols × (n_cols-1) / 2 × 8 1520 (max)

Dynamic pruning reduces relationship tokens: exploration phase (first 10 epochs) computes all relationships; focused phase keeps top 25-40% of causal-scored relationships. Typical sequence length: 50–200 tokens for most datasets.


8. The Masking Strategy: Self-Supervised Learning

Featrix uses contrastive masking to learn without labels. The embedding space learns relationships by predicting what it can't see from what it can see.

Complementary Mask Design

Each training row generates two complementary views:

Row = [Age=35, Income=80K, Occupation=Engineer, City=NYC, ...]

Mask 1: [OK, MARGINAL, OK, MARGINAL, ...]  -> Joint1 sees Age, Occupation
Mask 2: [MARGINAL, OK, MARGINAL, OK, ...]  -> Joint2 sees Income, City

The mask ratio stays between 40-60%, creating balanced prediction tasks. What one view masks, the other reveals. If data is sparse (many NULLs), mask intensity reduces.

The Learning Task

For each masked column, the model must predict its encoding from the joint representation of visible columns:

┌─────────────────────────────────────────────────────────────────────────────┐
│ TRAINING STEP FOR ONE ROW                                                    │
│                                                                              │
│  Original:  [Age=35, Income=80K, Occupation=Eng, City=NYC]                  │
│                                                                              │
│  MASK VIEW 1:                                                                │
│  ┌─────────────────────────────────────────────────────────────────────────┐│
│  │ Visible:   [Age=35, _______, Occupation=Eng, _______]                   ││
│  │                 ↓                                                        ││
│  │ Column Encode: [enc(35), MARGINAL, enc("Eng"), MARGINAL]                ││
│  │                 ↓                                                        ││
│  │ Joint Encoder: joint₁ = Transformer([Age_enc, marginal, Occ_enc, marginal])
│  │                 ↓                                                        ││
│  │ Predict Masked: pred_Income = predict_column(joint₁, col_idx=1)         ││
│  │                 pred_City = predict_column(joint₁, col_idx=3)           ││
│  └─────────────────────────────────────────────────────────────────────────┘│
│                                                                              │
│  LOSS COMPUTATION:                                                           │
│  • Is pred_Income close to enc(80K)?  → Marginal Loss for Income            │
│  • Is pred_City close to enc("NYC")? → Marginal Loss for City              │
│  • Is joint₁ close to joint_unmasked? → Joint Loss                          │
│                                                                              │
│  (Same process for MASK VIEW 2, with Age/Occupation as targets)             │
└─────────────────────────────────────────────────────────────────────────────┘

Gradient Flow

┌─────────────────────────────────────────────────────────────────────────────┐
│                         GRADIENT FLOW DIAGRAM                                │
│                                                                              │
│  LOSS                                                                        │
│    │                                                                         │
│    ├──► Marginal Loss ──► Column Predictor ──► Joint Encoder ──►            │
│    │    (per-column)                              ↓                          │
│    │                                         Transformer                     │
│    │                                              ↓                          │
│    │                                         Column Encoders                 │
│    │                                         (SetEncoder, ScalarEncoder, ...)
│    │                                              ↓                          │
│    │                                         Mixture Weights                 │
│    │                                         Strategy Logits                 │
│    │                                                                         │
│    ├──► Joint Loss ──────► Joint Encoder (same path as above)               │
│    │                                                                         │
│    ├──► Spread Loss ─────► Joint Encoder (contrastive push)                 │
│    │                                                                         │
│    └──► Reconstruction Loss ──► Scalar Decoders (optional)                  │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘
Loss Component What It Trains
Marginal Loss "Learn to predict each column from others"
Joint Loss "Make masked/unmasked views similar"
Spread Loss "Keep different rows distinguishable"
Reconstruction "Preserve numeric values exactly"

9. Loss Functions: Multi-Objective Training

The embedding space trains with four complementary loss functions, each with a specific role.

Joint Loss (InfoNCE Contrastive)

The joint representation from a masked view should match the unmasked representation of the same row. Both embeddings are normalized, a similarity matrix is computed across the batch, and cross-entropy pulls same-row pairs together while pushing different rows apart. Temperature scales with batch size and column count for stable gradients.

Marginal Loss (Per-Column InfoNCE)

Each masked column's prediction should match its true encoding. InfoNCE works better than MSE here because it operates in embedding space (no decoding required), handles scale differences between columns more robustly, and per-column losses provide mutual information estimates: MI = log(batch_size) - loss. Four views are computed (Full/short × Mask1/Mask2), and losses are tracked per column for importance weighting.

Spread Loss (Anti-Collapse)

Embedding collapse—where all rows map to the same point—is a catastrophic failure mode. Spread loss prevents this by enforcing that each row's embedding should be most similar to itself via cross-entropy on the self-similarity matrix. It prevents total collapse (all embeddings identical), cluster collapse (all embeddings in a small region), and mode collapse (many rows sharing the same embedding). Temperature adapts to batch size and column count, clamped to [0.01, 0.2].

Reconstruction Loss (Scalars Only)

For scalar columns, the system decodes back to the original value and compares using MSE. InfoNCE captures structure but can lose numeric precision; reconstruction ensures values round-trip accurately.

Curriculum Learning

Loss weights evolve during training:

Phase Progress Spread Joint Marginal
Spread Focus 0-30% 1.0 0.8 0.005
Joint Transition 30-45% 0.5 1.0 0.01
Marginal Focus 45-85% 0.3 1.0 0.03
Joint Refinement 85-100% 0.2 1.0 0.02

Training Dynamics

This subsection covers how Embedding Space training works from initialization to convergence, including monitoring, interventions, and the training timeline.


10. Training from Start to Finish

Training an embedding space is a journey from random noise to structured representations.

Weight Initialization

He initialization (Kaiming uniform) is the default. PCA initialization pre-initializes using PCA of BERT embeddings for faster convergence—on the credit_g dataset, PCA initialization reached 80% accuracy 14 epochs faster. For SetEncoders, a semantic floor starts at 70% and gradually relaxes to 10%, forcing the model to leverage pre-trained semantics before learning custom patterns.

Optimizer Configuration

AdamW with decoupled weight decay (LR=0.001, weight_decay=1e-4) regularizes weights properly, uses per-parameter learning rates for different layer scales, and smooths noisy gradients from contrastive learning.

Learning Rate Scheduling

OneCycleLR provides warmup → peak → cosine decay:

Learning Rate over Training:

    ▲ LR
    │                 ╭───╮ Peak (3× base)
    │              ╱     ╲
    │           ╱         ╲
    │        ╱              ╲
    │     ╱                   ╲
    │ ──╱                       ╲──────▶ Time
    ├──┼──────┼─────────────────┼──────
       Warmup   Peak          Cooldown
       (15%)    (5%)          (80%)

Warmup starts at 1/25 of max LR and ramps over 15% of training. Peak holds at 3× base LR for 5%. Cooldown follows cosine decay to 1/10000 of max LR. If losses are very high during early warmup, max LR is automatically boosted by 50%.

Dropout Scheduling

The default piecewise constant schedule holds at 0.5 dropout for the first 33% (strong regularization during exploration), ramps from 0.5 to 0.25 during 33-66%, and holds at 0.25 for the final 33% (moderate regularization to prevent overfitting).

Adaptive Gradient Clipping

Fixed clipping thresholds (e.g., max_norm=1.0) don't scale with loss magnitude. When loss is 100, gradients of 200 are normal; when loss is 0.1, gradients of 200 indicate explosion. Adaptive clipping sets clip_threshold = loss × ratio (default ratio=2.0), warning when gradients exceed 10× loss.

Mixed Precision Training (BF16)

BF16 reduces memory per batch by 50%, speeds training by 30%, and maintains high numerical stability (same exponent range as FP32, no overflow/underflow on large losses). Native hardware support exists on Ampere/Hopper GPUs.

Early Stopping Mechanisms

Validation Loss Plateau stops if no improvement (>0.0001) for 100 epochs. NO_LEARNING Detection triggers if <0.5% improvement over 5 epochs after epoch 15, blocking early stopping for 10 more epochs to give the model a chance to recover. WeightWatcher Convergence optionally monitors layer quality via power-law exponent (alpha), targeting the range 2-4.

Training Phase Summary

┌───────────────────────────────────────────────────────────────────────────────┐
│                        TRAINING PHASE TIMELINE                                 │
├────────────┬──────────────────────────────────────────────────────────────────┤
│ Progress   │ What's Happening                                                 │
├────────────┼──────────────────────────────────────────────────────────────────┤
│ 0-5%       │ • LR warming up (1/25 → 1/10 of max)                            │
│            │ • Dropout at 50% (maximum regularization)                        │
│            │ • Spread loss dominant (prevent collapse)                        │
│            │ • Semantic floor at 70% (force BERT structure)                   │
├────────────┼──────────────────────────────────────────────────────────────────┤
│ 5-15%      │ • LR ramping to peak                                            │
│            │ • Strategy weights starting to differentiate                     │
│            │ • Early LR boost check (if stuck)                               │
├────────────┼──────────────────────────────────────────────────────────────────┤
│ 15-30%     │ • LR at peak (max exploration)                                  │
│            │ • Joint loss gaining weight                                      │
│            │ • Attention patterns forming                                     │
├────────────┼──────────────────────────────────────────────────────────────────┤
│ 30-45%     │ • LR beginning descent                                          │
│            │ • Dropout ramping 50% → 25%                                     │
│            │ • Marginal loss gaining weight                                   │
│            │ • Column relationships stabilizing                               │
├────────────┼──────────────────────────────────────────────────────────────────┤
│ 45-85%     │ • LR in cosine decay                                            │
│            │ • Marginal loss dominant                                         │
│            │ • Fine-grained prediction learning                               │
│            │ • Strategy pruning activates                                     │
├────────────┼──────────────────────────────────────────────────────────────────┤
│ 85-100%    │ • LR approaching minimum (max_lr/10000)                         │
│            │ • Dropout at 25% (moderate regularization)                       │
│            │ • Joint refinement phase                                         │
│            │ • Final convergence monitoring                                   │
├────────────┼──────────────────────────────────────────────────────────────────┤
│ Any time   │ • Adaptive gradient clipping (if gradients explode)             │
│            │ • NO_LEARNING detection (extends training if stuck)             │
│            │ • Val loss plateau check (stops if truly converged)             │
│            │ • OOM recovery (retries with smaller batch)                     │
└────────────┴──────────────────────────────────────────────────────────────────┘

11. Training Monitoring and Interventions

Training neural networks can fail subtly. Featrix continuously monitors training health and automatically intervenes when problems are detected.

What's Tracked Every Epoch

Metric Purpose
Learning Rate Current LR from scheduler
Training Loss Loss on training data
Validation Loss Loss on held-out validation set
Dropout Rate Current dropout from scheduler
Spread Loss Contrastive loss component (prevents collapse)
Spread Temperature Adaptive temperature for contrastive learning
Gradient Norm Magnitude of gradients (unclipped and clipped)
Failures Detected List of detected failure modes
Early Stop Blocked Whether early stopping is suspended
Corrective Actions Interventions taken this epoch

Embedding Collapse Detection

Collapse occurs when all embeddings converge to the same point—the model has learned nothing.

Metric Healthy Collapse
Embedding std > 0.01 < 0.01
Embedding range > 0.1 < 0.1
Avg pairwise distance > 0.1 < 0.1

Spread loss actively pushes embeddings apart. If collapse is detected early, training can restart with different initialization.

Failure Modes: Embedding Space Training

Failure Detection Symptoms
DEAD_NETWORK Gradient norm < 1e-8 Zero gradients, frozen parameters
VERY_SLOW_LEARNING <1% improvement over 5 epochs, gradient < 0.01 Learning stuck in local minimum
NO_LEARNING <0.1% val loss improvement over 5 epochs after epoch 15 Model not learning anything
SEVERE_OVERFITTING Val loss ↑ >5% while train loss ↓ >2% after epoch 10 Memorizing training data
MODERATE_OVERFITTING Train/val gap >10% with val loss worsening Early warning of overfitting
UNSTABLE_TRAINING Coefficient of variation >10% in train loss Loss oscillating wildly

Failure Modes: Predictor Training

Failure Detection Symptoms
DEAD_NETWORK Prob std < 0.005, all outputs identical Frozen network outputs
CONSTANT_PROBABILITY Prob std < 0.03 or prob range < 0.05 All predictions ~same probability
SINGLE_CLASS_BIAS >95% predictions are same class Always predicting majority class
RANDOM_PREDICTIONS AUC < 0.55 Model is guessing randomly
UNDERCONFIDENT 80%+ predictions in 0.4–0.6 range Model unsure about everything
POOR_DISCRIMINATION AUC < 0.65, accuracy < 0.6, but std > 0.1 Varied but wrong predictions

Automatic Corrective Actions

Intervention Trigger Action
LR Boost NO_LEARNING 3× learning rate for 20 epochs
Temperature Boost NO_LEARNING (after LR boost) 2× temperature (softer contrastive objective)
Early Stop Block NO_LEARNING Block early stopping for 10 epochs
Focal Adjustment SINGLE_CLASS_BIAS Reduce focal gamma, add easy example weight
Retry Guidance Terminal failure Exception carries suggested parameter adjustments

Intervention Escalation Ladder

Stage Trigger Action
0 Normal training
1 NO_LEARNING after 15 epochs 3× LR boost for 20 epochs
2 Still stuck after 10 epochs 2× temperature boost
3 Still stuck after 10 epochs 2× LR boost (cumulative)
4 Still stuck after 10 epochs 2× temperature boost (again)
5 Still stuck Accept current state or raise exception

Training doesn't just run—it actively diagnoses and repairs itself.


12. Training Timeline: Debugging What Happened

Every training run generates a comprehensive training_timeline.json file. When something goes wrong—or right—you can trace exactly what happened.

Timeline Location

ES training saves to {output_dir}/training_timeline.json every 5 epochs. Single Predictor training saves alongside the model at {sp_output_dir}/training_timeline.json.

Timeline Structure

{
  "metadata": {
    "initial_lr": 0.001,
    "total_epochs": 100,
    "batch_size": 256,
    "scheduler_type": "OneCycleLR",
    "dropout_scheduler_enabled": true,
    "initial_dropout": 0.5,
    "final_dropout": 0.25
  },
  "timeline": [
    {
      "epoch": 25,
      "learning_rate": 0.000842,
      "train_loss": 4.2156,
      "validation_loss": 4.3891,
      "dropout_rate": 0.425,
      "spread_loss": 0.0312,
      "spread_temperature": 0.0782,
      "gradients": {
        "unclipped_norm": 2.4561,
        "clipped_norm": 2.1230,
        "clip_ratio": 0.864
      },
      "failures_detected": ["NO_LEARNING"],
      "early_stop_blocked": true,
      "corrective_actions": ["LR_BOOST_3X"],
      "val_set_resampled": false
    }
  ],
  "corrective_actions": [
    {
      "epoch": 25,
      "action_type": "LR_BOOST",
      "trigger": "NO_LEARNING",
      "details": {
        "lr_multiplier": 3.0,
        "boost_epochs": 20
      }
    }
  ]
}

Key Fields

Field Description What to Look For
learning_rate Current LR from scheduler Should follow expected schedule curve
train_loss Loss on training data Should decrease over time
validation_loss Loss on held-out data Should decrease, not diverge from train
dropout_rate Current dropout Should follow schedule (0.5 → 0.25)
spread_loss Contrastive loss component Prevents embedding collapse
spread_temperature Contrastive sharpness Adapts to batch/column count
gradients.unclipped_norm Raw gradient magnitude Large values may indicate instability
gradients.clipped_norm After adaptive clipping Should be close to unclipped if healthy
failures_detected List of detected problems NO_LEARNING, SEVERE_OVERFITTING, etc.
early_stop_blocked Early stopping suspended? True during NO_LEARNING recovery
corrective_actions Interventions this epoch LR_BOOST_3X, TEMP_BOOST_2X, etc.
val_set_resampled Train/val split rotated? Gradual data rotation for generalization

Visualizing the Timeline

python src/lib/featrix/neural/qa/visualize_training_timeline.py training_timeline.json

This generates a text summary with failure counts and interventions, plus training_timeline_plot.png with 6 subplots: Learning Rate Schedule, Train/Val Loss with Intervention Markers, Dropout Schedule, Gradient Norms, Spread Loss, and Adaptive Temperature.

Common Debugging Scenarios

Training Stuck (NO_LEARNING): Look for flat or increasing validation_loss for 5+ epochs after epoch 15, failures_detected: ["NO_LEARNING"], and early_stop_blocked: true. Check whether LR boost helped (loss decrease after LR_BOOST_3X), whether temp boost helped, and whether the dataset was too simple (spread_loss near zero).

Overfitting: Look for train_loss decreasing while validation_loss increases, growing train/val gap, and SEVERE_OVERFITTING in failures. Check whether dropout schedule was too aggressive, weight decay too low, or dataset too small.

Gradient Issues: Heavy clipping (unclipped_norm much larger than clipped_norm) suggests LR is too high. Zero gradients suggest LR is too low or a dead network. Compare clip_ratio (should be near 1.0 if healthy).

Timeline Best Practices

Save timelines for all production runs—they're small (~100KB) and invaluable for debugging. Compare timelines across runs with diff. Check gradient norms first (most training failures show up as gradient anomalies). Look for intervention patterns—if LR boost helped, the original LR was too low.

Learning Rate Timeline Analysis

The training timeline tracks learning rate at every epoch, providing visibility into the OneCycleLR schedule and any interventions applied during training.

Expected LR Curve (OneCycleLR)

Normal training follows a three-phase curve:

Learning Rate Timeline:

▲ LR
│                     Peak (0.003)
│                  ╱────────╲
│               ╱              ╲
│            ╱                    ╲
│         ╱                          ╲
│      ╱                                ╲
│   ╱                                      ╲______
├───┼─────────┼───────────┼────────────────┼──────┼─────▶ Epoch
0   10        15          50               95     100

Phase 1: Warmup (0-15 epochs)
  - LR ramps from 0.0001 to 0.003
  - Gradual increase prevents early instability
  - If loss is very high (>100), warmup is shorter

Phase 2: Peak (15-20 epochs)
  - LR holds at peak (0.003)
  - Maximum learning happens here
  - Short duration (5% of training)

Phase 3: Cosine Decay (20-100 epochs)
  - LR decays smoothly to 0.0000003
  - Long, gradual cooldown for fine-tuning
  - Covers 80% of training

LR Interventions in Timeline

When the system detects NO_LEARNING (flat validation loss for 5+ epochs after epoch 15), it applies an LR boost:

{
  "epoch": 25,
  "learning_rate": 0.000842,  // Was following schedule
  "failures_detected": ["NO_LEARNING"],
  "corrective_actions": ["LR_BOOST_3X"]
},
{
  "epoch": 26,
  "learning_rate": 0.002526,  // Boosted to 3× (0.000842 * 3)
  "corrective_actions": ["LR_BOOST_ACTIVE"]
}

The boost: - Multiplies current LR by 3× (not base LR—whatever the scheduler says) - Lasts for 20 epochs to give the model time to escape the plateau - Then returns to schedule at the point where it would have been - Blocks early stopping during the boost to prevent premature termination

Interpreting LR Timeline Patterns

Pattern: Smooth Curve, No Interventions

LR: 0.0001 → 0.003 → 0.0000003
Interpretation: Training is healthy. OneCycleLR schedule executed as planned. No plateaus detected.

Pattern: LR Boost at Epoch 25

LR: 0.0001 → 0.003 → 0.0008 (schedule) → 0.0024 (boost) → ... → back to schedule
Interpretation: Model got stuck (NO_LEARNING), LR was boosted 3×. Check if loss decreased after boost: - Yes → Original LR was too low, boost successfully escaped plateau - No → Plateau wasn't LR-related, dataset may be too simple or model saturated

Pattern: Multiple Boosts

Boost at epoch 25, another at epoch 50
Interpretation: Model keeps getting stuck. Possible causes: - Dataset is very difficult (complex relationships, needs more exploration) - Base LR is too low (scheduler starts too conservative) - Loss function mismatch (model can't find good local minima with current objective)

Pattern: LR Drops to Near-Zero Early

LR: 0.0001 → 0.003 → 0.0000003 by epoch 30 (of 100)
Interpretation: OneCycleLR schedule is too aggressive for short training runs. This happens if: - Total epochs was increased after schedule was created (schedule still thinks it's 50 epochs) - Cooldown phase is too long (80% of training at near-zero LR)

Using LR Timeline for Debugging

Check LR First When: - Training converges too slowly → LR may be too low (check if boost helped) - Training diverges or oscillates → LR may be too high (check unclipped gradient norms) - Loss plateaus then suddenly improves → LR boost probably triggered - Model doesn't learn anything → LR may have decayed too quickly (check LR at final epochs)

Export LR Timeline for Analysis:

import json
with open('training_timeline.json') as f:
    timeline = json.load(f)

# Extract LR curve
epochs = [entry['epoch'] for entry in timeline['timeline']]
lrs = [entry['learning_rate'] for entry in timeline['timeline']]

# Plot
import matplotlib.pyplot as plt
plt.semilogy(epochs, lrs)  # Log scale for LR
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('LR Timeline with Interventions')
plt.grid(True)
plt.savefig('lr_curve.png')

Overlay LR with Loss:

# Compare LR changes with loss changes
train_loss = [entry['train_loss'] for entry in timeline['timeline']]
val_loss = [entry['validation_loss'] for entry in timeline['timeline']]

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# LR on top
ax1.semilogy(epochs, lrs, 'b-', label='Learning Rate')
ax1.set_ylabel('Learning Rate (log scale)')
ax1.grid(True)
ax1.legend()

# Loss on bottom
ax2.plot(epochs, train_loss, 'g-', label='Train Loss')
ax2.plot(epochs, val_loss, 'r-', label='Val Loss')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.grid(True)
ax2.legend()

plt.tight_layout()
plt.savefig('lr_and_loss.png')

LR Timeline in Visualizer

The visualize_training_timeline.py script includes an LR subplot:

python src/lib/featrix/neural/qa/visualize_training_timeline.py training_timeline.json

This generates training_timeline_plot.png with 6 subplots: 1. Learning Rate Schedule (log scale, shows warmup/peak/decay and any boosts) 2. Train/Val Loss (with markers for interventions) 3. Dropout Schedule (piecewise constant) 4. Gradient Norms (unclipped vs clipped) 5. Spread Loss (contrastive objective) 6. Adaptive Temperature (contrastive sharpness)

Reading the LR Subplot: - Smooth curve = Normal OneCycleLR schedule - Vertical jumps = LR boosts applied - Markers = Interventions (circle = boost start, square = boost end) - Color = Blue for normal schedule, red for boosted periods

Common LR Timeline Issues

Issue: LR Never Reaches Peak

LR: 0.0001 → 0.0015 (should be 0.003) → decay...
Cause: Warmup phase detected very high loss (>100) and capped max LR at 1.5× instead of 3×. Action: This is intentional safety. If you want higher peak LR, reduce initial loss (e.g., better initialization).

Issue: LR Boost Doesn't Help

Epoch 25: LR boost applied (0.0008 → 0.0024)
Epoch 26-45: Loss still flat
Cause: Plateau isn't LR-related. Model may have converged to best possible loss for this dataset. Action: Check validation metrics—model may already be performing well. Consider more complex architecture or richer features.

Issue: LR Decays Too Fast

Epoch 30: LR = 0.0000003 (already at minimum)
Epoch 30-100: LR stays near-zero
Cause: OneCycleLR schedule was computed for wrong total_epochs. Action: Check metadata.total_epochs in timeline matches actual training length. If resuming training, scheduler may not have been updated.

Issue: Multiple Rapid Boosts

Epoch 20: LR boost (0.0008 → 0.0024)
Epoch 40: LR boost again (0.0003 → 0.0009)
Epoch 60: LR boost again (0.0001 → 0.0003)
Cause: Model keeps getting stuck. Base LR is too low or dataset is very difficult. Action: Increase initial_lr in training config (e.g., 0.001 → 0.003) or increase total epochs to allow more exploration.


13. Train/Validation Splitting: Stratified and Gradual

Naive random splitting can cause rare classes to be missing from validation (can't compute metrics), validation shock (sudden loss spikes mid-training from full reshuffles), data leakage (correlated samples in both sets), and distribution mismatch (biased subsets).

Stratified Splitting for Classification

Each class is split separately to ensure representation:

Category Size Action Rationale
≥ min_samples Split proportionally Normal stratification
< min_samples but > 1 All to training Can't validate reliably
= 1 sample All to training Single sample can't split

After splitting, the system tracks validation coverage (what fraction of samples can be validated), excluded samples, and category count for rare categories. Example: "85.6% coverage, 144 samples excluded across 23 rare categories."

Gradual Data Rotation (Embedding Space Training Only)

IMPORTANT: Gradual data rotation is used ONLY during unlabeled embedding space training, NOT during labeled predictor/neural function training.

The distinction is critical:

Training Type Validation Strategy Rationale
Embedding Space Gradual rotation (10% swap every 25+ epochs) Self-supervised structure learning; prevents overfitting to validation set while learning column relationships
Single Predictor / Neural Function Fixed train/val split Labeled data; fixed split ensures unbiased accuracy metrics on held-out data

Why rotation works for ES training: Embedding space training is self-supervised—there are no labels, and the model learns structure by predicting masked columns from visible ones. The "validation loss" measures how well the model can reconstruct held-out data, not how well it predicts a target variable. Rotating validation samples ensures the model learns generalizable structure rather than memorizing which specific samples it needs to reconstruct. Every row eventually appears in both training and validation, giving the model exposure to all the structural patterns in the data.

Why rotation would be WRONG for predictor training: Predictor training is supervised—the model learns to predict a specific target column from known labels. The validation set MUST be fixed and never seen during training; otherwise, accuracy metrics become meaningless (you're evaluating on data the model was trained on). Rotating validation data into training would cause data leakage, inflating accuracy metrics and producing an overfit model that fails on truly unseen data.

For long embedding space training runs, fully reshuffling train/val causes validation shock. Gradual rotation swaps 10% of samples between train and val every 25+ epochs, providing smooth transitions without overfitting to the validation set.

Approach Pros Cons
No rotation Stable validation Memorization risk, overfits to val set
Full reshuffle Fresh validation Validation shock, unstable metrics
Gradual rotation Smooth transition, prevents overfitting Slight overhead

Rotations are logged to the training timeline. All splits use fixed random seeds (42 for initial, epoch-based for rotations) for reproducibility.

Distribution Drift Detection

Per-column KL divergence between train and validation distributions is computed during setup:

KL Divergence Interpretation Action
< 0.1 Nearly identical None needed
0.1 – 1.0 Moderate drift Review splitting
> 1.0 High distribution shift Warning issued
> 2.0 Severe mismatch Investigate data leakage

High KL divergence indicates the split created biased subsets—a sign that the data may have structure (temporal, clustered) that random splitting violates.


Model Architecture Details

This subsection covers parameter counts, dimensions, scaling, and data quality filters.


14. Model Architecture: Parameters and Dimensions

The architecture scales to your data's complexity through a single knob: the embedding dimension, d_model. This number—128, 256, or 512—determines the capacity of every component in the system. Column encoders output d_model-dimensional vectors. The transformer operates on d_model-dimensional tokens. The final row embedding is d_model dimensions. Everything flows through this unified representation space.

For most datasets—a few dozen columns, standard complexity—128 dimensions is plenty. The model fits comfortably in memory, trains quickly, and captures the relationships that matter. Think of 128d as the "fast iteration" setting: you'll train in minutes, experiment rapidly, and get 95%+ accuracy on most problems. The parameter count stays modest, maybe 50,000 to 500,000 total depending on column count and types. This is small enough to train on a CPU if you're patient, though a modest GPU makes it painless.

When you move to 256 dimensions, you're buying headroom for more complex datasets—more columns, richer interactions, subtler patterns. The parameter count grows quadratically: transformer attention is O(d²), feedforward networks are O(d²), every component scales with the square of the dimension. That 128d model with 200K parameters becomes 800K parameters at 256d. But you also get dramatically more expressiveness: the model can represent more nuanced relationships, separate more finely-grained concepts, and handle datasets where 128d would hit its capacity ceiling. If you have 50+ columns or your data has known complex interactions—hierarchical structures, many-to-many relationships, subtle correlations—256d is where you want to start.

At 512 dimensions, you're at the upper limit of what's practical for most problems. The parameter count explodes into the millions: 2M, 5M, 10M depending on configuration. Training slows down noticeably—GPU memory becomes a constraint, batch sizes shrink, and throughput drops. But for truly large-scale problems—hundreds of columns, massive vocabularies, datasets where every bit of capacity matters—512d is available. The model can represent almost anything your data contains. The catch is you need the data to fill that capacity. Running 512d on a dataset with 10 columns and 1,000 rows is like renting a supercomputer to add two numbers. The model will overfit spectacularly, learning to memorize random noise. 512d shines on datasets with millions of rows and hundreds of features where the additional capacity pays for itself in better generalization.

The parameters themselves are distributed across three main components. First, the column encoders: every column gets its own encoder network that transforms raw values into d_model-dimensional embeddings. Scalar columns use AdaptiveScalarEncoder with 20 different encoding strategies—percentile buckets, logarithmic scaling, clipped ranges, and more—each with its own small MLP. That's roughly 100-150K parameters per scalar column. Categorical columns use SetEncoder with learned vocabulary embeddings, semantic projections powered by cached BERT representations, and gating networks that blend semantic and learned components. String columns get even more complex: they might use DelimiterAttention (200K parameters to intelligently split and attend over parts), RadixAttention (250K parameters for prefix-tree based compression), or simpler strategies like character n-grams. For a dataset with 20 columns—mix of scalars, categories, strings—the column encoders alone might consume 2M parameters.

Next comes the joint encoder: the transformer that attends across columns and learns relationships. Every layer costs 12×d² parameters: 4×d² for the query/key/value projections in attention, 8×d² for the two-layer feedforward network with 4× expansion. At d=256, that's 786,000 parameters per layer. The default configuration uses 3 layers, so roughly 2.4M parameters in attention and feedforwards. Add in layer normalization (small), positional embeddings (d × number of columns), the CLS token (d dimensions), and relationship tokens (if enabled), and you get another 100K or so. Total for the joint encoder: around 5-6M parameters at d=256 with 3 layers.

Finally, the output head: projections that convert the learned representation into short embeddings for visualization and column predictions for the marginal loss. This is typically the smallest component—a few hundred thousand parameters, maybe 1M at most. The CLS output passes through a final MLP and normalization to produce the row embedding. Per-column prediction heads decode from the joint representation back to the original column's value space for reconstruction loss.

Put it together for a 20-column dataset at d=256: roughly 11M parameters total, distributed as 2M in column encoders, 6M in the transformer, 3M in output projections and auxiliary components. The memory footprint in 32-bit floating point is around 45MB just for the parameters themselves—small enough to fit in L3 cache on modern CPUs. But during training, you also need gradients (another 45MB), optimizer state for Adam (90MB for first and second moments), and most expensive of all, activations: the intermediate tensors produced by every layer during the forward pass. At batch size 256, those activations can consume 500MB to 1GB depending on sequence length and gradient checkpointing strategy. All told, training a medium-sized model requires 1-2GB of memory for the model state alone, plus whatever the data batches consume.

This is why Featrix defaults to 128 dimensions and scales up only when necessary. Larger models aren't inherently better—they're just more expensive and easier to overfit. The system picks the dimension based on your dataset's complexity: column count, vocabulary size, expected relationship richness. Most of the time, 128 or 256 is the sweet spot where you get excellent performance without burning through memory or training time. If you find yourself needing 512d, you're probably working on a problem that would have required a team of ML engineers months to solve manually. The fact that Featrix can scale to that level with zero configuration is itself remarkable.


15. Dual Embeddings: Short (3D) vs Full (d_model)

Every encoder produces two output vectors from the same learned representation: a full_vec (d_model dimensions) for training and prediction, and a short_vec (3 dimensions) for visualization.

The short embedding consists of the first 3 dimensions of the full embedding, L2-normalized to live on a unit sphere. This enables real-time 3D visualization through stereographic projection and animated "movie frames" during training.

The full embedding uses all d_model dimensions (128–512) for maximum expressiveness. All predictions use the full embedding; mutual information estimation requires full dimensionality.

The short embedding is a projection—it captures the same semantic structure, just compressed for visual inspection. The first 3 dimensions are not random; they're the most important components of the learned representation, analogous to top-3 PCA components.


16. Intelligent Column Filtering: Excluding Zero-Information Data

Real-world datasets are filthy. They contain columns that serve database purposes—transaction IDs, row hashes, internal timestamps—but contribute zero information for learning. Including these columns in training is worse than useless: they consume parameters modeling pure noise, dilute the attention mechanism by adding irrelevant tokens, and waste compute on relationships that can never exist. A UUID column that's different for every row has perfect uniqueness and zero predictive power. The model will spend thousands of gradient updates trying to find patterns in random hexadecimal strings. It's not going to find any, but it'll burn memory and training time trying.

Featrix detects and excludes these columns automatically before training starts. The system isn't looking for a magic pattern that perfectly identifies "useless" columns—that's impossible without knowing the actual prediction task. Instead, it uses a collection of signals that, when combined, reliably catch the worst offenders.

Start with the obvious cases: columns that are entirely null, or entirely the same value. All-null columns contribute literally nothing—there's no information to learn because there's no data. Uniform columns are almost as bad: if every row has the same value ("United States" in a US-only dataset, "Active" in a table of active users), the column provides no discrimination between rows. The model could learn the constant but gain zero predictive power. These get filtered immediately with simple checks: does the column contain any non-null values? How many distinct values exist? If the answers are "no" or "one", the column is gone.

Then hunt for random identifiers. The tell-tale signs are high uniqueness and low semantic similarity. If 95% or more of values are distinct, the column is likely some kind of ID. But check further: maybe it's a product name column with a rich vocabulary—that would also have high uniqueness but would be semantically meaningful. So compute semantic similarity: grab BERT embeddings for a sample of values and measure how correlated they are. Random UUIDs produce uncorrelated embeddings—each string is semantically unrelated to the others. Product names, even with high uniqueness, cluster semantically: "iPhone 12 Pro" is similar to "iPhone 13 Pro Max", which is similar to "iPad Air", and so on. The embeddings reveal structure even when raw string matching doesn't.

Combine this with structural patterns. Does the column contain values that match UUID format (8-4-4-4-12 hex digits)? Hash format (32 or 64 hex digits)? Snowflake IDs (19-digit integers)? These patterns are probabilistic signals: not every UUID-shaped string is meaningless, and not every random ID looks like a UUID. But when you see high uniqueness + low semantic similarity + UUID pattern, you can be pretty confident you're looking at a transaction ID column that should be excluded.

Finally, check the column name itself. Does it contain "id", "uuid", "hash", "key", "token"? This is the weakest signal—plenty of meaningful columns have "id" in the name (customer_id might be predictive if customers have repeat behavior)—but it's still informative when combined with other evidence. If the values look like UUIDs, act like random strings, and the column is named transaction_id, all the evidence points one way. Exclude it.

The system uses weighted scoring. High uniqueness contributes 0.3, low semantic similarity contributes 0.4, structural pattern match contributes 0.5, name hints contribute 0.1. Add them up. If the total score exceeds 0.7, the column gets excluded. This threshold catches obvious garbage—pure UUIDs, MD5 hashes, Snowflake IDs—while leaving alone high-cardinality columns that have real semantic content, like email addresses (semantically meaningful patterns), user agents (structured but informative), or URLs (domain and path carry signal).

Examples of what gets caught: transaction_id with values like "a8f2c1d4-9e3b-4a5c-8d7e-6f1a2b3c4d5e" (UUID pattern + high uniqueness + name hint = excluded). user_hash with values like "3f4b8c9d2a1e7f6b5c4d3e2a1b0c9d8e" (32 hex digits + high uniqueness + name hint = excluded). session_token with random alphanumeric (no pattern match, but high uniqueness + low semantic + name = still excluded).

Examples of what doesn't get caught: customer_id with values 1, 2, 3, ... (high uniqueness but low cardinality relative to total, and numeric IDs often correlate with customer tenure—kept). product_name with 10,000 distinct product names (high uniqueness but high semantic similarity—brand and category structure present—kept). email_domain extracted from email addresses (moderate uniqueness, strong semantic clustering by company/provider—kept).

There's one more class of column that gets auto-excluded: internal columns that Featrix itself creates. Any column starting with __featrix is metadata—filter indicators, row weights, training control signals—that should never be included in the embedding space. These get stripped silently before training begins. Users never see them, but they enable powerful features like selective row training and dynamic data filtering without polluting the learned representation.

The result is a cleaned dataset where every column that enters training has a fighting chance of contributing signal. The garbage is gone. The parameters aren't wasted on noise. The attention mechanism focuses on relationships that actually exist. And this happens automatically—no manual inspection, no feature engineering, no guessing which columns to drop. The system looks at your data and filters out the obvious poison before training even starts. It won't catch every useless column—that's impossible without knowing the prediction task—but it catches the worst offenders, the ones that would actively harm learning if included. What's left is clean enough to train on, and the trained model can later reveal which remaining columns actually matter through mutual information estimates and feature importance scores.


The Predictor Deep Dive

This subsection covers Predictor architecture, fine-tuning, output structure, and safety mechanisms.


17. Predictor Architecture and Fine-Tuning

After training an Embedding Space, predictors train on top of it. Each predictor is a neural network that takes row embeddings and predicts a target column.

Predictor Architecture

Row Embedding (d_model dim)
┌─────────────────────────────────────────┐
│  PREDICTOR MLP                          │
│                                         │
│  Linear(d_model → d_hidden)             │
│       │                                 │
│  ┌────┴────┐ × n_hidden_layers          │
│  │ Hidden  │                            │
│  │ Block:  │ [Linear + BatchNorm +      │
│  │         │  LeakyReLU + Dropout]      │
│  │         │  + Residual connection     │
│  │         │  + Self-Attention (opt)    │
│  └────┬────┘                            │
│       │                                 │
│  Linear(d_hidden → n_classes)           │
└─────────────────────────────────────────┘
   Logits/Predictions

Automatic Architecture Selection

When n_hidden_layers=None, the system analyzes mutual information between features and target, nonlinearity detection (Random Forest vs Linear comparison), class imbalance, and chi-square tests for categorical targets.

Data Characteristics Hidden Layers
Simple linear problem (MI > 0.4) 2 layers
Moderate complexity 2-3 layers
Complex nonlinear (RF >> Linear) 3-4 layers
Small dataset (< 2K rows) Capped at 2 layers

Typical parameter counts: d_hidden=256, 2 layers, d_model=128 → ~100K; d_hidden=512, 4 layers, d_model=256 → ~700K.

Fine-Tuning

Mode What Trains Pros Cons
Frozen (fine_tune=False) Predictor only Faster, preserves representations Can't adapt to target
Fine-tuning (fine_tune=True) Encoder + Predictor Better performance Overfitting risk, slower

Fine-tuned encoder weights are saved with the predictor for consistent inference.

Loss Function Selection

The predictor supports multiple loss functions, with automatic selection based on data characteristics:

Loss Type Best For How It Works
focal (default) Imbalanced classification Focuses on hard examples via gamma parameter
cross_entropy Balanced classification Standard cross-entropy with optional label smoothing
prauc Severely imbalanced Ranking-based loss optimizing PR-AUC directly
auto Let system decide Analyzes data to select optimal loss

When loss_type="auto", the ModelAdvisor examines: - Class distribution and imbalance severity - Cost asymmetry (cost_false_positive vs cost_false_negative) - Dataset size and characteristics

Class Weight Computation

Class weights compensate for imbalanced class distributions. Featrix uses square root inverse frequency weighting:

weight[class] = sqrt(1 / frequency[class])
weights are normalized to average 1.0
Class Distribution Majority Weight Minority Weight Boost Factor
50% / 50% 1.0 1.0 1.0×
84% / 16% 0.61 1.40 2.3×
97% / 3% 0.51 2.87 5.6×

Why sqrt and not full inverse? Full inverse (1/frequency) gives 33× weight for a 3% minority class, causing reverse bias—the model over-predicts the minority. Sqrt weighting (5.6×) is gentler and avoids this failure mode.

Focal Loss

Focal Loss addresses class imbalance by down-weighting easy examples and focusing on hard ones:

FL(p_t) = max(min_weight, (1 - p_t)^gamma) × CE(p_t)

where:
  p_t = probability of correct class
  gamma = focusing parameter (default: 2.0)
  min_weight = minimum weight for easy examples (default: 0.1)
Example Type p_t Focal Weight (γ=2) Effect
Very easy (97% confident) 0.97 max(0.1, 0.03²) = 0.10 10% of normal weight
Easy (80% confident) 0.80 max(0.1, 0.20²) = 0.10 10% of normal weight
Medium (60% confident) 0.60 max(0.1, 0.40²) = 0.16 16% of normal weight
Hard (30% confident) 0.30 max(0.1, 0.70²) = 0.49 49% of normal weight
Very hard (10% confident) 0.10 max(0.1, 0.90²) = 0.81 81% of normal weight

The system automatically detects and corrects reverse bias by adjusting gamma and min_weight based on prediction distribution mismatch.

Probability Calibration

After training, calibration is auto-fit on the validation set:

Method Description Use Case
Temperature Scaling Single parameter T dividing logits Most common, simple
Platt Scaling Logistic regression on logits Better for small datasets
Isotonic Regression Non-parametric monotonic mapping Complex calibration curves

For binary classification, the system finds the F1-optimal threshold on validation data.


18. Predictor Safety and Guardrails

Failure Detection

Predictor training monitors for six failure modes:

Failure Detection Symptoms
DEAD_NETWORK Prob std < 0.005 All outputs identical
CONSTANT_PROBABILITY Prob std < 0.03 All predictions ~same probability
SINGLE_CLASS_BIAS >95% predictions same class Always predicting majority
RANDOM_PREDICTIONS AUC < 0.55 Model is guessing randomly
UNDERCONFIDENT >70% predictions in [0.4, 0.6] Model unsure about everything
POOR_DISCRIMINATION AUC < 0.65, accuracy < 0.6 Varied but wrong predictions

Prediction Guardrails

Every prediction includes: - Calibrated probabilities: When model says 70%, it means 70% - Feature importance: Which columns drove this prediction - OOD detection: Distance from training distribution - Confidence scores: Uncertainty quantification - Class imbalance warnings: Context about training data distribution

Example prediction response:

{
  "prediction": "churn",
  "probability": 0.847,
  "confidence": "high",
  "warnings": [],
  "feature_importance": {
    "account_age_days": -0.42,
    "days_since_last_purchase": 0.31,
    "support_tickets_30d": 0.18
  },
  "metadata": {
    "class_distribution_train": {"churn": 0.03, "active": 0.97},
    "optimal_threshold": 0.34
  }
}


Advanced Features

This subsection covers advanced capabilities: resume training, extending spaces, foundation models, and multi-table data.


19. Resume Training and Extending Embedding Spaces

Resume Training

An Embedding Space can resume training from any saved checkpoint to recover from crashes, extend training when models haven't converged, or incrementally train with new data (same schema).

When train() is called with existing_epochs set to a checkpoint epoch, the system loads checkpoint state (model weights, optimizer state, scheduler position, training history), recovers from corruption if needed (searches backward for last valid checkpoint), recreates non-serializable objects (DataLoaders, schedulers), and continues training.

Extending Embedding Spaces

When new feature columns are engineered or discovered, extend_from_existing() adds new columns while preserving existing encoder weights.

Two-Phase Training Strategy:

Phase Epochs Existing Encoders New Encoders Purpose
Phase 1 epochs/8 🔒 Frozen 🏋️ Training Let new columns learn without disturbing existing
Phase 2 epochs/8 🏋️ Training 🏋️ Training Fine-tune everything jointly
Total epochs/4 4× faster than full retraining

This ensures new column encoders learn good representations before joint training while existing encoders maintain their learned patterns initially.

Foundation Models

A Foundation Model in Featrix is a pre-trained Embedding Space that serves as the basis for multiple predictors:

Customer Data
Embedding Space (train once: 60 min)
     ├─ Churn Predictor (train: 5 min)
     ├─ Lifetime Value Predictor (train: 8 min)
     ├─ Segment Predictor (train: 6 min)
     ├─ Next Purchase Predictor (train: 7 min)
     └─ Support Ticket Risk Predictor (train: 5 min)

Total time: 60 min (ES) + 31 min (5 predictors) = 91 minutes for 5 production models.

Hierarchical Embedding Spaces (Multi-Table)

For multi-table data (customers + transactions, orders + line_items), hierarchical embedding spaces avoid row explosion:

  1. Train a separate embedding space for each table
  2. Learn join relationships that aggregate child patterns up to parent level
  3. Avoid row explosion by seeing each parent row exactly once

This mirrors normalized database design but with learned aggregations instead of hand-crafted SQL GROUP BY clauses.


20. Appendix: Quick Reference and Glossary

Key Concepts

Embedding Space (ES): Self-supervised neural network that learns universal representations of tabular data. Trains without labels by predicting masked columns from visible ones. Outputs row embeddings, column embeddings, and mutual information estimates.

Predictor (Single Predictor): Lightweight supervised neural network that trains on top of an ES to predict a specific target column. Typically 2-4 layers, trains in 5-20 minutes.

d_model: The embedding dimension (128, 256, or 512) that determines model capacity. All components (column encoders, transformer, row embeddings) operate in d_model-dimensional space.

[CLS] Token: Learnable parameter prepended to every token sequence. After transformer layers, aggregates information from all columns and relationships into the row embedding.

Causal Relationship Scoring: Measures the actual marginal benefit of each column relationship by tracking improvement rates when paired vs unpaired. Uses Lower Confidence Bound (LCB) and recency weighting to conservatively estimate which relationships help learning.

InfoNCE Loss: Contrastive learning objective that pulls similar examples together and pushes dissimilar ones apart in embedding space. Used for both joint loss (masked vs unmasked views) and marginal loss (per-column prediction).

Marginal Loss: Per-column prediction loss that measures how well each masked column can be predicted from visible columns. Provides mutual information estimates: MI = log(batch_size) - loss.

Spread Loss: Anti-collapse loss that prevents all embeddings from converging to the same point. Enforces that each row's embedding should be most similar to itself.

OneCycleLR: Learning rate scheduler with three phases: warmup (15%), peak (5%), cosine decay (80%). Provides fast convergence with stable gradients.

Focal Loss: Class imbalance loss that down-weights easy examples and focuses on hard ones via gamma parameter. Prevents model from ignoring minority class.

Stratified Splitting: Ensures each class appears proportionally in train/val splits, preventing rare classes from being missing from validation.

Gradual Data Rotation: For ES training only—swaps 10% of samples between train/val every 25+ epochs to prevent overfitting to validation set. Never used for Predictor training (would cause data leakage).

Lift: Causal effect of pairing columns i and j, measured as the improvement rate difference between paired and unpaired epochs. Positive lift = helpful, negative lift = harmful, zero lift = no effect.

LCB (Lower Confidence Bound): Conservative estimate of lift: mean - 1.96 × std. Accounts for uncertainty—pairs with high variance get penalized to avoid trusting noisy measurements.

Mutual Information (MI): Bits of information column A provides about column B. Quantifies how much knowing A reduces uncertainty about B, accounting for nonlinear relationships.

Semantic Floor: For SetEncoder, starts at 70% semantic (BERT) / 30% learned, gradually relaxes to 10% semantic / 90% learned. Forces model to leverage pre-trained semantics before learning custom patterns.

Adaptive Scalar Encoder: Learns which of 20 encoding strategies work best for each numeric column (linear, log, robust, rank, periodic, etc.). Dual-path architecture with gating network.

Hybrid Column Detection: Automatically detects columns that semantically belong together (address components, coordinates, entity attributes). Uses MERGE strategy (combine into one token) or RELATIONSHIP strategy (add shared group embedding).

Timeline: JSON record of every epoch with LR, losses, dropout, gradients, failures detected, and corrective actions. Invaluable for debugging training issues.

Default Hyperparameters

Parameter Default Auto-Computed From
d_model 128–256 Column count, vocab size, complexity
Batch size min(2048, n_rows/100) Dataset size, GPU memory
Learning rate 5e-5 to 3e-4 Dataset size (smaller = more conservative)
Epochs 36000 / steps_per_epoch Target optimizer updates
Dropout 0.5 → 0.25 Scheduled by training phase
Temperature base / (batch × columns) Clamped to [0.01, 0.2]
Weight decay 1e-4 to 0.1 Dataset size (smaller = heavier regularization)
Gradient clip loss × 2.0 Adaptive to loss magnitude
Transformer layers 3 Fixed
Attention heads 8 Fixed

Training Timelines

Dataset Size ES Training Predictor Training
1K rows, 20 cols ~10 minutes ~2 minutes
10K rows, 50 cols ~30 minutes ~5 minutes
100K rows, 100 cols ~2 hours ~15 minutes

Failure Modes Summary

Embedding Space: DEAD_NETWORK, NO_LEARNING, VERY_SLOW_LEARNING, SEVERE_OVERFITTING, MODERATE_OVERFITTING, UNSTABLE_TRAINING

Predictor: DEAD_NETWORK, CONSTANT_PROBABILITY, SINGLE_CLASS_BIAS, RANDOM_PREDICTIONS, UNDERCONFIDENT, POOR_DISCRIMINATION

All failures trigger automatic diagnostics and corrective actions (LR boost, temperature boost, focal adjustment).


Conclusion

This document has covered the complete Featrix Embedding Space architecture from foundational concepts to implementation details. Part I provided a progressive introduction for anyone evaluating Featrix—covering the ML nightmare, data handling, core concepts, use cases, and safety mechanisms. Part II dove into the technical implementation: type-specific encoding, the transformer-based joint encoder with causal relationship scoring, self-supervised masking, multi-objective loss functions, training dynamics with automatic monitoring and intervention, and the predictor architecture.

The key insights:

  1. Foundational representations beat one-off models: Train the Embedding Space once, reuse for unlimited prediction tasks.
  2. Zero configuration through intelligent automation: Batch size, LR, epochs, dropout, temperature, class weights—all computed automatically from your data.
  3. Self-supervised learning discovers structure: No labels needed for ES training. The model learns by predicting columns from columns.
  4. Causal relationship scoring is principled pruning: Measure actual lift, quantify uncertainty, validate automatically. No heuristics.
  5. Automatic monitoring and intervention: Training doesn't just run—it diagnoses problems and fixes itself (LR boosts, temperature adjustments, early stop blocking).
  6. Production-ready by default: Calibrated probabilities, OOD detection, feature importance, guardrails, timeline export, stratified splitting, gradient clipping—all automatic.

Featrix makes machine learning accessible by eliminating the manual decisions that break traditional ML projects. Upload data, wait for training, get production-ready predictions with full explainability and monitoring. The rest of this complexity? Handled automatically.

For questions, issues, or contributions, see the main Featrix documentation and GitHub repository.