search-agents

Code for the paper 🌳 Tree Search for Language Model Agents

https://github.com/kohjingyu/search-agents

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • ✓
    CITATION.cff file
    Found CITATION.cff file
  • ✓
    codemeta.json file
    Found codemeta.json file
  • ✓
    .zenodo.json file
    Found .zenodo.json file
  • â—‹
    DOI references
  • ✓
    Academic publication links
    Links to: arxiv.org
  • â—‹
    Academic email domains
  • â—‹
    Institutional organization owner
  • â—‹
    JOSS paper metadata
  • â—‹
    Scientific vocabulary similarity
    Low similarity (16.0%) to scientific vocabulary

Keywords

agents llms machine-learning
Last synced: 6 months ago · JSON representation ·

Repository

Code for the paper 🌳 Tree Search for Language Model Agents

Basic Info
Statistics
  • Stars: 192
  • Watchers: 4
  • Forks: 22
  • Open Issues: 1
  • Releases: 1
Topics
agents llms machine-learning
Created over 1 year ago · Last pushed over 1 year ago
Metadata Files
Readme License Citation

README.md

Tree Search for Language Model Agents

[Website] [Paper]

Overview

We propose an inference-time tree search algorithm to enable language model agents to perform exploration and multi-step planning in interactive web environments. This repository demonstrates how to run our method on the VisualWebArena and WebArena benchmarks.

TODOs

  • [ ] Add other options besides gpt-4o for the value function

News

  • [07/24/2024]: Released trajectories of the gpt-4o agent.
  • [06/19/2024]: GitHub repo released.

Install

```bash

Python 3.10 or 3.11 recommended

python -m venv venv source venv/bin/activate pip install -r requirements.txt playwright install pip install -e . ```

End-to-end Evaluation on (V)WA

  1. Setup the standalone environments. Please check out this page for details.

  2. Configurate the urls for each website. First, export the DATASET to be visualwebarena: bash export DATASET=visualwebarena Then, set the URL for the websites

bash export CLASSIFIEDS="<your_classifieds_domain>:9980" export CLASSIFIEDS_RESET_TOKEN="4b61655535e7ed388f0d40a93600254c" # Default reset token for classifieds site, change if you edited its docker-compose.yml export SHOPPING="<your_shopping_site_domain>:7770" export REDDIT="<your_reddit_domain>:9999" export WIKIPEDIA="<your_wikipedia_domain>:8888" export HOMEPAGE="<your_homepage_domain>:4399"

If you want to run on the WebArena tasks instead, make sure to also set up the CMS, GitLab, and map environments, and then set their respective environment variables: bash export DATASET=webarena export SHOPPING_ADMIN="<your_e_commerce_cms_domain>:7780/admin" export GITLAB="<your_gitlab_domain>:8023" export MAP="<your_map_domain>:3000"

  1. Generate config files for each test example: bash python scripts/generate_test_data.py You will see *.json files generated in the config_files folder. Each file contains the configuration for one test example.

  2. Obtain and save the auto-login cookies for all websites: bash prepare.sh

  3. Set up API keys.

If using OpenAI models, set a valid OpenAI API key (starting with sk-) as the environment variable: export OPENAI_API_KEY=your_key

  1. Launch the evaluation. For example, to reproduce our GPT-4o + Search agent, you can run the script provided:

bash bash scripts/run_vwa_shopping_search.sh

This script will run the search agent with the default hyperparams from our paper on the full set of VWA shopping tasks. Note that the baselines that include a captioning model run on GPU by default (e.g., BLIP-2-T5XL as the captioning model will take up approximately 12GB of GPU VRAM). Similarly, the other bash scripts in scripts/ reproduce the results on the other VWA sites and the text-only WA environment.

By default, the scripts run experiments with the agents with search. If you wish to reproduce the baseline results (without search), set --agent_type prompt when executing run.py.

Running Llama-3 models

If you wish to run the Llama-3 models we have in our paper, first set up a vLLM OpenAI compatible server. Then, update the OPENAI_BASE_URL environment variable in scripts/run_llama_vwa_shopping_search.sh to reflect the URL that the vLLM server is running on. This particular script shows how to run the Llama-3 agent on the VWA shopping environment; it is otherwise very similar to the OpenAI scripts for running on the other environments.

Agent Trajectories

We release the agent trajectories and results of the gpt-4o agent (with gpt-4o as the reward function) here. They are saved in the same format specified in run.py.

Citation

If you methods or code useful, please consider citing our paper: @article{koh2024tree, title={Tree Search for Language Model Agents}, author={Koh, Jing Yu and McAleer, Stephen and Fried, Daniel and Salakhutdinov, Ruslan}, journal={arXiv preprint arXiv:2407.01476}, year={2024} }

Acknowledgements

Our code is heavily based off the VisualWebArena codebase and the WebArena codebase.

Owner

  • Name: Jing Yu Koh
  • Login: kohjingyu
  • Kind: user
  • Company: Carnegie Mellon University

ML PhD student at CMU. Previously at Google Research.

Citation (CITATION.cff)

@article{koh2024tree,
  title={Tree Search for Language Model Agents},
  author={Koh, Jing Yu and McAleer, Stephen and Fried, Daniel and Salakhutdinov, Ruslan},
  journal={arXiv preprint},
  year={2024}
}

GitHub Events

Total
  • Issues event: 2
  • Watch event: 72
  • Fork event: 6
Last Year
  • Issues event: 2
  • Watch event: 72
  • Fork event: 6

Dependencies

requirements.txt pypi
  • Farama-Notifications ==0.0.4
  • Jinja2 ==3.1.2
  • MarkupSafe ==2.1.3
  • Pillow ==10.0.1
  • PyYAML ==6.0.1
  • Pygments ==2.16.1
  • accelerate ==0.22.0
  • aiohttp ==3.8.5
  • aiolimiter ==1.1.0
  • aiosignal ==1.3.1
  • annotated-types ==0.5.0
  • anyio ==3.7.1
  • appnope ==0.1.3
  • asttokens ==2.4.0
  • async-timeout ==4.0.3
  • attrs ==23.1.0
  • backcall ==0.2.0
  • beartype ==0.12.0
  • beautifulsoup4 ==4.12.2
  • cachetools ==5.3.3
  • certifi ==2023.7.22
  • cfgv ==3.4.0
  • charset-normalizer ==3.2.0
  • click ==8.1.7
  • cloudpickle ==2.2.1
  • comm ==0.1.4
  • contourpy ==1.1.1
  • cycler ==0.12.1
  • datasets ==2.14.4
  • debugpy ==1.8.0
  • decorator ==5.1.1
  • dill ==0.3.7
  • distlib ==0.3.7
  • distro ==1.9.0
  • evaluate ==0.4.0
  • exceptiongroup ==1.1.3
  • execnet ==2.0.2
  • executing ==2.0.0
  • fastjsonschema ==2.18.1
  • filelock ==3.12.2
  • fonttools ==4.43.1
  • frozenlist ==1.4.0
  • fsspec ==2023.6.0
  • google-ai-generativelanguage ==0.6.5
  • google-api-core ==2.15.0
  • google-api-python-client ==2.133.0
  • google-auth ==2.26.1
  • google-auth-httplib2 ==0.2.0
  • google-cloud-aiplatform ==1.38.1
  • google-cloud-bigquery ==3.14.1
  • google-cloud-core ==2.4.1
  • google-cloud-resource-manager ==1.11.0
  • google-cloud-storage ==2.14.0
  • google-crc32c ==1.5.0
  • google-generativeai ==0.7.0
  • google-resumable-media ==2.7.0
  • googleapis-common-protos ==1.62.0
  • gradio_client ==0.5.2
  • greenlet ==2.0.2
  • grpc-google-iam-v1 ==0.13.0
  • grpcio ==1.64.1
  • grpcio-status ==1.62.2
  • gymnasium ==0.29.1
  • h11 ==0.14.0
  • httpcore ==0.18.0
  • httplib2 ==0.22.0
  • httpx ==0.25.0
  • huggingface-hub ==0.16.4
  • identify ==2.5.30
  • idna ==3.4
  • imageio ==2.34.1
  • iniconfig ==2.0.0
  • ipykernel ==6.25.2
  • ipython ==8.16.1
  • jedi ==0.19.1
  • joblib ==1.3.2
  • jsonschema ==4.19.1
  • jsonschema-specifications ==2023.7.1
  • jupyter_client ==8.4.0
  • jupyter_core ==5.4.0
  • kiwisolver ==1.4.5
  • lazy_loader ==0.4
  • matplotlib ==3.8.0
  • matplotlib-inline ==0.1.6
  • mpmath ==1.3.0
  • multidict ==6.0.4
  • multiprocess ==0.70.15
  • mypy ==0.991
  • mypy-extensions ==1.0.0
  • nbclient ==0.6.8
  • nbformat ==5.9.2
  • nbmake ==1.4.6
  • nest-asyncio ==1.5.8
  • networkx ==3.1
  • nltk ==3.8.1
  • nodeenv ==1.8.0
  • numpy ==1.25.2
  • openai ==1.3.5
  • opencv-python ==4.8.1.78
  • packaging ==23.1
  • pandas ==2.0.3
  • parso ==0.8.3
  • pexpect ==4.8.0
  • pickleshare ==0.7.5
  • platformdirs ==3.11.0
  • playwright ==1.37.0
  • pluggy ==1.3.0
  • pre-commit ==3.0.1
  • prompt-toolkit ==3.0.39
  • proto-plus ==1.23.0
  • protobuf ==4.24.3
  • psutil ==5.9.5
  • ptyprocess ==0.7.0
  • pure-eval ==0.2.2
  • py ==1.11.0
  • pyarrow ==12.0.1
  • pyasn1 ==0.6.0
  • pyasn1_modules ==0.4.0
  • pydantic ==2.4.2
  • pydantic_core ==2.10.1
  • pyee ==9.0.4
  • pyparsing ==3.1.1
  • pytest ==7.1.2
  • pytest-asyncio ==0.21.1
  • pytest-xdist ==3.3.1
  • python-dateutil ==2.8.2
  • pytz ==2023.3
  • pyzmq ==25.1.1
  • referencing ==0.30.2
  • regex ==2023.8.8
  • requests ==2.31.0
  • responses ==0.18.0
  • rpds-py ==0.10.6
  • rsa ==4.9
  • safetensors ==0.3.3
  • scikit-image ==0.22.0
  • scipy ==1.13.1
  • sentencepiece ==0.1.99
  • shapely ==2.0.4
  • six ==1.16.0
  • sniffio ==1.3.0
  • soupsieve ==2.5
  • stack-data ==0.6.3
  • sympy ==1.12
  • text-generation ==0.6.1
  • tifffile ==2024.5.22
  • tiktoken ==0.7.0
  • tokenizers ==0.14.0
  • tomli ==2.0.1
  • torch ==2.0.1
  • tornado ==6.3.3
  • tqdm ==4.66.1
  • traitlets ==5.11.2
  • transformers ==4.34.0
  • types-requests ==2.31.0.10
  • types-tqdm ==4.66.0.1
  • typing_extensions ==4.7.1
  • tzdata ==2023.3
  • uritemplate ==4.1.1
  • urllib3 ==2.0.4
  • virtualenv ==20.24.5
  • wcwidth ==0.2.8
  • websockets ==11.0.3
  • xxhash ==3.3.0
  • yarl ==1.9.2
setup.py pypi