Language Agent Tree Search
Language Agent Tree Search (LATS), by Zhou, et. al, is a general LLM agent search algorithm that combines reflection/evaluation and search (specifically Monte-Carlo tree search) to achieve stronger overall task performance by leveraging inference-time compute.
It has four main phases consisting of six steps:
- Select: pick the best next state to progress from, based on its aggregate value.
- Expand and simulate: sample n potential actions to take and execute them in parallel.
- Reflect + Evaluate: observe the outcomes of these actions and score the decisions based on reflection (and possibly external feedback if available)
- Backpropagate: update the scores of the root trajectories based on the outcomes.
import json
import logging
import os
import uuid
from typing import Any, Dict, List
from autogen import AssistantAgent, ConversableAgent, GroupChat, UserProxyAgent, config_list_from_json
Configure logging
logging.basicConfig(level=logging.INFO)
Set environment variables
os.environ["AUTOGEN_USE_DOCKER"] = "0" # Disable Docker usage globally for Autogen
os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"
Prerequisites
Install autogen
(for the LLM framework and agents)
Required packages: autogen
Please ensure these packages are installed before running this script
Directly create the config_list with the API key
config_list = [{"model": "gpt-4o-mini", "api_key": "YOUR_API_KEY"}]
if not config_list:
raise ValueError("Failed to create configuration. Please check the API key.")
Reflection Class
The reflection chain will score agent outputs based on the decision and the tool responses.
from pydantic import BaseModel, Field
class Reflection(BaseModel):
reflections: str = Field(
description="The critique and reflections on the sufficiency, superfluency,"
" and general quality of the response"
)
score: int = Field(
description="Score from 0-10 on the quality of the candidate response.",
gte=0,
lte=10,
)
found_solution: bool = Field(description="Whether the response has fully solved the question or task.")
def as_message(self):
return {"role": "human", "content": f"Reasoning: {self.reflections}\nScore: {self.score}"}
@property
def normalized_score(self) -> float:
return self.score / 10.0
Tree State
LATS is based on a (greedy) Monte-Carlo tree search. For each search steps, it picks the node with the highest “upper confidence bound”, which is a metric that balances exploitation (highest average reward) and exploration (lowest visits). Starting from that node, it generates N (5 in this case) new candidate actions to take, and adds them to the tree. It stops searching either when it has generated a valid solution OR when it has reached the maximum number of rollouts (search tree depth).
import math
import os
from collections import deque
from typing import Optional
class Node:
def __init__(
self,
messages: List[Dict[str, str]],
reflection: Optional[Reflection] = None,
parent: Optional["Node"] = None,
):
self.messages = messages
self.parent = parent
self.children: List["Node"] = []
self.value = 0.0
self.visits = 0
self.reflection = reflection
self.depth = parent.depth + 1 if parent is not None else 1
self._is_solved = reflection.found_solution if reflection else False
if self._is_solved:
self._mark_tree_as_solved()
if reflection:
self.backpropagate(reflection.normalized_score)
def __repr__(self) -> str:
return (
f"<Node value={self.value:.2f}, visits={self.visits}," f" depth={self.depth}, is_solved={self._is_solved}>"
)
@property
def is_solved(self) -> bool:
"""If any solutions exist, we can end the search."""
return self._is_solved
@property
def is_terminal(self):
return not self.children
@property
def best_child(self):
"""Select the child with the highest UCT to search next."""
if not self.children:
return None
all_nodes = self._get_all_children()
return max(all_nodes, key=lambda child: child.upper_confidence_bound())
@property
def best_child_score(self):
"""Return the child with the highest value."""
if not self.children:
return None
return max(self.children, key=lambda child: int(child.is_solved) * child.value)
@property
def height(self) -> int:
"""Check for how far we've rolled out the tree."""
if self.children:
return 1 + max([child.height for child in self.children])
return 1
def upper_confidence_bound(self, exploration_weight=1.0):
"""Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
if self.parent is None:
raise ValueError("Cannot obtain UCT from root node")
if self.visits == 0:
return self.value
# Encourages exploitation of high-value trajectories
average_reward = self.value / self.visits
exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
return average_reward + exploration_weight * exploration_term
def backpropagate(self, reward: float):
"""Update the score of this node and its parents."""
node = self
while node:
node.visits += 1
node.value = (node.value * (node.visits - 1) + reward) / node.visits
node = node.parent
def get_messages(self, include_reflections: bool = True):
if include_reflections and self.reflection:
return self.messages + [self.reflection.as_message()]
return self.messages
def get_trajectory(self, include_reflections: bool = True) -> List[Dict[str, str]]:
"""Get messages representing this search branch."""
messages = []
node = self
while node:
messages.extend(node.get_messages(include_reflections=include_reflections)[::-1])
node = node.parent
# Reverse the final back-tracked trajectory to return in the correct order
return messages[::-1] # root solution, reflection, child 1, ...
def _get_all_children(self):
all_nodes = []
nodes = deque()
nodes.append(self)
while nodes:
node = nodes.popleft()
all_nodes.extend(node.children)
for n in node.children:
nodes.append(n)
return all_nodes
def get_best_solution(self):
"""Return the best solution from within the current sub-tree."""
all_nodes = [self] + self._get_all_children()
best_node = max(
all_nodes,
# We filter out all non-terminal, non-solution trajectories
key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
)
return best_node
def _mark_tree_as_solved(self):
parent = self.parent
while parent:
parent._is_solved = True
parent = parent.parent
The main component is the tree, represented by the root node.
from typing_extensions import TypedDict
class TreeState(TypedDict):
# The full tree
root: Node
# The original input
input: str
Define Language Agent
Our agent will have three primary LLM-powered processes:
- Reflect: score the action based on the tool response.
- Initial response: to create the root node and start the search.
- Expand: generate 5 candidate “next steps” from the best spot in the current tree
For more “Grounded” tool applications (such as code synthesis), you could integrate code execution into the reflection/reward step. This type of external feedback is very useful.
Tools
For our example, we will give the language agent a search engine.
Define the UserProxyAgent with web search / tool-use capability
user_proxy = UserProxyAgent(
name="user",
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={
"work_dir": "web",
"use_docker": False,
},
)
Create a ConversableAgent without tools
assistant_agent = ConversableAgent(
name="assistant_agent",
system_message="You are an AI assistant capable of helping with various tasks.",
human_input_mode="NEVER",
code_execution_config=False,
)
Reflection
Self-reflection allows the agent to boostrap, improving its future responses based on the outcome of previous ones. In agents this is more powerful since it can use external feedback to improve.
reflection_prompt = """
Reflect and grade the assistant response to the user question below.
User question: {input}
Assistant response: {candidate}
Provide your reflection in the following format:
Reflections: [Your detailed critique and reflections]
Score: [A score from 0-10]
Found Solution: [true/false]
"""
reflection_agent = AssistantAgent(
name="reflection_agent",
system_message="You are an AI assistant that reflects on and grades responses.",
llm_config={
"config_list": config_list,
"temperature": 0.2,
},
)
def reflection_chain(inputs: Dict[str, Any]) -> Reflection:
try:
candidate_content = ""
if "candidate" in inputs:
candidate = inputs["candidate"]
if isinstance(candidate, list):
candidate_content = (
candidate[-1]["content"]
if isinstance(candidate[-1], dict) and "content" in candidate[-1]
else str(candidate[-1])
)
elif isinstance(candidate, dict):
candidate_content = candidate.get("content", str(candidate))
elif isinstance(candidate, str):
candidate_content = candidate
else:
candidate_content = str(candidate)
formatted_prompt = [
{"role": "system", "content": "You are an AI assistant that reflects on and grades responses."},
{
"role": "user",
"content": reflection_prompt.format(input=inputs.get("input", ""), candidate=candidate_content),
},
]
response = reflection_agent.generate_reply(formatted_prompt)
# Parse the response
response_str = str(response)
lines = response_str.split("\n")
reflections = next((line.split(": ", 1)[1] for line in lines if line.startswith("Reflections:")), "")
score_str = next((line.split(": ", 1)[1] for line in lines if line.startswith("Score:")), "0")
try:
if "/" in score_str:
numerator, denominator = map(int, score_str.split("/"))
score = int((numerator / denominator) * 10)
else:
score = int(score_str)
except ValueError:
logging.warning(f"Invalid score value: {score_str}. Defaulting to 0.")
score = 0
found_solution = next(
(line.split(": ", 1)[1].lower() == "true" for line in lines if line.startswith("Found Solution:")), False
)
if not reflections:
logging.warning("No reflections found in the response. Using default values.")
reflections = "No reflections provided."
return Reflection(reflections=reflections, score=score, found_solution=found_solution)
except Exception as e:
logging.error(f"Error in reflection_chain: {str(e)}", exc_info=True)
return Reflection(reflections=f"Error in reflection: {str(e)}", score=0, found_solution=False)
Initial Response
We start with a single root node, generated by this first step. It responds to the user input either with a tool invocation or a response.
Create Autogen agents
assistant = AssistantAgent(name="assistant", llm_config={"config_list": config_list}, code_execution_config=False)
user = UserProxyAgent(
name="user",
human_input_mode="NEVER",
max_consecutive_auto_reply=10,
code_execution_config={"work_dir": "web", "use_docker": False},
)
Define a function to create the initial prompt
def create_initial_prompt(input_text):
return [
{"role": "system", "content": "You are an AI assistant."},
{"role": "user", "content": input_text},
]
Function to generate initial response
def generate_initial_response(state: TreeState) -> TreeState:
chat_messages = create_initial_prompt(state["input"])
try:
# Ensure chat_messages is a list of dictionaries
if not isinstance(chat_messages, list):
chat_messages = [{"role": "user", "content": chat_messages}]
logging.info(f"Generating initial response for input: {state['input']}")
logging.debug(f"Chat messages: {chat_messages}")
response = assistant.generate_reply(chat_messages)
logging.debug(f"Raw response from assistant: {response}")
# Ensure response is properly formatted as a string
if isinstance(response, str):
content = response
elif isinstance(response, dict) and "content" in response:
content = response["content"]
elif isinstance(response, list) and len(response) > 0:
content = response[-1].get("content", str(response[-1]))
else:
content = str(response)
content = content.strip()
if not content:
raise ValueError("Generated content is empty after processing")
logging.debug(f"Processed content: {content[:100]}...") # Log first 100 chars
# Generate reflection
reflection_input = {"input": state["input"], "candidate": content}
logging.info("Generating reflection on the initial response")
reflection = reflection_chain(reflection_input)
logging.debug(f"Reflection generated: {reflection}")
# Create Node with messages as a list containing a single dict
messages = [{"role": "assistant", "content": content}]
root = Node(messages=messages, reflection=reflection)
logging.info("Initial response and reflection generated successfully")
return TreeState(root=root, input=state["input"])
except Exception as e:
logging.error(f"Error in generate_initial_response: {str(e)}", exc_info=True)
return TreeState(root=None, input=state["input"])
Example usage of the generate_initial_response function
initial_prompt = "Why is the sky blue?"
initial_state = TreeState(input=initial_prompt, root=None)
result_state = generate_initial_response(initial_state)
if result_state["root"] is not None:
print(result_state["root"].messages[0]["content"])
else:
print("Failed to generate initial response.")
Starting Node
We will package up the candidate generation and reflection in a single node of our graph. This is represented by the following function:
Define the function to generate the initial response
# Define the function to generate the initial response
def generate_initial_response(state: TreeState) -> TreeState:
"""Generate the initial candidate response using Autogen components."""
assistant = AssistantAgent(name="assistant", llm_config={"config_list": config_list}, code_execution_config=False)
# Generate initial response
initial_message = [
{"role": "system", "content": "You are an AI assistant."},
{"role": "user", "content": state["input"]},
]
try:
logging.info(f"Generating initial response for input: {state['input']}")
response = assistant.generate_reply(initial_message)
logging.debug(f"Raw response from assistant: {response}")
# Ensure response is properly formatted as a string
if isinstance(response, str):
content = response
elif isinstance(response, dict):
content = response.get("content", "")
if not content:
content = json.dumps(response)
elif isinstance(response, list):
content = " ".join(str(item) for item in response)
else:
content = str(response)
# Ensure content is always a string and not empty
content = content.strip()
if not content:
raise ValueError("Generated content is empty after processing")
logging.debug(f"Final processed content (first 100 chars): {content[:100]}...")
# Generate reflection
logging.info("Generating reflection on the initial response")
reflection_input = {"input": state["input"], "candidate": content}
reflection = reflection_chain(reflection_input)
logging.debug(f"Reflection generated: {reflection}")
if not isinstance(reflection, Reflection):
raise TypeError(f"Invalid reflection type: {type(reflection)}. Expected Reflection, got {type(reflection)}")
# Create Node with messages as a list containing a single dict
messages = [{"role": "assistant", "content": content}]
logging.debug(f"Creating Node with messages: {messages}")
root = Node(messages=messages, reflection=reflection)
logging.info("Initial response and reflection generated successfully")
logging.debug(f"Created root node: {root}")
return TreeState(root=root, input=state["input"])
except Exception as e:
logging.error(f"Error in generate_initial_response: {str(e)}", exc_info=True)
return TreeState(root=None, input=state["input"])
Candidate Generation
The following code prompts the same LLM to generate N additional candidates to check.
This generates N candidate values for a single input to sample actions from the environment
def generate_candidates(messages: list, config: dict):
n = config.get("N", 5)
assistant = AssistantAgent(name="assistant", llm_config={"config_list": config_list}, code_execution_config=False)
candidates = []
for _ in range(n):
try:
# Use the assistant to generate a response
last_message = messages[-1]["content"] if messages and isinstance(messages[-1], dict) else str(messages[-1])
response = assistant.generate_reply([{"role": "user", "content": last_message}])
if isinstance(response, str):
candidates.append(response)
elif isinstance(response, dict) and "content" in response:
candidates.append(response["content"])
elif (
isinstance(response, list) and response and isinstance(response[-1], dict) and "content" in response[-1]
):
candidates.append(response[-1]["content"])
else:
candidates.append(str(response))
except Exception as e:
logging.error(f"Error generating candidate: {str(e)}")
candidates.append("Failed to generate candidate.")
if not candidates:
logging.warning("No candidates were generated.")
return candidates
expansion_chain = generate_candidates
Candidate generation node
We will package the candidate generation and reflection steps in the following “expand” node. We do all the operations as a batch process to speed up execution.
def expand(state: TreeState, config: Dict[str, Any]) -> dict:
root = state["root"]
best_candidate: Node = root.best_child if root.children else root
messages = best_candidate.get_trajectory()
# Generate N candidates using Autogen's generate_candidates function
new_candidates = generate_candidates(messages, config)
# Reflect on each candidate using Autogen's AssistantAgent
reflections = []
for candidate in new_candidates:
reflection = reflection_chain({"input": state["input"], "candidate": candidate})
reflections.append(reflection)
# Grow tree
child_nodes = [
Node([{"role": "assistant", "content": candidate}], parent=best_candidate, reflection=reflection)
for candidate, reflection in zip(new_candidates, reflections)
]
best_candidate.children.extend(child_nodes)
# We have already extended the tree directly, so we just return the state
return state
Create Tree
With those two nodes defined, we are ready to define the tree. After each agent step, we have the option of finishing.
from typing import Any, Dict, Literal
def should_loop(state: Dict[str, Any]) -> Literal["expand", "end"]:
"""Determine whether to continue the tree search."""
root = state["root"]
if root.is_solved:
return "end"
if root.height > 5:
return "end"
return "expand"
def run_lats(input_query: str, max_iterations: int = 10):
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
state = {"input": input_query, "root": None}
try:
state = generate_initial_response(state)
if not isinstance(state, dict) or "root" not in state or state["root"] is None:
logger.error("Initial response generation failed or returned invalid state")
return "Failed to generate initial response."
logger.info("Initial response generated successfully")
except Exception as e:
logger.error(f"Error generating initial response: {str(e)}", exc_info=True)
return "Failed to generate initial response due to an unexpected error."
for iteration in range(max_iterations):
action = should_loop(state)
if action == "end":
logger.info(f"Search ended after {iteration + 1} iterations")
break
try:
state = expand(
state,
{
"N": 5,
"input_query": input_query,
},
)
logger.info(f"Completed iteration {iteration + 1}")
except Exception as e:
logger.error(f"Error during iteration {iteration + 1}: {str(e)}", exc_info=True)
continue
if not isinstance(state, dict) or "root" not in state or state["root"] is None:
return "No valid solution found due to an error in the search process."
solution_node = state["root"].get_best_solution()
best_trajectory = solution_node.get_trajectory(include_reflections=False)
if not best_trajectory:
return "No solution found in the search process."
result = (
best_trajectory[-1].get("content") if isinstance(best_trajectory[-1], dict) else str(best_trajectory[-1])
)
logger.info("LATS search completed successfully")
return result
except Exception as e:
logger.error(f"An unexpected error occurred during LATS execution: {str(e)}", exc_info=True)
return f"An unexpected error occurred: {str(e)}"
Example usage:
result = run_lats(“Write a research report on deep learning.”)
print(result)
Example usage of the LATS algorithm with Autogen
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def run_lats_example(question):
try:
logger.info(f"Processing question: {question}")
result = run_lats(question)
logger.info(f"LATS algorithm completed. Result: {result[:100]}...") # Log first 100 chars of result
print(f"Question: {question}")
print(f"Answer: {result}")
except Exception as e:
logger.error(f"An error occurred while processing the question: {str(e)}", exc_info=True)
print(f"An error occurred: {str(e)}")
finally:
print("---")
List of example questions
questions = [
"Explain how epigenetic modifications can influence gene expression across generations and the implications for evolution.",
"Discuss the challenges of grounding ethical theories in moral realism, especially in light of the is-ought problem introduced by Hume.",
"How does the Riemann Hypothesis relate to the distribution of prime numbers, and why is it significant in number theory?",
"Describe the challenges and theoretical underpinnings of unifying general relativity with quantum mechanics, particularly focusing on string theory and loop quantum gravity.",
]
Run LATS algorithm for each question
for i, question in enumerate(questions, 1):
print(f"\nExample {i}:")
run_lats_example(question)
logger.info("All examples processed.")
Conclusion
Congrats on implementing LATS! This is a technique that can be reasonably fast and effective at solving complex agent tasks. A few notes that you probably observed above:
-
While LATS is effective, the tree rollout process can require additional inference compute time. If you plan to integrate this into a production application, consider streaming intermediate steps to allow users to see the thought process and access intermediate results. Alternatively, you could use it to generate fine-tuning data to enhance single-shot accuracy and avoid lengthy rollouts. The cost of using LATS has significantly decreased since its initial proposal and is expected to continue decreasing.
-
The effectiveness of the candidate selection process depends on the quality of the rewards generated. In this example, we exclusively use self-reflection as feedback, but if you have access to external feedback sources (such as code test execution), those should be incorporated as suggested above.