docprompting

Data and code for "DocPrompting: Generating Code by Retrieving the Docs" @ICLR 2023

https://github.com/shuyanzhou/docprompting

Science Score: 54.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.6%) to scientific vocabulary

Keywords

code-generation natural-language-processing nl-to-code reading-comprehension
Last synced: 6 months ago · JSON representation ·

Repository

Data and code for "DocPrompting: Generating Code by Retrieving the Docs" @ICLR 2023

Basic Info
  • Host: GitHub
  • Owner: shuyanzhou
  • License: apache-2.0
  • Language: Python
  • Default Branch: main
  • Homepage:
  • Size: 47.6 MB
Statistics
  • Stars: 243
  • Watchers: 9
  • Forks: 19
  • Open Issues: 6
  • Releases: 0
Topics
code-generation natural-language-processing nl-to-code reading-comprehension
Created over 3 years ago · Last pushed about 2 years ago
Metadata Files
Readme License Citation

README.md

DocPrompting: Generating Code by Retrieving the Docs

This is the official implementation of

Shuyan Zhou, Uri Alon, Frank F. Xu, Zhiruo Wang, Zhengbao Jiang, Graham Neubig, "DocPrompting: Generating Code by Retrieving the Docs", ICLR'2023 (Spotlight)

January 2023 - The paper was accepted to ICLR'2023 as a Spotlight!


Publicly available source-code libraries are continuously growing and changing. This makes it impossible for models of code to keep current with all available APIs by simply training these models on existing code repositories. We introduce DocPrompting: a natural-language-to-code generation approach that explicitly leverages documentation by 1. retrieving the relevant documentation pieces given an NL intent, and 2. generating code based on the NL intent and the retrieved documentation.

In this repository we provide the best model in each setting described in the paper.

overview

Table of content


Huggingface 🤗 Dataset & Evaluation

In this work, we introduce a new natural language to bash generation benchmark tldr and re-split CoNaLa to have unseen functions on the dev and test set. The datasets and the corresponding evaluations are available on huggingface * tldr and eval * CoNaLa and eval ```python import datasets import evaluate tldr = datasets.loaddataset('neulab/tldr') tldrmetric = evaluate.load('neulab/tldr_eval')

conala = datasets.loaddataset('neulab/docprompting-conala') conalametric = evaluate.load('neulab/python_bleu') ```

Huggingface 🤗 Models

We make the following models available on Huggingface:

  • neulab/docprompting-tldr-gpt-neo-125M
  • neulab/docprompting-tldr-gpt-neo-1.3B

Example usage

```python from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.frompretrained("neulab/docprompting-tldr-gpt-neo-1.3B") model = AutoModelForCausalLM.frompretrained("neulab/docprompting-tldr-gpt-neo-1.3B")

prompt template

prompt = f"""{tokenizer.bostoken} Potential manual 0: makepkg - package build utility Potential manual 1: -c, --clean Clean up leftover work files and directories after a successful build. Potential manual 2: -r, --rmdeps Upon successful build, remove any dependencies installed by makepkg during dependency auto-resolution and installation when using -s Potential manual 3: CONTENTOFTHEMANUAL3 ... Potential manual 10: CONTENTOFTHEMANUAL10""" prompt += f"{tokenizer.septoken} clean up work directories after a successful build {tokenizer.sep_token}"

inputids = tokenizer(prompt, returntensors="pt").inputids gentokens = model.generate( inputids, numbeams=5, maxnewtokens=150, numreturnsequences=2, padtokenid=tokenizer.eostokenid ) gentokens = gentokens.reshape(1, -1, gen_tokens.shape[-1])[0][0]

to text and clean

gencode = tokenizer.decode(gentokens) gencode = gencode.split(tokenizer.septoken)[2].strip().split(tokenizer.eostoken)[0].strip() print(gen_code)

makepkg --clean {{path/to/directory}} ```

Example script

An example script on tldr by using the retrieved docs is here

Other models

Other models require the customized implementations in our repo, please read through the corresponding sections to use them. These models are: 1. sparse retriever based on BM25 for tldr 2. dense retriever based on CodeT5 for CoNaLa 3. FiD T5 generator for tldr 4. FiD CodeT5 generator for CoNaLa


The following instructions are for reproducing the results in the paper.

Preparation

Download data for CoNaLa and tldr from link ```bash

unzip

unzip docprompting_data.zip

move to the data folder

mv docprompting_data/* data ```

Download trained generator weights from link ```bash unzip docpromptinggeneratormodels.zip

move to the model folder

mv docpromptinggeneratormodels/* models/generator

```

Retrieval

Dense retrieval

(CoNaLa as an example)

The code is based on SimCSE

  1. Run inference with our trained model on CoNaLa (Python) bash python retriever/simcse/run_inference.py \ --model_name "neulab/docprompting-codet5-python-doc-retriever" \ --source_file data/conala/conala_nl.txt \ --target_file data/conala/python_manual_firstpara.tok.txt \ --source_embed_save_file data/conala/.tmp/src_embedding \ --target_embed_save_file data/conala/.tmp/tgt_embedding \ --sim_func cls_distance.cosine \ --num_layers 12 \ --save_file data/conala/retrieval_results.json We observed that model whether or not to normalize the embeddings can affect the retrieval results. We therefore selected this hyper-parameter (--normalize_embed) on the validation set.

The results will be saved to data/conala/retrieval_results.json.

  1. Train your own retriever bash python retriever/simcse/run_train.py \ --num_layers 12 \ --model_name_or_path Salesforce/codet5-base \ --sim_func cls_distance.cosine \ --temp 0.05 \ --train_file data/conala/train_retriever_sup_unsup.json \ --eval_file data/conala/dev_retriever.json \ --output_dir models/retriever/docprompting_codet5_python_doc_retriever \ --eval_src_file data/conala/conala_nl.txt \ --eval_tgt_file data/conala/python_manual_firstpara.tok.txt \ --eval_root_folder data/conala \ --eval_oracle_file data/conala/cmd_dev.oracle_man.full.json \ --run_name docprompting_codet5_python_doc_retriever \ --num_train_epochs 10 \ --per_device_train_batch_size 512 \ --learning_rate 1e-5 \ --max_seq_length 32 \ --evaluation_strategy steps \ --metric_for_best_model recall@10 \ --load_best_model_at_end \ --eval_steps 125 \ --overwrite_output_dir \ --do_train \ --eval_form retrieval "$@"
  2. train_retriever_sup_unsup.json contains the supervised (CoNaLa training and mined) and unsupervised data (duplication of sentences in a doc) for training the retriever.
  3. Be accurate on the saved model name. If using codet5, make sure codet5 is in the name.

Sparse retrieval

(tldr as an example)

There are two stages in the retrieval procedure in tldr. The first stage retrieves the bash command and the second stage retrieves the potentially relevant paragraphs that describe the usage of the arguments 1. build index with Elasticsearch bash python retriever/bm25/main.py \ --retrieval_stage 0 2. first stage retrieval bash python retriever/bm25/main.py \ --retrieval_stage 1 \ --split {cmd_train, cmd_dev, cmd_test} 3. second stage retrieval bash python retriever/bm25/main.py \ --retrieval_stage 2 \ --split {cmd_train, cmd_dev, cmd_test}


Generation

FID generation

The code is based on FiD A training or evaluation file should be converted to the format compatible with FiD. An example is here

Important note: FiD has a strong dependency on the version of transformers (3.0.2). Unable to match the version might result in inreproducible results. 1. Run generation. Here is an example with our trained model on Python CoNaLa bash ds='conala' python generator/fid/test_reader_simple.py \ --model_path models/generator/${ds}.fid.codet5.top10/checkpoint/best_dev \ --tokenizer_name models/generator/codet5-base \ --eval_data data/${ds}/fid.cmd_test.codet5.t10.json \ --per_gpu_batch_size 8 \ --n_context 10 \ --name ${ds}.fid.codet5.top10 \ --checkpoint_dir models/generator \ --result_tag test_same \ --main_port 81692 The results will be saved to models/generator/{name}/test_results_test_same.json

To evaluate pass@k, we need more generations, we use nucleus sampling (instead of beam search) for the generation. bash ds='conala' t=1.0 # set this from 0.2, 0.4, 0.6, .. 1.0. Use the dev set to find the best temperature python generator/fid/test_reader_simple.py \ --model_path models/generator/${ds}.fid.codet5.top10/checkpoint/best_dev \ --tokenizer_name models/generator/codet5-base \ --eval_data data/${ds}/fid.cmd_test.codet5.t10.ns200.json \ --per_gpu_batch_size 8 \ --n_context 10 \ --name ${ds}.fid.codet5.top10.ns200 \ --checkpoint_dir models/generator \ --result_tag test_same \ --num_beams 1 \ --temperature $t \ --top_p 0.95 \ --num_return_sequences 200 \ --main_port 81692 Then run this script bash python dataset_helper/conala/execution_eval.py --result_file data/${ds}/fid.cmd_test.codet5.t10.ns200.json

  1. Train your own generator ```bash ds='conala' python generator/fid/trainreader.py \ --seed 1996 \ --traindata data/${ds}/fid.cmdtrain.codet5.t10.json \ --evaldata data/${ds}/fid.cmddev.codet5.t10.json \ --modelname models/generator/codet5-base \ # initialize with the codet5-base model \ --pergpubatchsize 4 \ --ncontext 10 \ --name ${ds}.fid.codet5.top10 \ --checkpointdir models/generator/ \ --evalfreq 500 \ --accumulationsteps 2 \ --mainport 30843 \ --totalsteps 20000 \ --warmupsteps 2000

ds='tldr' python generator/fid/trainreader.py \ --dataset tldr \ --traindata data/${ds}/fid.cmdtrain.codet5.t10.json \ --evaldata data/${ds}/fid.cmdmodelselect.codet5.t10.json \ --modelname models/generator/codet5-base \ --pergpubatchsize 4 \ --ncontext 10 \ --evalmetric tokenf1 \ --name ${ds}.fid.codet5.top10 \ --checkpointdir models/generator/ \ --evalfreq 1000 \ --accumulationsteps 2 \ --mainport 32420 \ --totalsteps 20000 \ --warmupsteps 2000 `` * Examples infid.cmdmodelselect.codet5.t10.jsonare the same asfid.cmddev.codet5.t10.json`.

The difference is that it use the oracle first stage retrieval results (oracle bash name).

Data

The data folder contains the two benchmarks we curated or re-splitted. * tldr * CoNaLa

On each dataset, we provide 1. Natural language intent (entry nl) 2. Oracle code (entry cmd) * Bash for tldr * Python for CoNaLa 3. Oracle docs (entry oracle_man) * In the data files, we only provide the manual ids, their contents could be found in the {dataset}/{dataset}_docs.json. 4. Other data with different format for different modules

Resources

  • tldr Github repo. Thanks for all the contributors!
  • CoNaLa

Citation

@inproceedings{zhou23docprompting, title = {DocPrompting: Generating Code by Retrieving the Docs}, author = {Shuyan Zhou and Uri Alon and Frank F. Xu and Zhiruo Wang and Zhengbao Jiang and Graham Neubig}, booktitle = {International Conference on Learning Representations (ICLR)}, address = {Kigali, Rwanda}, month = {May}, url = {https://arxiv.org/abs/2207.05987}, year = {2023} }

Owner

  • Name: Shuyan Zhou
  • Login: shuyanzhou
  • Kind: user
  • Location: Pittsburgh, PA

PhD student @ LTI CMU

Citation (CITATION.cff)

@article{zhou2022doccoder,
  title={DocCoder: Generating Code by Retrieving and Reading Docs},
  author={Zhou, Shuyan and Alon, Uri and Xu, Frank F and JIang, Zhengbao and Neubig, Graham},
  journal={arXiv preprint arXiv:2207.05987},
  year={2022}
}

GitHub Events

Total
  • Issues event: 2
  • Watch event: 15
  • Issue comment event: 1
  • Fork event: 3
Last Year
  • Issues event: 2
  • Watch event: 15
  • Issue comment event: 1
  • Fork event: 3

Dependencies

generator/fid/requirements.txt pypi
  • faiss-cpu *
  • numpy *
  • tensorboard *
  • torch *
  • transformers ==3.0.2
requirements.txt pypi
  • aiohttp ==3.8.1
  • aiosignal ==1.2.0
  • alabaster ==0.7.12
  • anaconda-client ==1.7.2
  • anaconda-project ==0.9.1
  • analytics-python ==1.4.0
  • anyio ==3.5.0
  • appdirs ==1.4.4
  • argh ==0.26.2
  • argon2-cffi ==20.1.0
  • asgiref ==3.5.0
  • asn1crypto ==1.4.0
  • astor ==0.8.1
  • astroid ==2.5
  • astropy ==4.2.1
  • async-generator ==1.10
  • async-timeout ==4.0.2
  • asynctest ==0.13.0
  • atomicwrites ==1.4.0
  • attrs ==20.3.0
  • autopep8 ==1.5.6
  • babel ==2.9.0
  • backcall ==0.2.0
  • backoff ==1.10.0
  • backports.shutil-get-terminal-size ==1.0.0
  • bcrypt ==3.2.0
  • beautifulsoup4 ==4.9.3
  • bert-score ==0.3.11
  • bitarray ==2.1.0
  • bkcharts ==0.2
  • black ==19.10b0
  • bleach ==3.3.0
  • bokeh ==2.3.2
  • boto ==2.49.0
  • bottleneck ==1.3.2
  • brotlipy ==0.7.0
  • certifi ==2020.12.5
  • cffi ==1.14.5
  • chardet ==4.0.0
  • charset-normalizer ==2.0.12
  • click ==7.1.2
  • cloudpickle ==1.6.0
  • clyent ==1.2.2
  • colorama ==0.4.4
  • contextlib2 ==0.6.0.post1
  • cryptography ==3.4.7
  • cycler ==0.10.0
  • cython ==0.29.23
  • cytoolz ==0.11.0
  • dask ==2021.4.0
  • datasets ==1.2.1
  • decorator ==5.0.6
  • defusedxml ==0.7.1
  • diff-match-patch ==20200713
  • dill ==0.3.4
  • distributed ==2021.4.1
  • docker-pycreds ==0.4.0
  • docutils ==0.17.1
  • editdistance *
  • elasticsearch *
  • entrypoints ==0.3
  • et-xmlfile ==1.0.1
  • faiss-cpu ==1.7.2
  • fastapi ==0.75.0
  • fastcache ==1.1.0
  • ffmpy ==0.3.0
  • filelock ==3.0.12
  • flake8 ==3.9.0
  • flask ==1.1.2
  • fonttools ==4.31.2
  • frozenlist ==1.3.0
  • fsspec ==0.9.0
  • future ==0.18.2
  • gevent ==21.1.2
  • gitdb ==4.0.9
  • gitpython ==3.1.27
  • glob2 ==0.7
  • gmpy2 ==2.0.8
  • gradio ==2.8.12
  • greenlet ==1.0.0
  • h11 ==0.13.0
  • h5py ==2.10.0
  • heapdict ==1.0.1
  • html5lib ==1.1
  • idna ==2.10
  • imageio ==2.9.0
  • imagesize ==1.2.0
  • importlib-metadata ==3.10.0
  • iniconfig ==1.1.1
  • intervaltree ==3.1.0
  • ipykernel ==5.3.4
  • ipython ==7.22.0
  • ipython-genutils ==0.2.0
  • ipywidgets ==7.6.3
  • isort ==5.8.0
  • itsdangerous ==1.1.0
  • jdcal ==1.4.1
  • jedi ==0.17.2
  • jeepney ==0.6.0
  • jinja2 ==2.11.3
  • joblib ==1.0.1
  • json5 ==0.9.5
  • jsonschema ==3.2.0
  • jupyter ==1.0.0
  • jupyter-client ==6.1.12
  • jupyter-console ==6.4.0
  • jupyter-core ==4.7.1
  • jupyter-packaging ==0.7.12
  • jupyter-server ==1.4.1
  • jupyterlab ==3.0.14
  • jupyterlab-pygments ==0.1.2
  • jupyterlab-server ==2.4.0
  • jupyterlab-widgets ==1.0.0
  • keyring ==22.3.0
  • kiwisolver ==1.3.1
  • lazy-object-proxy ==1.6.0
  • libarchive-c ==2.9
  • linkify-it-py ==1.0.3
  • llvmlite ==0.36.0
  • locket ==0.2.1
  • lxml ==4.6.3
  • markdown-it-py ==2.0.1
  • markupsafe ==1.1.1
  • matplotlib ==3.5.1
  • mccabe ==0.6.1
  • mdit-py-plugins ==0.3.0
  • mdurl ==0.1.0
  • mistune ==0.8.4
  • mkl-fft ==1.3.0
  • mkl-random ==1.2.1
  • mkl-service ==2.3.0
  • mock ==4.0.3
  • monotonic ==1.6
  • more-itertools ==8.7.0
  • mpmath ==1.2.1
  • msgpack ==1.0.2
  • multidict ==6.0.2
  • multipledispatch ==0.6.0
  • multiprocess ==0.70.12.2
  • mypy-extensions ==0.4.3
  • nbclassic ==0.2.6
  • nbclient ==0.5.3
  • nbconvert ==6.0.7
  • nbformat ==5.1.3
  • nest-asyncio ==1.5.1
  • networkx ==2.5
  • nltk ==3.6.1
  • nose ==1.3.7
  • notebook ==6.3.0
  • numba ==0.53.1
  • numexpr ==2.7.3
  • numpy ==1.20.1
  • numpydoc ==1.1.0
  • olefile ==0.46
  • openpyxl ==3.0.7
  • orjson ==3.6.7
  • packaging ==20.9
  • pandas ==1.1.5
  • pandocfilters ==1.4.3
  • paramiko ==2.10.3
  • parso ==0.7.0
  • partd ==1.2.0
  • path ==15.1.2
  • pathlib2 ==2.3.5
  • pathspec ==0.7.0
  • pathtools ==0.1.2
  • patsy ==0.5.1
  • pep8 ==1.7.1
  • pexpect ==4.8.0
  • pickleshare ==0.7.5
  • pillow ==8.2.0
  • pip ==21.0.1
  • pkginfo ==1.7.0
  • pluggy ==0.13.1
  • ply ==3.11
  • prettytable ==2.1.0
  • prometheus-client ==0.10.1
  • promise ==2.3
  • prompt-toolkit ==3.0.17
  • protobuf ==3.19.4
  • psutil ==5.8.0
  • ptyprocess ==0.7.0
  • py ==1.10.0
  • pyarrow ==7.0.0
  • pycodestyle ==2.6.0
  • pycosat ==0.6.3
  • pycparser ==2.20
  • pycrypto ==2.6.1
  • pycryptodome ==3.14.1
  • pycurl ==7.43.0.6
  • pydantic ==1.9.0
  • pydocstyle ==6.0.0
  • pydub ==0.25.1
  • pyerfa ==1.7.3
  • pyflakes ==2.2.0
  • pygments ==2.8.1
  • pylint ==2.7.4
  • pyls-black ==0.4.6
  • pyls-spyder ==0.3.2
  • pynacl ==1.5.0
  • pyodbc ==4.0.0
  • pyopenssl ==20.0.1
  • pyparsing ==2.4.7
  • pyrsistent ==0.17.3
  • pysocks ==1.7.1
  • pytest ==6.2.3
  • python-dateutil ==2.8.1
  • python-jsonrpc-server ==0.4.0
  • python-language-server ==0.36.2
  • python-multipart ==0.0.5
  • pytz ==2021.1
  • pywavelets ==1.1.1
  • pywsd ==1.2.4
  • pyxdg ==0.27
  • pyyaml ==5.4.1
  • pyzmq ==20.0.0
  • qdarkstyle ==2.8.1
  • qtawesome ==1.0.2
  • qtconsole ==5.0.3
  • qtpy ==1.9.0
  • regex ==2021.4.4
  • requests ==2.25.1
  • rope ==0.18.0
  • rtree ==0.9.7
  • ruamel-yaml-conda ==0.15.100
  • sacremoses ==0.0.49
  • scikit-image ==0.18.1
  • scikit-learn ==0.24.0
  • scipy ==1.5.4
  • seaborn ==0.11.1
  • secretstorage ==3.3.1
  • send2trash ==1.5.0
  • sentencepiece ==0.1.96
  • sentry-sdk ==1.5.8
  • setproctitle ==1.2.2
  • setuptools ==49.3.0
  • shortuuid ==1.0.8
  • simplegeneric ==0.8.1
  • singledispatch ==0.0.0
  • six ==1.15.0
  • smmap ==5.0.0
  • sniffio ==1.2.0
  • snowballstemmer ==2.1.0
  • sortedcollections ==2.1.0
  • sortedcontainers ==2.3.0
  • soupsieve ==2.2.1
  • sphinx ==4.0.1
  • sphinxcontrib-applehelp ==1.0.2
  • sphinxcontrib-devhelp ==1.0.2
  • sphinxcontrib-htmlhelp ==1.0.3
  • sphinxcontrib-jsmath ==1.0.1
  • sphinxcontrib-qthelp ==1.0.3
  • sphinxcontrib-serializinghtml ==1.1.4
  • sphinxcontrib-websupport ==1.2.4
  • spyder ==4.2.5
  • spyder-kernels ==1.10.2
  • sqlalchemy ==1.4.15
  • starlette ==0.17.1
  • statsmodels ==0.12.2
  • sympy ==1.8
  • tables ==3.6.1
  • tblib ==1.7.0
  • termcolor ==1.1.0
  • terminado ==0.9.4
  • testpath ==0.4.4
  • textdistance ==4.2.1
  • threadpoolctl ==2.1.0
  • three-merge ==0.1.1
  • tifffile ==2020.10.1
  • tokenizers ==0.9.4
  • toml ==0.10.2
  • tomli ==2.0.1
  • toolz ==0.11.1
  • torch ==1.7.1
  • torchaudio ==0.7.2
  • torchvision ==0.8.2
  • tornado ==6.1
  • tqdm ==4.49.0
  • traitlets ==5.0.5
  • transformers ==4.2.1
  • typed-ast ==1.4.2
  • typing-extensions ==3.7.4.3
  • uc-micro-py ==1.0.1
  • ujson ==4.0.2
  • unicodecsv ==0.14.1
  • urllib3 ==1.26.4
  • uvicorn ==0.17.6
  • wandb ==0.12.11
  • watchdog ==1.0.2
  • wcwidth ==0.2.5
  • webencodings ==0.5.1
  • werkzeug ==1.0.1
  • wheel ==0.36.2
  • whoosh ==2.7.4
  • widgetsnbextension ==3.5.1
  • wn ==0.9.1
  • wrapt ==1.12.1
  • wurlitzer ==2.1.0
  • xlrd ==2.0.1
  • xlsxwriter ==1.3.8
  • xlwt ==1.3.0
  • xxhash ==3.0.0
  • yapf ==0.31.0
  • yarl ==1.7.2
  • yaspin ==2.1.0
  • zict ==2.0.0
  • zipp ==3.4.1
  • zope.event ==4.5.0
  • zope.interface ==5.3.0
generator/fid/setup.py pypi