docprompting
Data and code for "DocPrompting: Generating Code by Retrieving the Docs" @ICLR 2023
Science Score: 54.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.6%) to scientific vocabulary
Keywords
Repository
Data and code for "DocPrompting: Generating Code by Retrieving the Docs" @ICLR 2023
Basic Info
Statistics
- Stars: 243
- Watchers: 9
- Forks: 19
- Open Issues: 6
- Releases: 0
Topics
Metadata Files
README.md
DocPrompting: Generating Code by Retrieving the Docs
This is the official implementation of
Shuyan Zhou, Uri Alon, Frank F. Xu, Zhiruo Wang, Zhengbao Jiang, Graham Neubig, "DocPrompting: Generating Code by Retrieving the Docs", ICLR'2023 (Spotlight)
January 2023 - The paper was accepted to ICLR'2023 as a Spotlight!
Publicly available source-code libraries are continuously growing and changing. This makes it impossible for models of code to keep current with all available APIs by simply training these models on existing code repositories. We introduce DocPrompting: a natural-language-to-code generation approach that explicitly leverages documentation by 1. retrieving the relevant documentation pieces given an NL intent, and 2. generating code based on the NL intent and the retrieved documentation.
In this repository we provide the best model in each setting described in the paper.

Table of content
- Quick Dataset&Eval Access through 🤗
- Quick Models Loading 🤗
- Preparation
- Retrieval
- Generation
- Data
- Resources
- Citation
Huggingface 🤗 Dataset & Evaluation
In this work, we introduce a new natural language to bash generation benchmark tldr
and re-split CoNaLa to have unseen functions on the dev and test set.
The datasets and the corresponding evaluations are available on huggingface
* tldr and eval
* CoNaLa and eval
```python
import datasets
import evaluate
tldr = datasets.loaddataset('neulab/tldr')
tldrmetric = evaluate.load('neulab/tldr_eval')
conala = datasets.loaddataset('neulab/docprompting-conala') conalametric = evaluate.load('neulab/python_bleu') ```
Huggingface 🤗 Models
We make the following models available on Huggingface:
- neulab/docprompting-tldr-gpt-neo-125M
- neulab/docprompting-tldr-gpt-neo-1.3B
Example usage
```python from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.frompretrained("neulab/docprompting-tldr-gpt-neo-1.3B") model = AutoModelForCausalLM.frompretrained("neulab/docprompting-tldr-gpt-neo-1.3B")
prompt template
prompt = f"""{tokenizer.bostoken} Potential manual 0: makepkg - package build utility Potential manual 1: -c, --clean Clean up leftover work files and directories after a successful build. Potential manual 2: -r, --rmdeps Upon successful build, remove any dependencies installed by makepkg during dependency auto-resolution and installation when using -s Potential manual 3: CONTENTOFTHEMANUAL3 ... Potential manual 10: CONTENTOFTHEMANUAL10""" prompt += f"{tokenizer.septoken} clean up work directories after a successful build {tokenizer.sep_token}"
inputids = tokenizer(prompt, returntensors="pt").inputids gentokens = model.generate( inputids, numbeams=5, maxnewtokens=150, numreturnsequences=2, padtokenid=tokenizer.eostokenid ) gentokens = gentokens.reshape(1, -1, gen_tokens.shape[-1])[0][0]
to text and clean
gencode = tokenizer.decode(gentokens) gencode = gencode.split(tokenizer.septoken)[2].strip().split(tokenizer.eostoken)[0].strip() print(gen_code)
makepkg --clean {{path/to/directory}} ```
Example script
An example script on tldr by using the retrieved docs is here
Other models
Other models require the customized implementations in our repo, please read through the corresponding sections to use them. These models are:
1. sparse retriever based on BM25 for tldr
2. dense retriever based on CodeT5 for CoNaLa
3. FiD T5 generator for tldr
4. FiD CodeT5 generator for CoNaLa
The following instructions are for reproducing the results in the paper.
Preparation
Download data for CoNaLa and tldr from link
```bash
unzip
unzip docprompting_data.zip
move to the data folder
mv docprompting_data/* data ```
Download trained generator weights from link ```bash unzip docpromptinggeneratormodels.zip
move to the model folder
mv docpromptinggeneratormodels/* models/generator
```
Retrieval
Dense retrieval
(CoNaLa as an example)
The code is based on SimCSE
- Run inference with our trained model on CoNaLa (Python)
bash python retriever/simcse/run_inference.py \ --model_name "neulab/docprompting-codet5-python-doc-retriever" \ --source_file data/conala/conala_nl.txt \ --target_file data/conala/python_manual_firstpara.tok.txt \ --source_embed_save_file data/conala/.tmp/src_embedding \ --target_embed_save_file data/conala/.tmp/tgt_embedding \ --sim_func cls_distance.cosine \ --num_layers 12 \ --save_file data/conala/retrieval_results.jsonWe observed that model whether or not to normalize the embeddings can affect the retrieval results. We therefore selected this hyper-parameter (--normalize_embed) on the validation set.
The results will be saved to data/conala/retrieval_results.json.
- Train your own retriever
bash python retriever/simcse/run_train.py \ --num_layers 12 \ --model_name_or_path Salesforce/codet5-base \ --sim_func cls_distance.cosine \ --temp 0.05 \ --train_file data/conala/train_retriever_sup_unsup.json \ --eval_file data/conala/dev_retriever.json \ --output_dir models/retriever/docprompting_codet5_python_doc_retriever \ --eval_src_file data/conala/conala_nl.txt \ --eval_tgt_file data/conala/python_manual_firstpara.tok.txt \ --eval_root_folder data/conala \ --eval_oracle_file data/conala/cmd_dev.oracle_man.full.json \ --run_name docprompting_codet5_python_doc_retriever \ --num_train_epochs 10 \ --per_device_train_batch_size 512 \ --learning_rate 1e-5 \ --max_seq_length 32 \ --evaluation_strategy steps \ --metric_for_best_model recall@10 \ --load_best_model_at_end \ --eval_steps 125 \ --overwrite_output_dir \ --do_train \ --eval_form retrieval "$@" train_retriever_sup_unsup.jsoncontains the supervised (CoNaLatraining and mined) and unsupervised data (duplication of sentences in a doc) for training the retriever.- Be accurate on the saved model name. If using codet5, make sure
codet5is in the name.
Sparse retrieval
(tldr as an example)
There are two stages in the retrieval procedure in tldr.
The first stage retrieves the bash command and the second stage retrieves the potentially relevant paragraphs that describe the usage of the arguments
1. build index with Elasticsearch
bash
python retriever/bm25/main.py \
--retrieval_stage 0
2. first stage retrieval
bash
python retriever/bm25/main.py \
--retrieval_stage 1 \
--split {cmd_train, cmd_dev, cmd_test}
3. second stage retrieval
bash
python retriever/bm25/main.py \
--retrieval_stage 2 \
--split {cmd_train, cmd_dev, cmd_test}
Generation
FID generation
The code is based on FiD A training or evaluation file should be converted to the format compatible with FiD. An example is here
Important note: FiD has a strong dependency on the version of
transformers(3.0.2). Unable to match the version might result in inreproducible results. 1. Run generation. Here is an example with our trained model on Python CoNaLabash ds='conala' python generator/fid/test_reader_simple.py \ --model_path models/generator/${ds}.fid.codet5.top10/checkpoint/best_dev \ --tokenizer_name models/generator/codet5-base \ --eval_data data/${ds}/fid.cmd_test.codet5.t10.json \ --per_gpu_batch_size 8 \ --n_context 10 \ --name ${ds}.fid.codet5.top10 \ --checkpoint_dir models/generator \ --result_tag test_same \ --main_port 81692The results will be saved tomodels/generator/{name}/test_results_test_same.json
To evaluate pass@k, we need more generations, we use nucleus sampling (instead of beam search) for the generation.
bash
ds='conala'
t=1.0 # set this from 0.2, 0.4, 0.6, .. 1.0. Use the dev set to find the best temperature
python generator/fid/test_reader_simple.py \
--model_path models/generator/${ds}.fid.codet5.top10/checkpoint/best_dev \
--tokenizer_name models/generator/codet5-base \
--eval_data data/${ds}/fid.cmd_test.codet5.t10.ns200.json \
--per_gpu_batch_size 8 \
--n_context 10 \
--name ${ds}.fid.codet5.top10.ns200 \
--checkpoint_dir models/generator \
--result_tag test_same \
--num_beams 1 \
--temperature $t \
--top_p 0.95 \
--num_return_sequences 200 \
--main_port 81692
Then run this script
bash
python dataset_helper/conala/execution_eval.py --result_file data/${ds}/fid.cmd_test.codet5.t10.ns200.json
- Train your own generator ```bash ds='conala' python generator/fid/trainreader.py \ --seed 1996 \ --traindata data/${ds}/fid.cmdtrain.codet5.t10.json \ --evaldata data/${ds}/fid.cmddev.codet5.t10.json \ --modelname models/generator/codet5-base \ # initialize with the codet5-base model \ --pergpubatchsize 4 \ --ncontext 10 \ --name ${ds}.fid.codet5.top10 \ --checkpointdir models/generator/ \ --evalfreq 500 \ --accumulationsteps 2 \ --mainport 30843 \ --totalsteps 20000 \ --warmupsteps 2000
ds='tldr'
python generator/fid/trainreader.py \
--dataset tldr \
--traindata data/${ds}/fid.cmdtrain.codet5.t10.json \
--evaldata data/${ds}/fid.cmdmodelselect.codet5.t10.json \
--modelname models/generator/codet5-base \
--pergpubatchsize 4 \
--ncontext 10 \
--evalmetric tokenf1 \
--name ${ds}.fid.codet5.top10 \
--checkpointdir models/generator/ \
--evalfreq 1000 \
--accumulationsteps 2 \
--mainport 32420 \
--totalsteps 20000 \
--warmupsteps 2000
``
* Examples infid.cmdmodelselect.codet5.t10.jsonare the same asfid.cmddev.codet5.t10.json`.
The difference is that it use the oracle first stage retrieval results (oracle bash name).
Data
The data folder contains the two benchmarks we curated or re-splitted.
* tldr
* CoNaLa
On each dataset, we provide
1. Natural language intent (entry nl)
2. Oracle code (entry cmd)
* Bash for tldr
* Python for CoNaLa
3. Oracle docs (entry oracle_man)
* In the data files, we only provide the manual ids, their contents could be found in the {dataset}/{dataset}_docs.json.
4. Other data with different format for different modules
Resources
Citation
@inproceedings{zhou23docprompting,
title = {DocPrompting: Generating Code by Retrieving the Docs},
author = {Shuyan Zhou and Uri Alon and Frank F. Xu and Zhiruo Wang and Zhengbao Jiang and Graham Neubig},
booktitle = {International Conference on Learning Representations (ICLR)},
address = {Kigali, Rwanda},
month = {May},
url = {https://arxiv.org/abs/2207.05987},
year = {2023}
}
Owner
- Name: Shuyan Zhou
- Login: shuyanzhou
- Kind: user
- Location: Pittsburgh, PA
- Website: shuyanzhou.com
- Twitter: shuyanzhxyc
- Repositories: 3
- Profile: https://github.com/shuyanzhou
PhD student @ LTI CMU
Citation (CITATION.cff)
@article{zhou2022doccoder,
title={DocCoder: Generating Code by Retrieving and Reading Docs},
author={Zhou, Shuyan and Alon, Uri and Xu, Frank F and JIang, Zhengbao and Neubig, Graham},
journal={arXiv preprint arXiv:2207.05987},
year={2022}
}
GitHub Events
Total
- Issues event: 2
- Watch event: 15
- Issue comment event: 1
- Fork event: 3
Last Year
- Issues event: 2
- Watch event: 15
- Issue comment event: 1
- Fork event: 3
Dependencies
- faiss-cpu *
- numpy *
- tensorboard *
- torch *
- transformers ==3.0.2
- aiohttp ==3.8.1
- aiosignal ==1.2.0
- alabaster ==0.7.12
- anaconda-client ==1.7.2
- anaconda-project ==0.9.1
- analytics-python ==1.4.0
- anyio ==3.5.0
- appdirs ==1.4.4
- argh ==0.26.2
- argon2-cffi ==20.1.0
- asgiref ==3.5.0
- asn1crypto ==1.4.0
- astor ==0.8.1
- astroid ==2.5
- astropy ==4.2.1
- async-generator ==1.10
- async-timeout ==4.0.2
- asynctest ==0.13.0
- atomicwrites ==1.4.0
- attrs ==20.3.0
- autopep8 ==1.5.6
- babel ==2.9.0
- backcall ==0.2.0
- backoff ==1.10.0
- backports.shutil-get-terminal-size ==1.0.0
- bcrypt ==3.2.0
- beautifulsoup4 ==4.9.3
- bert-score ==0.3.11
- bitarray ==2.1.0
- bkcharts ==0.2
- black ==19.10b0
- bleach ==3.3.0
- bokeh ==2.3.2
- boto ==2.49.0
- bottleneck ==1.3.2
- brotlipy ==0.7.0
- certifi ==2020.12.5
- cffi ==1.14.5
- chardet ==4.0.0
- charset-normalizer ==2.0.12
- click ==7.1.2
- cloudpickle ==1.6.0
- clyent ==1.2.2
- colorama ==0.4.4
- contextlib2 ==0.6.0.post1
- cryptography ==3.4.7
- cycler ==0.10.0
- cython ==0.29.23
- cytoolz ==0.11.0
- dask ==2021.4.0
- datasets ==1.2.1
- decorator ==5.0.6
- defusedxml ==0.7.1
- diff-match-patch ==20200713
- dill ==0.3.4
- distributed ==2021.4.1
- docker-pycreds ==0.4.0
- docutils ==0.17.1
- editdistance *
- elasticsearch *
- entrypoints ==0.3
- et-xmlfile ==1.0.1
- faiss-cpu ==1.7.2
- fastapi ==0.75.0
- fastcache ==1.1.0
- ffmpy ==0.3.0
- filelock ==3.0.12
- flake8 ==3.9.0
- flask ==1.1.2
- fonttools ==4.31.2
- frozenlist ==1.3.0
- fsspec ==0.9.0
- future ==0.18.2
- gevent ==21.1.2
- gitdb ==4.0.9
- gitpython ==3.1.27
- glob2 ==0.7
- gmpy2 ==2.0.8
- gradio ==2.8.12
- greenlet ==1.0.0
- h11 ==0.13.0
- h5py ==2.10.0
- heapdict ==1.0.1
- html5lib ==1.1
- idna ==2.10
- imageio ==2.9.0
- imagesize ==1.2.0
- importlib-metadata ==3.10.0
- iniconfig ==1.1.1
- intervaltree ==3.1.0
- ipykernel ==5.3.4
- ipython ==7.22.0
- ipython-genutils ==0.2.0
- ipywidgets ==7.6.3
- isort ==5.8.0
- itsdangerous ==1.1.0
- jdcal ==1.4.1
- jedi ==0.17.2
- jeepney ==0.6.0
- jinja2 ==2.11.3
- joblib ==1.0.1
- json5 ==0.9.5
- jsonschema ==3.2.0
- jupyter ==1.0.0
- jupyter-client ==6.1.12
- jupyter-console ==6.4.0
- jupyter-core ==4.7.1
- jupyter-packaging ==0.7.12
- jupyter-server ==1.4.1
- jupyterlab ==3.0.14
- jupyterlab-pygments ==0.1.2
- jupyterlab-server ==2.4.0
- jupyterlab-widgets ==1.0.0
- keyring ==22.3.0
- kiwisolver ==1.3.1
- lazy-object-proxy ==1.6.0
- libarchive-c ==2.9
- linkify-it-py ==1.0.3
- llvmlite ==0.36.0
- locket ==0.2.1
- lxml ==4.6.3
- markdown-it-py ==2.0.1
- markupsafe ==1.1.1
- matplotlib ==3.5.1
- mccabe ==0.6.1
- mdit-py-plugins ==0.3.0
- mdurl ==0.1.0
- mistune ==0.8.4
- mkl-fft ==1.3.0
- mkl-random ==1.2.1
- mkl-service ==2.3.0
- mock ==4.0.3
- monotonic ==1.6
- more-itertools ==8.7.0
- mpmath ==1.2.1
- msgpack ==1.0.2
- multidict ==6.0.2
- multipledispatch ==0.6.0
- multiprocess ==0.70.12.2
- mypy-extensions ==0.4.3
- nbclassic ==0.2.6
- nbclient ==0.5.3
- nbconvert ==6.0.7
- nbformat ==5.1.3
- nest-asyncio ==1.5.1
- networkx ==2.5
- nltk ==3.6.1
- nose ==1.3.7
- notebook ==6.3.0
- numba ==0.53.1
- numexpr ==2.7.3
- numpy ==1.20.1
- numpydoc ==1.1.0
- olefile ==0.46
- openpyxl ==3.0.7
- orjson ==3.6.7
- packaging ==20.9
- pandas ==1.1.5
- pandocfilters ==1.4.3
- paramiko ==2.10.3
- parso ==0.7.0
- partd ==1.2.0
- path ==15.1.2
- pathlib2 ==2.3.5
- pathspec ==0.7.0
- pathtools ==0.1.2
- patsy ==0.5.1
- pep8 ==1.7.1
- pexpect ==4.8.0
- pickleshare ==0.7.5
- pillow ==8.2.0
- pip ==21.0.1
- pkginfo ==1.7.0
- pluggy ==0.13.1
- ply ==3.11
- prettytable ==2.1.0
- prometheus-client ==0.10.1
- promise ==2.3
- prompt-toolkit ==3.0.17
- protobuf ==3.19.4
- psutil ==5.8.0
- ptyprocess ==0.7.0
- py ==1.10.0
- pyarrow ==7.0.0
- pycodestyle ==2.6.0
- pycosat ==0.6.3
- pycparser ==2.20
- pycrypto ==2.6.1
- pycryptodome ==3.14.1
- pycurl ==7.43.0.6
- pydantic ==1.9.0
- pydocstyle ==6.0.0
- pydub ==0.25.1
- pyerfa ==1.7.3
- pyflakes ==2.2.0
- pygments ==2.8.1
- pylint ==2.7.4
- pyls-black ==0.4.6
- pyls-spyder ==0.3.2
- pynacl ==1.5.0
- pyodbc ==4.0.0
- pyopenssl ==20.0.1
- pyparsing ==2.4.7
- pyrsistent ==0.17.3
- pysocks ==1.7.1
- pytest ==6.2.3
- python-dateutil ==2.8.1
- python-jsonrpc-server ==0.4.0
- python-language-server ==0.36.2
- python-multipart ==0.0.5
- pytz ==2021.1
- pywavelets ==1.1.1
- pywsd ==1.2.4
- pyxdg ==0.27
- pyyaml ==5.4.1
- pyzmq ==20.0.0
- qdarkstyle ==2.8.1
- qtawesome ==1.0.2
- qtconsole ==5.0.3
- qtpy ==1.9.0
- regex ==2021.4.4
- requests ==2.25.1
- rope ==0.18.0
- rtree ==0.9.7
- ruamel-yaml-conda ==0.15.100
- sacremoses ==0.0.49
- scikit-image ==0.18.1
- scikit-learn ==0.24.0
- scipy ==1.5.4
- seaborn ==0.11.1
- secretstorage ==3.3.1
- send2trash ==1.5.0
- sentencepiece ==0.1.96
- sentry-sdk ==1.5.8
- setproctitle ==1.2.2
- setuptools ==49.3.0
- shortuuid ==1.0.8
- simplegeneric ==0.8.1
- singledispatch ==0.0.0
- six ==1.15.0
- smmap ==5.0.0
- sniffio ==1.2.0
- snowballstemmer ==2.1.0
- sortedcollections ==2.1.0
- sortedcontainers ==2.3.0
- soupsieve ==2.2.1
- sphinx ==4.0.1
- sphinxcontrib-applehelp ==1.0.2
- sphinxcontrib-devhelp ==1.0.2
- sphinxcontrib-htmlhelp ==1.0.3
- sphinxcontrib-jsmath ==1.0.1
- sphinxcontrib-qthelp ==1.0.3
- sphinxcontrib-serializinghtml ==1.1.4
- sphinxcontrib-websupport ==1.2.4
- spyder ==4.2.5
- spyder-kernels ==1.10.2
- sqlalchemy ==1.4.15
- starlette ==0.17.1
- statsmodels ==0.12.2
- sympy ==1.8
- tables ==3.6.1
- tblib ==1.7.0
- termcolor ==1.1.0
- terminado ==0.9.4
- testpath ==0.4.4
- textdistance ==4.2.1
- threadpoolctl ==2.1.0
- three-merge ==0.1.1
- tifffile ==2020.10.1
- tokenizers ==0.9.4
- toml ==0.10.2
- tomli ==2.0.1
- toolz ==0.11.1
- torch ==1.7.1
- torchaudio ==0.7.2
- torchvision ==0.8.2
- tornado ==6.1
- tqdm ==4.49.0
- traitlets ==5.0.5
- transformers ==4.2.1
- typed-ast ==1.4.2
- typing-extensions ==3.7.4.3
- uc-micro-py ==1.0.1
- ujson ==4.0.2
- unicodecsv ==0.14.1
- urllib3 ==1.26.4
- uvicorn ==0.17.6
- wandb ==0.12.11
- watchdog ==1.0.2
- wcwidth ==0.2.5
- webencodings ==0.5.1
- werkzeug ==1.0.1
- wheel ==0.36.2
- whoosh ==2.7.4
- widgetsnbextension ==3.5.1
- wn ==0.9.1
- wrapt ==1.12.1
- wurlitzer ==2.1.0
- xlrd ==2.0.1
- xlsxwriter ==1.3.8
- xlwt ==1.3.0
- xxhash ==3.0.0
- yapf ==0.31.0
- yarl ==1.7.2
- yaspin ==2.1.0
- zict ==2.0.0
- zipp ==3.4.1
- zope.event ==4.5.0
- zope.interface ==5.3.0