regressor
Regressing how many citations a given paper will have based on its abstract, title, year
Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Committers with academic emails
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (6.4%) to scientific vocabulary
Keywords
Repository
Regressing how many citations a given paper will have based on its abstract, title, year
Basic Info
Statistics
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
Citation Regressor
In this section we want to create a model that can estimate how many citations a given paper will have based on its abstract, title, year, and author-author network.
Regressor
```markdown
Top Predicted Citations:
title year \
6131 DSI-Net: Deep Synergistic Interaction Network ... 2021
1077 Uni-ControlNet: All-in-One Control to Text-to-... 2023
246 Anchor Diffusion for Unsupervised Video Object... 2019
topic predicted_citations
6131 LLM 1.189171e+13
1077 Diffusion Models 8.380618e+12
246 Diffusion Models 3.803213e+12
```
Author-Author Network
The idea is that scholarly paper authors have relations with each other that may be a great predictor of how many citations a given paper is going to have. ```markdown Test Set Regression Metrics: RMSE: 144.20 MAE: 60.33 R2 Score: -0.27
Top Predicted Citations:
title year \
4119 The Good, the Bad, and the Expert: How Consume... 2015
4740 Chinese Term Recognition and Extraction Based ... 2008
4045 High efficient hardware allocation framework o... 2015
10845 Score-Based Generative Modeling through Stocha... 2020
topic predicted_citations
4119 NaN 2697.807129
4740 NaN 2314.161621
4045 NaN 1178.770020
10845 Diffusion Models 1175.487061
```
Owner
- Name: Arman Yazdani
- Login: theveryhim
- Kind: user
- Location: Tehran-Iran
- Repositories: 1
- Profile: https://github.com/theveryhim
Currently Electrical Engineering B.Sc student at Sharif university of technology.
Citation (Citation Regressor.ipynb)
{
"cells": [
{
"cell_type": "markdown",
"id": "b118e482",
"metadata": {},
"source": [
"# Load"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a19ca19a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 100000 entries, 0 to 99999\n",
"Data columns (total 9 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Unnamed: 0 100000 non-null int64 \n",
" 1 abstract 82990 non-null object\n",
" 2 authors 99999 non-null object\n",
" 3 n_citation 100000 non-null int64 \n",
" 4 references 87625 non-null object\n",
" 5 title 100000 non-null object\n",
" 6 venue 82305 non-null object\n",
" 7 year 100000 non-null int64 \n",
" 8 id 100000 non-null object\n",
"dtypes: int64(3), object(6)\n",
"memory usage: 6.9+ MB\n"
]
}
],
"source": [
"import pandas as pd\n",
"data = pd.read_csv(\"dblp-v11.csv\")\n",
"data.info()"
]
},
{
"cell_type": "markdown",
"id": "b55cd22c",
"metadata": {},
"source": [
"# Regressor"
]
},
{
"cell_type": "markdown",
"id": "2fcdeddb",
"metadata": {},
"source": [
"### Model selection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a1580b0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e4d02ec814d84b8eb48ba5f8d444adc9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Training XGBoost...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [02:07:06] WARNING: /workspace/src/common/error_msg.cc:27: The tree method `gpu_hist` is deprecated since 2.0.0. To use GPU training, set the `device` parameter to CUDA instead.\n",
"\n",
" E.g. tree_method = \"hist\", device = \"cuda\"\n",
"\n",
" warnings.warn(smsg, UserWarning)\n",
"/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [02:07:06] WARNING: /workspace/src/learner.cc:740: \n",
"Parameters: { \"predictor\" } are not used.\n",
"\n",
" warnings.warn(smsg, UserWarning)\n",
"/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [02:07:09] WARNING: /workspace/src/common/error_msg.cc:27: The tree method `gpu_hist` is deprecated since 2.0.0. To use GPU training, set the `device` parameter to CUDA instead.\n",
"\n",
" E.g. tree_method = \"hist\", device = \"cuda\"\n",
"\n",
" warnings.warn(smsg, UserWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"XGBoost Validation RMSE: 54.23\n",
"\n",
"Training RandomForest...\n",
"RandomForest Validation RMSE: 54.32\n",
"\n",
"Best Model: XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device='cuda:0', early_stopping_rounds=None,\n",
" enable_categorical=False, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=6, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=200, n_jobs=None,\n",
" num_parallel_tree=None, predictor='gpu_predictor', ...)\n",
"Final RMSE: 54.23\n",
"Final MAE: 23.98\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sentence_transformers import SentenceTransformer\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from xgboost import XGBRegressor, DMatrix\n",
"import torch\n",
"\n",
"# Check for GPU availability\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Sample data for faster prototyping\n",
"df = data.sample(n=1_000, random_state=1)\n",
"\n",
"# Preprocess data\n",
"# Handle missing values\n",
"df['abstract'] = df['abstract'].fillna('')\n",
"df['text'] = df['title'] + \" \" + df['abstract']\n",
"\n",
"# Prepare features\n",
"scaler = MinMaxScaler()\n",
"df['year_scaled'] = scaler.fit_transform(df[['year']])\n",
"\n",
"# Prepare target with log transformation\n",
"df['log_citation'] = np.log1p(df['n_citation'])\n",
"\n",
"# Generate embeddings using GPU\n",
"model = SentenceTransformer('all-mpnet-base-v2', device='cuda')\n",
"embeddings = model.encode(\n",
" df['text'].tolist(),\n",
" batch_size=128,\n",
" show_progress_bar=True,\n",
" device='cuda'\n",
")\n",
"\n",
"# Combine features\n",
"X = np.hstack([embeddings, df['year_scaled'].values.reshape(-1, 1)])\n",
"y = df['log_citation'].values\n",
"\n",
"# Train-validation split\n",
"X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"# Define and train models with GridSearch\n",
"models = [\n",
" {\n",
" 'name': 'XGBoost',\n",
" 'model': XGBRegressor(tree_method = \"hist\", device = \"cuda\", predictor='gpu_predictor'),\n",
" 'params': {\n",
" 'n_estimators': [200, 300],\n",
" 'max_depth': [6, 8],\n",
" 'learning_rate': [0.05, 0.1],\n",
" 'subsample': [0.8, 1.0]\n",
" }\n",
" },\n",
" {\n",
" 'name': 'RandomForest',\n",
" 'model': RandomForestRegressor(),\n",
" 'params': {\n",
" 'n_estimators': [200, 300],\n",
" 'max_depth': [None, 20],\n",
" 'min_samples_split': [3, 5]\n",
" }\n",
" }\n",
"]\n",
"\n",
"best_model = None\n",
"best_score = float('inf')\n",
"\n",
"for m in models:\n",
" print(f\"\\nTraining {m['name']}...\")\n",
" grid = GridSearchCV(\n",
" m['model'],\n",
" m['params'],\n",
" cv=3,\n",
" scoring='neg_mean_squared_error',\n",
" n_jobs=-1\n",
" )\n",
"\n",
" # Use DMatrix for XGBoost to ensure GPU compatibility\n",
" if m['name'] == 'XGBoost':\n",
" dtrain = DMatrix(X_train, label=y_train, enable_categorical=True)\n",
" dval = DMatrix(X_val, label=y_val, enable_categorical=True)\n",
" grid.fit(dtrain.get_data(), dtrain.get_label())\n",
" else:\n",
" grid.fit(X_train, y_train)\n",
"\n",
" # Evaluate\n",
" if m['name'] == 'XGBoost':\n",
" val_pred = grid.best_estimator_.predict(dval.get_data())\n",
" else:\n",
" val_pred = grid.best_estimator_.predict(X_val)\n",
"\n",
" rmse = np.sqrt(mean_squared_error(np.expm1(y_val), np.expm1(val_pred)))\n",
" print(f\"{m['name']} Validation RMSE: {rmse:.2f}\")\n",
"\n",
" if rmse < best_score:\n",
" best_score = rmse\n",
" best_model = grid.best_estimator_\n",
"\n",
"# Evaluate best model\n",
"print(f\"\\nBest Model: {best_model}\")\n",
"if isinstance(best_model, XGBRegressor):\n",
" final_pred = best_model.predict(DMatrix(X_val, enable_categorical=True).get_data())\n",
"else:\n",
" final_pred = best_model.predict(X_val)\n",
"\n",
"final_rmse = np.sqrt(mean_squared_error(np.expm1(y_val), np.expm1(final_pred)))\n",
"final_mae = mean_absolute_error(np.expm1(y_val), np.expm1(final_pred))\n",
"\n",
"print(f\"Final RMSE: {final_rmse:.2f}\")\n",
"print(f\"Final MAE: {final_mae:.2f}\")"
]
},
{
"cell_type": "markdown",
"id": "421503ab",
"metadata": {},
"source": [
"### Training selected model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f89182a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a99a85c89b954b45a5f54c866e17a3ef",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/79 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0]\tvalidation_0-rmse:1.73811\n",
"[1]\tvalidation_0-rmse:1.70124\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [03:03:57] WARNING: /workspace/src/learner.cc:740: \n",
"Parameters: { \"predictor\" } are not used.\n",
"\n",
" warnings.warn(smsg, UserWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2]\tvalidation_0-rmse:1.66960\n",
"[3]\tvalidation_0-rmse:1.64754\n",
"[4]\tvalidation_0-rmse:1.62967\n",
"[5]\tvalidation_0-rmse:1.61450\n",
"[6]\tvalidation_0-rmse:1.60012\n",
"[7]\tvalidation_0-rmse:1.58993\n",
"[8]\tvalidation_0-rmse:1.58093\n",
"[9]\tvalidation_0-rmse:1.57076\n",
"[10]\tvalidation_0-rmse:1.56556\n",
"[11]\tvalidation_0-rmse:1.56057\n",
"[12]\tvalidation_0-rmse:1.55657\n",
"[13]\tvalidation_0-rmse:1.55478\n",
"[14]\tvalidation_0-rmse:1.55072\n",
"[15]\tvalidation_0-rmse:1.54898\n",
"[16]\tvalidation_0-rmse:1.54518\n",
"[17]\tvalidation_0-rmse:1.54427\n",
"[18]\tvalidation_0-rmse:1.54305\n",
"[19]\tvalidation_0-rmse:1.54245\n",
"[20]\tvalidation_0-rmse:1.54107\n",
"[21]\tvalidation_0-rmse:1.53897\n",
"[22]\tvalidation_0-rmse:1.53853\n",
"[23]\tvalidation_0-rmse:1.53753\n",
"[24]\tvalidation_0-rmse:1.53837\n",
"[25]\tvalidation_0-rmse:1.53842\n",
"[26]\tvalidation_0-rmse:1.53715\n",
"[27]\tvalidation_0-rmse:1.53738\n",
"[28]\tvalidation_0-rmse:1.53842\n",
"[29]\tvalidation_0-rmse:1.53723\n",
"[30]\tvalidation_0-rmse:1.53832\n",
"[31]\tvalidation_0-rmse:1.53784\n",
"[32]\tvalidation_0-rmse:1.53869\n",
"[33]\tvalidation_0-rmse:1.53745\n",
"[34]\tvalidation_0-rmse:1.53706\n",
"[35]\tvalidation_0-rmse:1.53698\n",
"[36]\tvalidation_0-rmse:1.53679\n",
"[37]\tvalidation_0-rmse:1.53661\n",
"[38]\tvalidation_0-rmse:1.53629\n",
"[39]\tvalidation_0-rmse:1.53672\n",
"[40]\tvalidation_0-rmse:1.53643\n",
"[41]\tvalidation_0-rmse:1.53557\n",
"[42]\tvalidation_0-rmse:1.53486\n",
"[43]\tvalidation_0-rmse:1.53486\n",
"[44]\tvalidation_0-rmse:1.53502\n",
"[45]\tvalidation_0-rmse:1.53571\n",
"[46]\tvalidation_0-rmse:1.53635\n",
"[47]\tvalidation_0-rmse:1.53587\n",
"[48]\tvalidation_0-rmse:1.53554\n",
"[49]\tvalidation_0-rmse:1.53529\n",
"[50]\tvalidation_0-rmse:1.53521\n",
"[51]\tvalidation_0-rmse:1.53492\n",
"[52]\tvalidation_0-rmse:1.53551\n",
"[53]\tvalidation_0-rmse:1.53580\n",
"[54]\tvalidation_0-rmse:1.53575\n",
"[55]\tvalidation_0-rmse:1.53537\n",
"[56]\tvalidation_0-rmse:1.53573\n",
"[57]\tvalidation_0-rmse:1.53613\n",
"[58]\tvalidation_0-rmse:1.53596\n",
"[59]\tvalidation_0-rmse:1.53620\n",
"[60]\tvalidation_0-rmse:1.53617\n",
"[61]\tvalidation_0-rmse:1.53645\n",
"[62]\tvalidation_0-rmse:1.53656\n",
"[63]\tvalidation_0-rmse:1.53563\n",
"[64]\tvalidation_0-rmse:1.53587\n",
"[65]\tvalidation_0-rmse:1.53584\n",
"[66]\tvalidation_0-rmse:1.53546\n",
"[67]\tvalidation_0-rmse:1.53558\n",
"[68]\tvalidation_0-rmse:1.53642\n",
"[69]\tvalidation_0-rmse:1.53650\n",
"[70]\tvalidation_0-rmse:1.53685\n",
"[71]\tvalidation_0-rmse:1.53671\n",
"[72]\tvalidation_0-rmse:1.53636\n",
"[73]\tvalidation_0-rmse:1.53633\n",
"[74]\tvalidation_0-rmse:1.53618\n",
"[75]\tvalidation_0-rmse:1.53579\n",
"[76]\tvalidation_0-rmse:1.53618\n",
"[77]\tvalidation_0-rmse:1.53687\n",
"[78]\tvalidation_0-rmse:1.53681\n",
"[79]\tvalidation_0-rmse:1.53705\n",
"[80]\tvalidation_0-rmse:1.53713\n",
"[81]\tvalidation_0-rmse:1.53703\n",
"[82]\tvalidation_0-rmse:1.53746\n",
"[83]\tvalidation_0-rmse:1.53767\n",
"[84]\tvalidation_0-rmse:1.53796\n",
"[85]\tvalidation_0-rmse:1.53813\n",
"[86]\tvalidation_0-rmse:1.53808\n",
"[87]\tvalidation_0-rmse:1.53813\n",
"[88]\tvalidation_0-rmse:1.53781\n",
"[89]\tvalidation_0-rmse:1.53786\n",
"[90]\tvalidation_0-rmse:1.53778\n",
"[91]\tvalidation_0-rmse:1.53789\n",
"[92]\tvalidation_0-rmse:1.53832\n",
"[93]\tvalidation_0-rmse:1.53814\n",
"[94]\tvalidation_0-rmse:1.53796\n",
"[95]\tvalidation_0-rmse:1.53770\n",
"[96]\tvalidation_0-rmse:1.53795\n",
"[97]\tvalidation_0-rmse:1.53792\n",
"[98]\tvalidation_0-rmse:1.53820\n",
"[99]\tvalidation_0-rmse:1.53827\n",
"[100]\tvalidation_0-rmse:1.53829\n",
"[101]\tvalidation_0-rmse:1.53854\n",
"[102]\tvalidation_0-rmse:1.53884\n",
"[103]\tvalidation_0-rmse:1.53899\n",
"[104]\tvalidation_0-rmse:1.53848\n",
"[105]\tvalidation_0-rmse:1.53856\n",
"[106]\tvalidation_0-rmse:1.53863\n",
"[107]\tvalidation_0-rmse:1.53919\n",
"[108]\tvalidation_0-rmse:1.53917\n",
"[109]\tvalidation_0-rmse:1.53934\n",
"[110]\tvalidation_0-rmse:1.53985\n",
"[111]\tvalidation_0-rmse:1.54002\n",
"[112]\tvalidation_0-rmse:1.53969\n",
"[113]\tvalidation_0-rmse:1.53989\n",
"[114]\tvalidation_0-rmse:1.53983\n",
"[115]\tvalidation_0-rmse:1.53989\n",
"[116]\tvalidation_0-rmse:1.53948\n",
"[117]\tvalidation_0-rmse:1.53961\n",
"[118]\tvalidation_0-rmse:1.53941\n",
"[119]\tvalidation_0-rmse:1.53957\n",
"[120]\tvalidation_0-rmse:1.53985\n",
"[121]\tvalidation_0-rmse:1.54026\n",
"[122]\tvalidation_0-rmse:1.54030\n",
"[123]\tvalidation_0-rmse:1.54027\n",
"[124]\tvalidation_0-rmse:1.54036\n",
"[125]\tvalidation_0-rmse:1.54013\n",
"[126]\tvalidation_0-rmse:1.54027\n",
"[127]\tvalidation_0-rmse:1.54054\n",
"[128]\tvalidation_0-rmse:1.54063\n",
"[129]\tvalidation_0-rmse:1.54093\n",
"[130]\tvalidation_0-rmse:1.54071\n",
"[131]\tvalidation_0-rmse:1.54122\n",
"[132]\tvalidation_0-rmse:1.54085\n",
"[133]\tvalidation_0-rmse:1.54064\n",
"[134]\tvalidation_0-rmse:1.54074\n",
"[135]\tvalidation_0-rmse:1.54074\n",
"[136]\tvalidation_0-rmse:1.54093\n",
"[137]\tvalidation_0-rmse:1.54126\n",
"[138]\tvalidation_0-rmse:1.54151\n",
"[139]\tvalidation_0-rmse:1.54133\n",
"[140]\tvalidation_0-rmse:1.54094\n",
"[141]\tvalidation_0-rmse:1.54110\n",
"[142]\tvalidation_0-rmse:1.54095\n",
"[143]\tvalidation_0-rmse:1.54059\n",
"[144]\tvalidation_0-rmse:1.54076\n",
"[145]\tvalidation_0-rmse:1.54051\n",
"[146]\tvalidation_0-rmse:1.54051\n",
"[147]\tvalidation_0-rmse:1.54048\n",
"[148]\tvalidation_0-rmse:1.54075\n",
"[149]\tvalidation_0-rmse:1.54041\n",
"[150]\tvalidation_0-rmse:1.54046\n",
"[151]\tvalidation_0-rmse:1.54049\n",
"[152]\tvalidation_0-rmse:1.54059\n",
"[153]\tvalidation_0-rmse:1.54071\n",
"[154]\tvalidation_0-rmse:1.54093\n",
"[155]\tvalidation_0-rmse:1.54096\n",
"[156]\tvalidation_0-rmse:1.54086\n",
"[157]\tvalidation_0-rmse:1.54078\n",
"[158]\tvalidation_0-rmse:1.54068\n",
"[159]\tvalidation_0-rmse:1.54088\n",
"[160]\tvalidation_0-rmse:1.54083\n",
"[161]\tvalidation_0-rmse:1.54095\n",
"[162]\tvalidation_0-rmse:1.54101\n",
"[163]\tvalidation_0-rmse:1.54129\n",
"[164]\tvalidation_0-rmse:1.54137\n",
"[165]\tvalidation_0-rmse:1.54155\n",
"[166]\tvalidation_0-rmse:1.54153\n",
"[167]\tvalidation_0-rmse:1.54140\n",
"[168]\tvalidation_0-rmse:1.54139\n",
"[169]\tvalidation_0-rmse:1.54127\n",
"[170]\tvalidation_0-rmse:1.54141\n",
"[171]\tvalidation_0-rmse:1.54161\n",
"[172]\tvalidation_0-rmse:1.54140\n",
"[173]\tvalidation_0-rmse:1.54143\n",
"[174]\tvalidation_0-rmse:1.54157\n",
"[175]\tvalidation_0-rmse:1.54162\n",
"[176]\tvalidation_0-rmse:1.54185\n",
"[177]\tvalidation_0-rmse:1.54195\n",
"[178]\tvalidation_0-rmse:1.54197\n",
"[179]\tvalidation_0-rmse:1.54194\n",
"[180]\tvalidation_0-rmse:1.54193\n",
"[181]\tvalidation_0-rmse:1.54178\n",
"[182]\tvalidation_0-rmse:1.54176\n",
"[183]\tvalidation_0-rmse:1.54183\n",
"[184]\tvalidation_0-rmse:1.54184\n",
"[185]\tvalidation_0-rmse:1.54175\n",
"[186]\tvalidation_0-rmse:1.54184\n",
"[187]\tvalidation_0-rmse:1.54186\n",
"[188]\tvalidation_0-rmse:1.54193\n",
"[189]\tvalidation_0-rmse:1.54198\n",
"[190]\tvalidation_0-rmse:1.54193\n",
"[191]\tvalidation_0-rmse:1.54202\n",
"[192]\tvalidation_0-rmse:1.54176\n",
"[193]\tvalidation_0-rmse:1.54169\n",
"[194]\tvalidation_0-rmse:1.54145\n",
"[195]\tvalidation_0-rmse:1.54146\n",
"[196]\tvalidation_0-rmse:1.54163\n",
"[197]\tvalidation_0-rmse:1.54149\n",
"[198]\tvalidation_0-rmse:1.54164\n",
"[199]\tvalidation_0-rmse:1.54162\n",
"[200]\tvalidation_0-rmse:1.54159\n",
"[201]\tvalidation_0-rmse:1.54151\n",
"[202]\tvalidation_0-rmse:1.54147\n",
"[203]\tvalidation_0-rmse:1.54151\n",
"[204]\tvalidation_0-rmse:1.54146\n",
"[205]\tvalidation_0-rmse:1.54146\n",
"[206]\tvalidation_0-rmse:1.54140\n",
"[207]\tvalidation_0-rmse:1.54141\n",
"[208]\tvalidation_0-rmse:1.54140\n",
"[209]\tvalidation_0-rmse:1.54138\n",
"[210]\tvalidation_0-rmse:1.54134\n",
"[211]\tvalidation_0-rmse:1.54140\n",
"[212]\tvalidation_0-rmse:1.54135\n",
"[213]\tvalidation_0-rmse:1.54140\n",
"[214]\tvalidation_0-rmse:1.54162\n",
"[215]\tvalidation_0-rmse:1.54170\n",
"[216]\tvalidation_0-rmse:1.54171\n",
"[217]\tvalidation_0-rmse:1.54162\n",
"[218]\tvalidation_0-rmse:1.54154\n",
"[219]\tvalidation_0-rmse:1.54153\n",
"[220]\tvalidation_0-rmse:1.54143\n",
"[221]\tvalidation_0-rmse:1.54147\n",
"[222]\tvalidation_0-rmse:1.54136\n",
"[223]\tvalidation_0-rmse:1.54135\n",
"[224]\tvalidation_0-rmse:1.54139\n",
"[225]\tvalidation_0-rmse:1.54149\n",
"[226]\tvalidation_0-rmse:1.54161\n",
"[227]\tvalidation_0-rmse:1.54175\n",
"[228]\tvalidation_0-rmse:1.54174\n",
"[229]\tvalidation_0-rmse:1.54170\n",
"[230]\tvalidation_0-rmse:1.54166\n",
"[231]\tvalidation_0-rmse:1.54162\n",
"[232]\tvalidation_0-rmse:1.54165\n",
"[233]\tvalidation_0-rmse:1.54158\n",
"[234]\tvalidation_0-rmse:1.54157\n",
"[235]\tvalidation_0-rmse:1.54169\n",
"[236]\tvalidation_0-rmse:1.54154\n",
"[237]\tvalidation_0-rmse:1.54158\n",
"[238]\tvalidation_0-rmse:1.54160\n",
"[239]\tvalidation_0-rmse:1.54159\n",
"[240]\tvalidation_0-rmse:1.54167\n",
"[241]\tvalidation_0-rmse:1.54149\n",
"[242]\tvalidation_0-rmse:1.54154\n",
"[243]\tvalidation_0-rmse:1.54148\n",
"[244]\tvalidation_0-rmse:1.54161\n",
"[245]\tvalidation_0-rmse:1.54164\n",
"[246]\tvalidation_0-rmse:1.54164\n",
"[247]\tvalidation_0-rmse:1.54162\n",
"[248]\tvalidation_0-rmse:1.54169\n",
"[249]\tvalidation_0-rmse:1.54165\n",
"[250]\tvalidation_0-rmse:1.54153\n",
"[251]\tvalidation_0-rmse:1.54145\n",
"[252]\tvalidation_0-rmse:1.54148\n",
"[253]\tvalidation_0-rmse:1.54150\n",
"[254]\tvalidation_0-rmse:1.54142\n",
"[255]\tvalidation_0-rmse:1.54155\n",
"[256]\tvalidation_0-rmse:1.54149\n",
"[257]\tvalidation_0-rmse:1.54143\n",
"[258]\tvalidation_0-rmse:1.54139\n",
"[259]\tvalidation_0-rmse:1.54143\n",
"[260]\tvalidation_0-rmse:1.54141\n",
"[261]\tvalidation_0-rmse:1.54147\n",
"[262]\tvalidation_0-rmse:1.54149\n",
"[263]\tvalidation_0-rmse:1.54155\n",
"[264]\tvalidation_0-rmse:1.54157\n",
"[265]\tvalidation_0-rmse:1.54153\n",
"[266]\tvalidation_0-rmse:1.54153\n",
"[267]\tvalidation_0-rmse:1.54162\n",
"[268]\tvalidation_0-rmse:1.54167\n",
"[269]\tvalidation_0-rmse:1.54166\n",
"[270]\tvalidation_0-rmse:1.54160\n",
"[271]\tvalidation_0-rmse:1.54163\n",
"[272]\tvalidation_0-rmse:1.54175\n",
"[273]\tvalidation_0-rmse:1.54167\n",
"[274]\tvalidation_0-rmse:1.54168\n",
"[275]\tvalidation_0-rmse:1.54176\n",
"[276]\tvalidation_0-rmse:1.54172\n",
"[277]\tvalidation_0-rmse:1.54171\n",
"[278]\tvalidation_0-rmse:1.54164\n",
"[279]\tvalidation_0-rmse:1.54162\n",
"[280]\tvalidation_0-rmse:1.54163\n",
"[281]\tvalidation_0-rmse:1.54163\n",
"[282]\tvalidation_0-rmse:1.54167\n",
"[283]\tvalidation_0-rmse:1.54170\n",
"[284]\tvalidation_0-rmse:1.54177\n",
"[285]\tvalidation_0-rmse:1.54190\n",
"[286]\tvalidation_0-rmse:1.54191\n",
"[287]\tvalidation_0-rmse:1.54190\n",
"[288]\tvalidation_0-rmse:1.54199\n",
"[289]\tvalidation_0-rmse:1.54204\n",
"[290]\tvalidation_0-rmse:1.54213\n",
"[291]\tvalidation_0-rmse:1.54205\n",
"[292]\tvalidation_0-rmse:1.54207\n",
"[293]\tvalidation_0-rmse:1.54210\n",
"[294]\tvalidation_0-rmse:1.54212\n",
"[295]\tvalidation_0-rmse:1.54207\n",
"[296]\tvalidation_0-rmse:1.54203\n",
"[297]\tvalidation_0-rmse:1.54210\n",
"[298]\tvalidation_0-rmse:1.54207\n",
"[299]\tvalidation_0-rmse:1.54210\n",
"Final RMSE: 132.15\n",
"Final MAE: 31.29\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sentence_transformers import SentenceTransformer\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
"from xgboost import XGBRegressor\n",
"import torch\n",
"\n",
"# Check for GPU availability\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Sample data for faster prototyping\n",
"df = data.sample(n=10_000, random_state=1)\n",
"\n",
"# Preprocess data\n",
"df['abstract'] = df['abstract'].fillna('')\n",
"df['text'] = df['title'] + \" \" + df['abstract']\n",
"\n",
"# Scale 'year' feature\n",
"scaler = MinMaxScaler()\n",
"df['year_scaled'] = scaler.fit_transform(df[['year']])\n",
"\n",
"# Log transformation for target variable\n",
"df['log_citation'] = np.log1p(df['n_citation'])\n",
"\n",
"# Generate embeddings using GPU\n",
"model = SentenceTransformer('all-mpnet-base-v2', device='cuda')\n",
"embeddings = model.encode(\n",
" df['text'].tolist(),\n",
" batch_size=128,\n",
" show_progress_bar=True,\n",
" device='cuda'\n",
")\n",
"\n",
"# Combine features\n",
"X = np.hstack([embeddings, df['year_scaled'].values.reshape(-1, 1)])\n",
"y = df['log_citation'].values\n",
"\n",
"# Train-validation split\n",
"X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"# Define XGBoost model\n",
"xgb_model = XGBRegressor(\n",
" tree_method=\"hist\",\n",
" device=\"cuda\",\n",
" predictor=\"gpu_predictor\",\n",
" n_estimators=300,\n",
" max_depth=8,\n",
" learning_rate=0.1,\n",
" subsample=0.8\n",
")\n",
"\n",
"# Train model with XGBoost's built-in progress bar\n",
"xgb_model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=True)\n",
"\n",
"# Predictions\n",
"val_pred = xgb_model.predict(X_val)\n",
"\n",
"# Evaluate model\n",
"final_rmse = np.sqrt(mean_squared_error(np.expm1(y_val), np.expm1(val_pred)))\n",
"final_mae = mean_absolute_error(np.expm1(y_val), np.expm1(val_pred))\n",
"\n",
"print(f\"Final RMSE: {final_rmse:.2f}\")\n",
"print(f\"Final MAE: {final_mae:.2f}\")\n"
]
},
{
"cell_type": "markdown",
"id": "ba7f9386",
"metadata": {},
"source": [
"### Testing trained model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00c4118b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n",
"Encoding test data...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "be72fc8920a542269193adbbc75cdc75",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Batches: 0%| | 0/74 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Test Set Regression Metrics:\n",
"RMSE: 210.04\n",
"MAE: 107.10\n",
"R2 Score: -0.34\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import sqlite3\n",
"from sentence_transformers import SentenceTransformer\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
"from xgboost import XGBRegressor\n",
"import torch\n",
"\n",
"# ----------------------------------------\n",
"# 1. Set up and check GPU availability\n",
"# ----------------------------------------\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"Using device: {device}\")\n",
"# Connect to the SQLite database and load the data.\n",
"conn = sqlite3.connect('papers.db')\n",
"test_df = pd.read_sql_query(\"SELECT * FROM papers\", conn)\n",
"conn.close()\n",
"\n",
"# Preprocess test data: handle missing values and combine title and abstract.\n",
"test_df['abstract'] = test_df['abstract'].fillna('')\n",
"test_df['text'] = test_df['title'] + \" \" + test_df['abstract']\n",
"\n",
"# Scale the 'year' feature in the test set using the same scaler fitted on training data.\n",
"test_df['year_scaled'] = scaler.transform(test_df[['year']])\n",
"\n",
"# Prepare the target variable in the test set.\n",
"# (Test set uses the column 'citations' for the citation count.)\n",
"test_df['log_citation'] = np.log1p(test_df['citations'])\n",
"\n",
"# ----------------------------------------\n",
"# 4. Generate text embeddings using a SentenceTransformer model\n",
"# ----------------------------------------\n",
"# Load the pre-trained SentenceTransformer model with GPU support\n",
"embed_model = SentenceTransformer('all-mpnet-base-v2', device='cuda')\n",
"\n",
"# Generate embeddings for the combined text from training and test data.\n",
"\n",
"print(\"Encoding test data...\")\n",
"test_embeddings = embed_model.encode(\n",
" test_df['text'].tolist(),\n",
" batch_size=128,\n",
" show_progress_bar=True,\n",
" device='cuda'\n",
")\n",
"\n",
"# ----------------------------------------\n",
"# 5. Combine features (embeddings + scaled year) and prepare datasets\n",
"# ----------------------------------------\n",
"\n",
"X_test = np.hstack([test_embeddings, test_df['year_scaled'].values.reshape(-1, 1)])\n",
"y_test = test_df['log_citation'].values\n",
"\n",
"# ----------------------------------------\n",
"# 7. Predict on the test set and evaluate regression metrics\n",
"# ----------------------------------------\n",
"# Predict in the log domain\n",
"y_test_pred_log = xgb_model.predict(X_test)\n",
"\n",
"# Convert predictions back to original citation counts using the inverse transformation\n",
"y_test_pred = np.expm1(y_test_pred_log)\n",
"y_test_true = np.expm1(y_test)\n",
"\n",
"# Calculate regression metrics\n",
"rmse = np.sqrt(mean_squared_error(y_test_true, y_test_pred))\n",
"mae = mean_absolute_error(y_test_true, y_test_pred)\n",
"r2 = r2_score(y_test_true, y_test_pred)\n",
"\n",
"print(\"\\nTest Set Regression Metrics:\")\n",
"print(f\"RMSE: {rmse:.2f}\")\n",
"print(f\"MAE: {mae:.2f}\")\n",
"print(f\"R2 Score: {r2:.2f}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d68704e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Top Predicted Citations:\n",
" title year \\\n",
"6131 DSI-Net: Deep Synergistic Interaction Network ... 2021 \n",
"1077 Uni-ControlNet: All-in-One Control to Text-to-... 2023 \n",
"246 Anchor Diffusion for Unsupervised Video Object... 2019 \n",
"\n",
" topic predicted_citations \n",
"6131 LLM 1.189171e+13 \n",
"1077 Diffusion Models 8.380618e+12 \n",
"246 Diffusion Models 3.803213e+12 \n"
]
}
],
"source": [
"test_df['predicted_citations'] = np.expm1(y_test_pred)\n",
"\n",
"# Display the top papers with the highest predicted citations\n",
"top_predictions = test_df[['title', 'year', 'topic', 'predicted_citations']].sort_values(\n",
" by='predicted_citations', ascending=False\n",
").head(3) # Show top 10 papers\n",
"\n",
"print(\"\\nTop Predicted Citations:\")\n",
"print(top_predictions)"
]
},
{
"cell_type": "markdown",
"id": "e592fecf",
"metadata": {},
"source": [
"# Author-Author Network\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "165a106c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/100, Loss: 3.7697\n",
"Epoch 20/100, Loss: 2.6772\n",
"Epoch 30/100, Loss: 1.7634\n",
"Epoch 40/100, Loss: 1.0372\n",
"Epoch 50/100, Loss: 0.5276\n",
"Epoch 60/100, Loss: 0.2340\n",
"Epoch 70/100, Loss: 0.1039\n",
"Epoch 80/100, Loss: 0.0525\n",
"Epoch 90/100, Loss: 0.0334\n",
"Epoch 100/100, Loss: 0.0260\n",
"\n",
"Test Set Regression Metrics:\n",
"RMSE: 144.20\n",
"MAE: 60.33\n",
"R2 Score: -0.27\n",
"\n",
"Top Predicted Citations:\n",
" title year \\\n",
"4119 The Good, the Bad, and the Expert: How Consume... 2015 \n",
"4740 Chinese Term Recognition and Extraction Based ... 2008 \n",
"4045 High efficient hardware allocation framework o... 2015 \n",
"10845 Score-Based Generative Modeling through Stocha... 2020 \n",
"12922 Normalizing Flows for Probabilistic Modeling a... 2019 \n",
"15288 Large language models in medicine 2023 \n",
"14175 Diffusion Models Beat GANs on Image Synthesis 2021 \n",
"15382 RePaint: Inpainting using Denoising Diffusion ... 2022 \n",
"16567 Vote-Selling: Infrastructure and Public Services 2018 \n",
"12577 Squeeze-and-Excitation Networks 2017 \n",
"\n",
" topic predicted_citations \n",
"4119 NaN 2697.807129 \n",
"4740 NaN 2314.161621 \n",
"4045 NaN 1178.770020 \n",
"10845 Diffusion Models 1175.487061 \n",
"12922 Foundation Models 1165.278809 \n",
"15288 Generative Models 1165.196045 \n",
"14175 Generative Models 1120.444824 \n",
"15382 Generative Models 1099.473389 \n",
"16567 LLM 949.788086 \n",
"12577 Foundation Models 900.298340 \n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sqlite3\n",
"from torch_geometric.nn import GCNConv\n",
"from torch_geometric.data import Data\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
"\n",
"# -------------------------------\n",
"# 1. Load and Combine Data\n",
"# -------------------------------\n",
"\n",
"# Load post-2017 dataset from papers.db\n",
"conn = sqlite3.connect('papers.db')\n",
"papers_df = pd.read_sql_query(\"SELECT * FROM papers\", conn)\n",
"conn.close()\n",
"\n",
"# Standardize column names to match\n",
"df = df.rename(columns={'n_citation': 'citations'})\n",
"papers_df = papers_df.rename(columns={'citations': 'citations'})\n",
"\n",
"# Ensure 'authors' field is consistent across datasets\n",
"df['authors'] = df['authors'].fillna('')\n",
"papers_df['authors'] = papers_df['authors'].fillna('')\n",
"\n",
"# Combine datasets\n",
"combined_df = pd.concat([df, papers_df], ignore_index=True)\n",
"\n",
"# Process authors (semicolon-separated)\n",
"combined_df['author_list'] = combined_df['authors'].apply(lambda x: [a.strip() for a in x.split(';') if a.strip() != ''])\n",
"\n",
"# Target variable: log-transform citation counts\n",
"combined_df['log_citations'] = np.log1p(combined_df['citations'])\n",
"\n",
"# -------------------------------\n",
"# 2. Train-Test Split\n",
"# -------------------------------\n",
"train_df, test_df = train_test_split(combined_df, test_size=0.2, random_state=42)\n",
"\n",
"# -------------------------------\n",
"# 3. Build the Author-Author Graph\n",
"# -------------------------------\n",
"# Collect all unique authors from the entire dataset\n",
"all_authors = set()\n",
"for authors in combined_df['author_list']:\n",
" for a in authors:\n",
" all_authors.add(a)\n",
"all_authors = sorted(list(all_authors))\n",
"author_to_id = {author: idx for idx, author in enumerate(all_authors)}\n",
"num_authors = len(all_authors)\n",
"\n",
"# Build edge list: add an undirected edge for each pair of co-authors in a paper\n",
"edges = []\n",
"for authors in combined_df['author_list']:\n",
" author_ids = [author_to_id[a] for a in authors if a in author_to_id]\n",
" for i in range(len(author_ids)):\n",
" for j in range(i + 1, len(author_ids)):\n",
" edges.append((author_ids[i], author_ids[j]))\n",
" edges.append((author_ids[j], author_ids[i]))\n",
"\n",
"# Remove duplicate edges\n",
"edges = list(set(edges))\n",
"edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n",
"\n",
"# Create the PyTorch Geometric Data object\n",
"data = Data(edge_index=edge_index)\n",
"data.num_nodes = num_authors\n",
"\n",
"# -------------------------------\n",
"# 4. Map Papers to Author Node Indices\n",
"# -------------------------------\n",
"def get_author_ids(author_list):\n",
" return [author_to_id[a] for a in author_list if a in author_to_id]\n",
"\n",
"train_paper_author_ids = train_df['author_list'].apply(get_author_ids).tolist()\n",
"test_paper_author_ids = test_df['author_list'].apply(get_author_ids).tolist()\n",
"\n",
"# Convert target values to torch tensors\n",
"y_train = torch.tensor(train_df['log_citations'].values, dtype=torch.float)\n",
"y_test = torch.tensor(test_df['log_citations'].values, dtype=torch.float)\n",
"\n",
"# -------------------------------\n",
"# 5. Define the GNN Regressor Model\n",
"# -------------------------------\n",
"class AuthorGNNRegressor(nn.Module):\n",
" def __init__(self, num_nodes, in_channels=32, hidden_channels=64, out_channels=64):\n",
" super(AuthorGNNRegressor, self).__init__()\n",
" self.author_embeddings = nn.Embedding(num_nodes, in_channels)\n",
" self.conv1 = GCNConv(in_channels, hidden_channels)\n",
" self.conv2 = GCNConv(hidden_channels, out_channels)\n",
" self.regressor = nn.Linear(out_channels, 1)\n",
"\n",
" def forward(self, data, paper_author_ids):\n",
" x = self.author_embeddings.weight\n",
" x = self.conv1(x, data.edge_index)\n",
" x = torch.relu(x)\n",
" x = self.conv2(x, data.edge_index)\n",
"\n",
" paper_embeddings = []\n",
" for author_ids in paper_author_ids:\n",
" if len(author_ids) > 0:\n",
" authors_tensor = x[author_ids]\n",
" paper_emb = authors_tensor.mean(dim=0)\n",
" else:\n",
" paper_emb = torch.zeros(x.size(1), device=x.device)\n",
" paper_embeddings.append(paper_emb)\n",
" paper_embeddings = torch.stack(paper_embeddings, dim=0)\n",
" out = self.regressor(paper_embeddings)\n",
" return out.squeeze()\n",
"\n",
"# Instantiate the model, optimizer, and loss function\n",
"model = AuthorGNNRegressor(num_nodes=num_authors)\n",
"optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
"criterion = nn.MSELoss()\n",
"\n",
"# -------------------------------\n",
"# 6. Training Loop\n",
"# -------------------------------\n",
"num_epochs = 100\n",
"model.train()\n",
"for epoch in range(num_epochs):\n",
" optimizer.zero_grad()\n",
" preds = model(data, train_paper_author_ids)\n",
" loss = criterion(preds, y_train)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if (epoch + 1) % 10 == 0:\n",
" print(f\"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}\")\n",
"\n",
"# -------------------------------\n",
"# 7. Evaluation on Test Set\n",
"# -------------------------------\n",
"model.eval()\n",
"with torch.no_grad():\n",
" test_preds = model(data, test_paper_author_ids)\n",
" test_preds_citations = torch.expm1(test_preds).cpu().numpy()\n",
" y_test_citations = torch.expm1(y_test).cpu().numpy()\n",
"\n",
"rmse = np.sqrt(mean_squared_error(y_test_citations, test_preds_citations))\n",
"mae = mean_absolute_error(y_test_citations, test_preds_citations)\n",
"r2 = r2_score(y_test_citations, test_preds_citations)\n",
"\n",
"print(\"\\nTest Set Regression Metrics:\")\n",
"print(f\"RMSE: {rmse:.2f}\")\n",
"print(f\"MAE: {mae:.2f}\")\n",
"print(f\"R2 Score: {r2:.2f}\")\n",
"\n",
"# -------------------------------\n",
"# 8. Display Top Predicted Citations\n",
"# -------------------------------\n",
"test_df = test_df.copy()\n",
"test_df['predicted_citations'] = test_preds_citations\n",
"top_predictions = test_df[['title', 'year', 'topic', 'predicted_citations']].sort_values(\n",
" by='predicted_citations', ascending=False\n",
").head(10)\n",
"\n",
"print(\"\\nTop Predicted Citations:\")\n",
"print(top_predictions)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
GitHub Events
Total
- Push event: 1
- Create event: 1
Last Year
- Push event: 1
- Create event: 1
Issues and Pull Requests
Last synced: 7 months ago