Ah, the sweet smell of progress in the morning! OpenAI has just dropped their shiny new o1-mini and o1 models, and the internet is abuzz with hot takes faster than you can say “AGI winter is coming.” You’ve probably seen the YouTube videos: wide-eyed tech bros marveling at the models’ ability to solve differential equations or explain quantum mechanics to their goldfish.
But let’s be real: you’re not here for another “OMG, AI can [insert mundane task]!” video. No, you, dear reader, are a person of substance. You crave the meaty, first-principles concepts that will help you navigate the choppy waters of AI progress without drowning in a sea of hype. You’re tired of every Tom, Dick, and Elon barging into the AI field like a bull in a china shop, asking everyone to consider their groundbreaking idea of a spherically symmetric cow in n-dimensional Hilbert space.
That’s why we’re diving into inference scaling today. It’s a crucial concept that’s driving the impressive performance of these new models, and understanding it is key to grasping the current state of AI technology.
Now, if you’re here to debate whether we’re on the cusp of AGI or if these models are secretly plotting to overthrow humanity and replace us with more efficient toasters, I’m going to have to ask you to see yourself out. There are plenty of Twitter threads and subreddits where you can indulge in that particular flavor of speculation. We’re here for the nitty-gritty, the nuts and bolts, the “how does this actually work?” of it all.
So, strap in, silence your phone (unless you’re reading this on it, in which case, carry on), and let’s embark on a journey into the heart of inference scaling. It’s time to learn why OpenAI is betting big on this approach, and why you should care. Welcome to the bleeding edge of AI, where the cows are spherical, the Hilbert spaces are infinite, and the potential for groundbreaking insights is limitless.
Disclaimer: Before we dive in, I just want to be super clear: what follows is my best understanding of inference scaling based on publicly available information. As for what’s really going on inside OpenAI’s secret volcano lair? Well, your guess is as good as mine. They’re about as forthcoming with their methods as a cat is with its tax returns.
Navigation Prompt: If you are familiar with different LLM layers, how they work and what roles they accomplish you can skip this section otherwise please read on.
The Assembly Line of Sentences: How Language Models Work (Infer)
Now that we’ve cleared the air of AGI speculation and spherical bovines, let’s delve into the intricacies underpinning inference scaling. Note that this is not only about throwing more FLOPs at the problem. To understand how inference scaling works we need to first understand how basic inference works.
Imagine, for a moment, that you’re running a factory. Not just any factory, mind you, but one that produces bespoke, artisanal sentences on demand. Your raw materials? Vectors of floating-point numbers. Your end product? Everything from Shakespearean sonnets to Python code.
Welcome to the world of Large Language Model inferencing. It’s a bit like running a just-in-time manufacturing operation, except instead of assembling cars, you’re assembling words. And let me tell you, the logistics are fascinating.
Let’s break down this process, shall we? It all starts with a prompt. Think of this as the order form for your sentence factory. “I need a limerick about database optimization, stat!” your customer (who is suspiciously often yourself) demands. Your factory springs into action:
The Tokenizer: This is your receiving department. It breaks down the incoming order into bite-sized pieces the rest of the factory can work with. “Database” might become “data” and “base”, while “optimization” could be “optim” and “ization”. It’s like those old “FRAGILE: HANDLE WITH CARE” stamps, except here it’s “LANGUAGE: HANDLE WITH VECTORS”. The tokenizer layer is the gatekeeper of language models, transforming raw text into a format the model can understand. Think of it as a translator that converts human-readable text into a sequence of numbers (tokens) that the AI can process.
Tokenize(text) = [token_1, token_2, ..., token_n]
where \(token_i = vocabulary_{index}(subword_i)\). Algorithmically, the process works in 3 sub-steps:
- Text Segmentation: The input text is broken down into subwords, words, or characters, depending on the tokenization strategy.
- Vocabulary Lookup: Each subword is mapped to a unique integer index in the model’s vocabulary.
- Token Generation: The sequence of these indices forms the final token sequence.
The Embedding Layer: This is where the magic of turning words into numbers happens. It’s as if each word is run through a very complex, very mathy Instagram filter. “Data” doesn’t just mean “data” anymore; it’s now a list of 768 floating-point numbers that somehow capture the essence of “data-ness”. Imagine the embedding layer as a massive, multidimensional dictionary or lookup table. Here’s is its structure:
Rows: Each row in this table corresponds to a token in your vocabulary. If your model has a vocabulary of 50,000 words, your table has 50,000 rows.
Columns: Each column represents a dimension of the embedding. If you’re using 300-dimensional embeddings, your table has 300 columns.
Cells: Each cell in this table contains a single number, typically a floating-point value.
Embedding(token_id) = LookupTable[token_id]
where LookupTable is a matrix of shape (vocab_size, embedding_dim)
The Transformer Layers: These layers have been covered very well in multiple posts all over the net, specifically look at [3] for more detail. This is the real engine of your factory. Imagine a room full of savants, each one looking at your partial sentence, conferring with other savants based on information they have, and then scribbling notes about what word should come next. Now imagine doing this dozens of times in parallel while sharing information. That’s what’s happening here, except the savants are matrix multiplications and the notes are more vectors. There are other parts to the layer but at the heart of of this communcation and compute is the attention layer which can be specified by Key, Value, Query triplets (K,Q,V) as follows: \[ Attention(Q, K, V) = softmax((QK^T) / √d_k) V \]
where, Key represents information that each of them has, Query is each one asking - hey who has this information that I need and Value is the information that they share with each other once they identify the overlap between whats needed and what they have based on the dot product between K and Q.
The Output Layer: This is your quality control department. It takes all those scribbled savant notes and turns them into actual probabilities for each word in your vocabulary. “The” might get a 2% chance, “database” a 0.5% chance, and “aardvark” a 0.0001% chance. (Hey, you never know when you might need an aardvark in your limerick about databases).
This is embodied by our Softmax layer. It takes the raw scores (often called logits) associated with each possible token in the vocabulary and transforms them into a probability distribution. Essentially, for each position in a sequence, the softmax layer calculates the likelihood of each token in the vocabulary appearing in that position. It does this by exponentiating the scores and then normalizing them so that they sum to 1. This process ensures that the model outputs a valid probability distribution over all possible tokens, allowing it to make predictions about the most likely next token in a sequence or to classify tokens into different categories. Mathematically, you can write it as: \[ softmax(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} \]
where i is the current token and j iterates over all tokens in the vocabulary.
Token Selection and Iteration: This is where your factory’s assembly line completes its cycle by choosing the next token based on Softmax output, transforming probability distributions into the next token in the sequence. This step is also where you decide how wild you want your factory’s output to be. e.g. Greedy search tends to produce more repetitive text, while sampling methods can lead to more varied and sometimes surprising results*.
Finally, we get a sequence of tokens that we convert into readable text using the inverse transform from look up tables in Step 1. i.e. if you were inputting rows and reading off columns as the output, here you input columns and read off rows and vice-versa.
By repeating this process, your AI factory can churn out everything from simple sentences to complex essays - Or entire essay about why PostgreSQL is actually a misunderstood performance art piece. (Don’t judge. GPT-4o has some weird hobbies.) How you generate tokens, iterate, and post process in this last step is what defines how “thoughtful” your LLM sounds. Lets dig into that in the next section.
Navigation Prompt: Details of algorithms in Step 5 is what inference scaling is all about. If you are familiar with various token generation strategies skip ahead otherwise please read on.
In AI’s Deep Reflection: Where Probability Flirts with Search and Models Judge Other Models’ Pickup Lines
For years, the AI community has been on a bulk-up routine that would make bodybuilders jealous - just feed the model more parameters and data. Inference scaling changes this up by giving AI models a gym membership and a personal trainer instead. It asks the question - instead of always making our models bigger, how do we make them think harder? Let’s dive into the three main approaches for that with a bit more technical detail this time.
1. Parallel Sampling: The Brute Force Charmer
Parallel sampling is a straightforward yet powerful technique in the world of inference scaling. At its core, it’s about generating multiple independent solutions and selecting the best one.
How It Works:
When presented with an input, the AI model doesn’t generate just one answer. Instead, it produces N complete, independent answers. This process is akin to running the model N times in parallel, each time generating a full response to the input.
The key to this method’s effectiveness lies in its evaluation step. We employ a sophisticated evaluator, an Outcome Reward Model (ORM). This is an AI model, trained to assess the quality of generated answers based on various criteria. The ORM examines each of the N answers and assigns them a score. This scoring isn’t just based on correctness, but can include factors such as the quality of reasoning, clarity of explanation, and adherence to the task requirements.
Once all N answers have been scored, we can use a “best-of-N weighted” selection method to choose the final answer. This method is more nuanced than simply selecting the highest-scoring response. Instead, it considers all answers that arrived at the same conclusion and sums their scores. The conclusion with the highest total score is then selected as the final output.
Mathematical Formulation:
For a given input x:
- Generate outputs: y₁, y₂, …, yₙ
- Score each output: s₁ = PRM(y₁), s₂ = PRM(y₂), …, sₙ = PRM(yₙ)
- Final selection: y* = argmax(s₁, s₂, …, sₙ)
In practice, the selection process might be more complex, involving grouping similar outputs and summing their scores before selecting the highest-scoring group.
Strengths:
- Simple and scalable. Improving results often involves simply increasing N, effectively leveraging additional computational resources to boost performance - a clear trade-off.
- Works well for easier questions or tasks where generating a large number of attempts is likely to produce at least one high-quality answer.
- Works especially when the space of possible good answers is relatively large and diverse.
Weaknesses:
- Computationally inefficient for complex problems.
- While it will eventually find a good solution if N is large enough, it may use significantly more resources than more targeted approaches.
- The compute costs can become prohibitive if N is set too high, especially for resource-intensive models.
- Each generation is independent, potentially repeating work or mistakes made in other attempts. It doesn’t learn from or build upon partial successes within a single attempt.
All-in-all it is a simple but brute-force approach to inference scaling. Its straightforward nature makes it an attractive option, especially when dealing with tasks where the generation of multiple diverse attempts is likely to yield at least one high-quality result. However, its effectiveness needs to be balanced against its potential computational costs, especially for more complex tasks or when scaling to very large N.
2. Beam Search: The Chess Grandmaster of Inference
Unlike parallel sampling, which generates complete independent answers, beam search constructs solutions incrementally, making decisions at each step about which partial solutions are most promising to develop further.
How It Works:
- Initialization: The process begins by generating N initial predictions for the first step of the solution. These represent different starting points for the answer.
- Scoring: A Process Reward Model (PRM) is used to score each of these N initial steps. The PRM evaluates the quality and potential of each partial solution. Note that this is also how it is different from an ORM above.
- Pruning: Instead of keeping all N initial steps, beam search retains only the top N/M highest-scoring ones, where M is the beam width. This step focuses the search on the most promising partial solutions.
- Expansion: For each of the retained partial solutions, the model generates M new proposals for the next step. This brings the total number of candidates back to N (N/M * M = N).
- Iteration: Steps 2-4 are repeated until either a complete solution is reached or a maximum number of rounds is hit.
This process allows beam search to maintain a diverse set of partially completed solutions, continually evaluating and refining them as it progresses.
Mathematical Formulation:
At each step t:
- Generate candidates: C_t = {c₁, c₂, …, cₙ}
- Score candidates: S_t = {PRM(c₁), PRM(c₂), …, PRM(cₙ)}
- Keep top k: B_t = top_k(C_t, S_t, k=N/M)
- Expand: C_{t+1} = ⋃_{b ∈ B_t} expand(b, M)
where:
- C_t is the set of candidate partial solutions at step t
- S_t is the set of scores for these candidates
- B_t is the set of top-scoring candidates retained
- expand(b, M) generates M new candidates from partial solution b
Strengths:
- Complex problem-solving where the solution is built incrementally.
- Natural language generation tasks requiring coherent, long-form responses.
- Multi-step planning or decision-making processes.
Weaknesses:
- There is potential for pruning prematurely. By discarding lower-scoring partial solutions early, beam search might miss unconventional but ultimately superior solutions that don’t appear promising in their early stages.
- While more efficient than exhaustive search, beam search can still be computationally intensive, especially with large beam widths or for problems requiring many steps.
- Its performance can be highly dependent on the choice of N and M. Optimal values may vary significantly between different types of problems.
Beam search is particularly suited for problems where the solution quality depends on a series of interconnected decisions. Its ability to maintain and explore multiple promising solution paths makes it a go-to approach for complex, multi-step reasoning tasks in AI.
3. Revision-Based Approach: The Perfectionist’s Dream
This approach is focuses on iterative improvement of an initial response. Unlike parallel sampling or beam search, which generate multiple independent answers or explore multiple paths simultaneously, this approach generates a single answer and then repeatedly refines it.
How It Works:
- Initial Generation: The model produces an initial answer to the given input.
- Iterative Refinement: The model is then presented with its previous answer(s) along with the original input and asked to generate an improved version.
- Repetition: This process of refinement is repeated for a predetermined number of iterations or until some convergence criterion is met.
- Selection: Once all iterations are complete, the best version is selected using either an Outcome Reward Model (ORM) or a majority voting mechanism.
Mathematical Formulation:
1. Initial generation: y₀ = model(x)
2. for i = 1 to k:
yᵢ = model(x, y₀, y₁, ..., yᵢ₋₁)
3. Final selection:
y* = argmax(ORM(y₀), ORM(y₁), ..., ORM(yₖ))
or
y* = mode(y₀, y₁, ..., yₖ)
Where:
- x is the input
- yᵢ is the i-th revision
- k is the total number of revisions
- ORM is the Outcome Reward Model
Key Components:
- Revision Model: This is typically a fine-tuned version of the base language model, specifically trained to improve upon previous answers. The training process involves exposing the model to sequences of increasingly better answers to the same question.
- Outcome Reward Model (ORM): This is a separate model trained to evaluate the quality of complete answers. It’s used to score each revision and select the best one.
- Majority Voting: An alternative to ORM, this method selects the most common answer among all revisions. It’s particularly useful when multiple independent revision sequences are generated.
Strengths:
- This approach allows for gradual refinement of the answer, potentially leading to high-quality results, especially for medium-difficulty questions.
- It can be more computationally efficient than generating multiple complete answers, as in parallel sampling.
- By considering previous attempts, the model can build upon its own insights and correct its mistakes.
Weaknesses:
- This approach may get stuck in a local optimum, repeatedly making similar refinements without considering radically different approaches.
- The quality of the final output can be heavily influenced by the quality of the initial answer.
- Training a model to effectively revise its own work is a complex task, requiring sophisticated training data and techniques.
Revision-based approach represents a unique strategy in inference scaling, mimicking a facet of the human creative process of drafting and refining ideas. As I have mentioned elsewhere that I don’t think that’s the whole story as far as human creativity is concerned. However, the ability to iteratively improve upon its own output makes for a powerful tool in the AI toolkit, especially for tasks where quality can be enhanced through careful refinement and consideration of previous attempts.
Conclusion: The Art of Thinking Harder, Not Just Bigger
This post has become way longer than originally intended so thanks for staying with it. For more details on performance metrics, and specific use-cases for each technique, I highly recommend checking out reference [3] below.
As I wrap up this journey, it’s clear to me that we are amid a paradigm shift (no not AGI). We’re moving from an era of “bigger is better” during training to a more nuanced approach that folds in clever inference algorithms as well in which ORM and PRM models are the unsung heroes. Their quality will set the upper bound on the quality of your inference scaling strategy.
And while these techniques will enable us to use copilots capable of more nuanced approaches to the problem this shift also means what earlier used to be a heavy training related capital expense borne by the model provider will turn more into an operational expense borne by the user.