Stream of Search: Teaching Language Models the Language of Search
Generative AI needs to learn from mistakes
Language models have demonstrated remarkable abilities in various tasks, from natural language processing to code generation. However, these models often struggle with complex problem-solving that requires planning, searching, and backtracking. A recent paper by Gandhi et al. (2024) introduces a novel approach called Stream of Search (SoS), which aims to teach language models to solve problems by searching in language, without relying on external components.
The Stream of Search framework systematizes the elements of search into a unified language, enabling the representation of diverse search strategies in a common format. By training language models on these "streams of search," the authors demonstrate that the models can learn to solve problems more effectively than those trained solely on optimal solution trajectories. Furthermore, the SoS models can self-improve by optimizing for correctness using reinforcement learning techniques such as STaR and APA.
In this article, we will delve into the Stream of Search framework, exploring its key components, the problem setup, and the experimental results that showcase its potential for enhancing the problem-solving capabilities of language models.
The Language of Search
At the core of the Stream of Search framework is a vocabulary of primitive operations that define the components of various search algorithms.
These operations include:
Current State (sc): The state being explored.
Goal State (sg): The target state.
State Queue (Sq): The states at the frontier of the trajectory that haven't been explored yet.
State Expansion Function (SE): A function that explores a state adjacent to the current state based on a transition function.
Exploration Choice: Choosing the order of states to explore (e.g., breadth-first search or depth-first search).
Pruning: Discarding states or subtrees that are unlikely to lead to a solution.
Backtracking: Moving between explored nodes to choose the next state for expansion.
Goal Check: Checking if the current state is the goal state.
Heuristic (h): A function that approximates the distance of the current state from the goal state, guiding the search process.
By representing these operations in language, the Stream of Search framework allows for the creation of a dataset with diverse search strategies. Some operations, such as the current state, goal state, backtracking, goal checks, and exploration choices, are explicitly represented in the search trajectory. Others, like heuristic functions, state values, and pruning strategies, are kept implicit, encouraging the model to internalize abstract representations that can be improved through training.
![](https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Ffe964225-47c1-4fec-a6bc-081e6af71b4c_860x300.png)
Problem Setup: The Game of Countdown
To demonstrate the utility of the Stream of Search framework, the authors focus on a generalization of the 24 Game called Countdown. In this game, a set of input numbers must be combined using simple arithmetic operations to reach a target number. Countdown presents a challenging search problem due to its high branching factor, requiring planning, search, and backtracking to solve.
The authors construct a synthetic dataset of 500,000 search trajectories using a set of diverse and suboptimal symbolic search strategies based on breadth-first search (BFS) and depth-first search (DFS) with two simple heuristic functions. The heuristics used are:
The absolute difference between the sum of the remaining options and the target.
The distance to the factors of the target.
The search trajectories are serialized as strings, representing a list of tree nodes or states in the order of traversal. The dataset is then used to train language models in two conditions:
Optimal Paths (OP): The model is trained to predict the correct and optimal path for all problems in the dataset.
Stream of Search (SoS): The model is trained on search trajectories sampled from different search strategies.
![](https://substackcdn.com/image/fetch/w_1456,c_limit,f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F75ce4c63-66ed-451f-bae8-baec6fdb7ea3_823x511.png)
Experimental Results
The authors train a GPT-Neo model with 250M parameters on the Countdown dataset and evaluate its performance on held-out problems. The results demonstrate that the model trained on streams of search (SoS) outperforms the model trained on optimal solutions (OP). The SoS model achieves an accuracy of 51.27% on held-out inputs, compared to 25.73% for the OP model. This finding highlights the importance of exposing models to the messy process of problem-solving, including exploration and backtracking, rather than only the ideal solution steps.
To understand the strategies employed by the trained SoS model, the authors measure the alignment of the model-generated search trajectories with symbolic strategies. They find that the SoS model does not predominantly use any single strategy from its training data but instead exhibits a higher correlation with strategies that use the sum heuristic.
Policy Improvement with Stream of Search
The authors further investigate whether the SoS model can self-improve with feedback based on correctness and efficiency. They employ two reinforcement learning strategies: expert iteration using STaR (Self-Taught Reasoner) and Advantage-Induced Policy Alignment (APA).
In the STaR approach, the authors use problems from the training dataset to generate 100,000 correct trajectories. These trajectories are then used to fine-tune the model iteratively until convergence in performance on the validation set is observed. After three iterations of STaR fine-tuning, the SoS+STaR model solves an additional 5% of the held-out inputs test set compared to the base SoS model.
Alternatively, the authors use APA, an Actor-Critic reinforcement learning technique that involves creating a copy of the language model to serve as a value network, which is then used to enhance the policy (the original language model). A straightforward reward function is defined, considering the correctness and length of the generated trajectory. The authors observe that updating the reference policy whenever the validation reward converges results in further policy improvement. After fine-tuning with APA, the SoS model achieves an improvement of about 6% over the base SoS model.
Analysis of the fine-tuned models reveals that both the STaR and APA models visit more states associated with the 'multiply' heuristic, which measures distance to the factors of the target. The APA model, in particular, diverges more from the symbolic strategies, indicating that it employs different strategies for searching and potentially discovers novel heuristics and search methods.
To further evaluate the performance of the improved models, the authors select 10,000 problems from the SoS training set that were unsolved by symbolic strategies and 10,000 difficult problems that none of the symbolic strategies used to train the SoS models can solve. Remarkably, the models solve approximately 36% of the previously unsolved problems and about 4% of the difficult problems.
Discussion and Future Directions
The SoS framework introduces a new approach to teaching language models to solve problems by searching in language, without relying on external components. By systematizing the elements of search into a unified language, the authors demonstrate that training language models on diverse streams of search leads to superior performance compared to models trained solely on optimal trajectories.
This addresses criticisms of language models for planning and problem-solving, such as the snowballing of errors and difficulty in lookahead tasks. By teaching models to backtrack and explore alternative paths, SoS enables language models to consider multiple possible outcomes before committing to a course of action. Crucially, SoS leads language models to learn an internal 'world model' for search, allowing for more adaptable and generalizable search compared to symbolic search that relies on an explicit environment model.
While the empirical results in the paper are restricted to the game of Countdown, the authors are optimistic that the SoS framework can be extended to more challenging, real-world tasks. Future research could explore integrating subgoals and hierarchical planning, as well as incorporating reflection and self-evaluation to enable models to discover and improve novel search strategies.
Generating the initial SoS dataset can be challenging, as it is not always feasible to create symbolic search algorithms to solve problems. An important question for future research is how well search abilities transfer between domains and between formal and informal domains.
Conclusion
SoS is a promising step forward in teaching language models to solve problems through structured search with backtracking, heuristic state evaluation, and world modeling. By exposing language models to diverse search strategies and iteratively refining them, the approach unlocks the potential of these models to tackle complex problems and discover new ways to solve them.
Frameworks like this will play a crucial role in enhancing the problem-solving capabilities of language models. By embracing the messy process of exploration and backtracking, and learning from productive mistakes, language models can develop more robust and adaptable search strategies, paving the way for their application in a wide range of real-world problem-solving scenarios.
AI that isn’t able to explore - and learn from mistakes - will always be limited by design. Unless we find a way to create flawless and complete training data for everything that we want it to do in this world - which, of course, is extremely unlikely.
👍 If you enjoyed this article, give it a like and share it with your peers.
And in case you want to continue reading, here’s my previous research summary on StructRAG, a framework that combines the best of both worlds - graph-based and standard RAG: