regressor

Regressing how many citations a given paper will have based on its abstract, title, year

https://github.com/theveryhim/regressor

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Committers with academic emails
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (6.4%) to scientific vocabulary

Keywords

data-science graph learning regression
Last synced: 6 months ago · JSON representation ·

Repository

Regressing how many citations a given paper will have based on its abstract, title, year

Basic Info
  • Host: GitHub
  • Owner: theveryhim
  • License: mit
  • Language: Jupyter Notebook
  • Default Branch: main
  • Homepage:
  • Size: 5.36 MB
Statistics
  • Stars: 0
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Topics
data-science graph learning regression
Created 8 months ago · Last pushed 8 months ago
Metadata Files
Readme License Citation

README.md

Citation Regressor

In this section we want to create a model that can estimate how many citations a given paper will have based on its abstract, title, year, and author-author network.

Regressor

```markdown Top Predicted Citations: title year \ 6131 DSI-Net: Deep Synergistic Interaction Network ... 2021
1077 Uni-ControlNet: All-in-One Control to Text-to-... 2023
246 Anchor Diffusion for Unsupervised Video Object... 2019

             topic  predicted_citations  

6131 LLM 1.189171e+13
1077 Diffusion Models 8.380618e+12
246 Diffusion Models 3.803213e+12 ```

Author-Author Network

The idea is that scholarly paper authors have relations with each other that may be a great predictor of how many citations a given paper is going to have. ```markdown Test Set Regression Metrics: RMSE: 144.20 MAE: 60.33 R2 Score: -0.27

Top Predicted Citations: title year \ 4119 The Good, the Bad, and the Expert: How Consume... 2015
4740 Chinese Term Recognition and Extraction Based ... 2008
4045 High efficient hardware allocation framework o... 2015
10845 Score-Based Generative Modeling through Stocha... 2020

               topic  predicted_citations  

4119 NaN 2697.807129
4740 NaN 2314.161621
4045 NaN 1178.770020
10845 Diffusion Models 1175.487061 ```

Owner

  • Name: Arman Yazdani
  • Login: theveryhim
  • Kind: user
  • Location: Tehran-Iran

Currently Electrical Engineering B.Sc student at Sharif university of technology.

Citation (Citation Regressor.ipynb)

{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b118e482",
   "metadata": {},
   "source": [
    "# Load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a19ca19a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 100000 entries, 0 to 99999\n",
      "Data columns (total 9 columns):\n",
      " #   Column      Non-Null Count   Dtype \n",
      "---  ------      --------------   ----- \n",
      " 0   Unnamed: 0  100000 non-null  int64 \n",
      " 1   abstract    82990 non-null   object\n",
      " 2   authors     99999 non-null   object\n",
      " 3   n_citation  100000 non-null  int64 \n",
      " 4   references  87625 non-null   object\n",
      " 5   title       100000 non-null  object\n",
      " 6   venue       82305 non-null   object\n",
      " 7   year        100000 non-null  int64 \n",
      " 8   id          100000 non-null  object\n",
      "dtypes: int64(3), object(6)\n",
      "memory usage: 6.9+ MB\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "data = pd.read_csv(\"dblp-v11.csv\")\n",
    "data.info()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b55cd22c",
   "metadata": {},
   "source": [
    "# Regressor"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fcdeddb",
   "metadata": {},
   "source": [
    "### Model selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a1580b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e4d02ec814d84b8eb48ba5f8d444adc9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training XGBoost...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [02:07:06] WARNING: /workspace/src/common/error_msg.cc:27: The tree method `gpu_hist` is deprecated since 2.0.0. To use GPU training, set the `device` parameter to CUDA instead.\n",
      "\n",
      "    E.g. tree_method = \"hist\", device = \"cuda\"\n",
      "\n",
      "  warnings.warn(smsg, UserWarning)\n",
      "/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [02:07:06] WARNING: /workspace/src/learner.cc:740: \n",
      "Parameters: { \"predictor\" } are not used.\n",
      "\n",
      "  warnings.warn(smsg, UserWarning)\n",
      "/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [02:07:09] WARNING: /workspace/src/common/error_msg.cc:27: The tree method `gpu_hist` is deprecated since 2.0.0. To use GPU training, set the `device` parameter to CUDA instead.\n",
      "\n",
      "    E.g. tree_method = \"hist\", device = \"cuda\"\n",
      "\n",
      "  warnings.warn(smsg, UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "XGBoost Validation RMSE: 54.23\n",
      "\n",
      "Training RandomForest...\n",
      "RandomForest Validation RMSE: 54.32\n",
      "\n",
      "Best Model: XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
      "             colsample_bylevel=None, colsample_bynode=None,\n",
      "             colsample_bytree=None, device='cuda:0', early_stopping_rounds=None,\n",
      "             enable_categorical=False, eval_metric=None, feature_types=None,\n",
      "             gamma=None, grow_policy=None, importance_type=None,\n",
      "             interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
      "             max_cat_threshold=None, max_cat_to_onehot=None,\n",
      "             max_delta_step=None, max_depth=6, max_leaves=None,\n",
      "             min_child_weight=None, missing=nan, monotone_constraints=None,\n",
      "             multi_strategy=None, n_estimators=200, n_jobs=None,\n",
      "             num_parallel_tree=None, predictor='gpu_predictor', ...)\n",
      "Final RMSE: 54.23\n",
      "Final MAE: 23.98\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "from xgboost import XGBRegressor, DMatrix\n",
    "import torch\n",
    "\n",
    "# Check for GPU availability\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "# Sample data for faster prototyping\n",
    "df = data.sample(n=1_000, random_state=1)\n",
    "\n",
    "# Preprocess data\n",
    "# Handle missing values\n",
    "df['abstract'] = df['abstract'].fillna('')\n",
    "df['text'] = df['title'] + \" \" + df['abstract']\n",
    "\n",
    "# Prepare features\n",
    "scaler = MinMaxScaler()\n",
    "df['year_scaled'] = scaler.fit_transform(df[['year']])\n",
    "\n",
    "# Prepare target with log transformation\n",
    "df['log_citation'] = np.log1p(df['n_citation'])\n",
    "\n",
    "# Generate embeddings using GPU\n",
    "model = SentenceTransformer('all-mpnet-base-v2', device='cuda')\n",
    "embeddings = model.encode(\n",
    "    df['text'].tolist(),\n",
    "    batch_size=128,\n",
    "    show_progress_bar=True,\n",
    "    device='cuda'\n",
    ")\n",
    "\n",
    "# Combine features\n",
    "X = np.hstack([embeddings, df['year_scaled'].values.reshape(-1, 1)])\n",
    "y = df['log_citation'].values\n",
    "\n",
    "# Train-validation split\n",
    "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "\n",
    "# Define and train models with GridSearch\n",
    "models = [\n",
    "    {\n",
    "        'name': 'XGBoost',\n",
    "        'model': XGBRegressor(tree_method = \"hist\", device = \"cuda\", predictor='gpu_predictor'),\n",
    "        'params': {\n",
    "            'n_estimators': [200, 300],\n",
    "            'max_depth': [6, 8],\n",
    "            'learning_rate': [0.05, 0.1],\n",
    "            'subsample': [0.8, 1.0]\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        'name': 'RandomForest',\n",
    "        'model': RandomForestRegressor(),\n",
    "        'params': {\n",
    "            'n_estimators': [200, 300],\n",
    "            'max_depth': [None, 20],\n",
    "            'min_samples_split': [3, 5]\n",
    "        }\n",
    "    }\n",
    "]\n",
    "\n",
    "best_model = None\n",
    "best_score = float('inf')\n",
    "\n",
    "for m in models:\n",
    "    print(f\"\\nTraining {m['name']}...\")\n",
    "    grid = GridSearchCV(\n",
    "        m['model'],\n",
    "        m['params'],\n",
    "        cv=3,\n",
    "        scoring='neg_mean_squared_error',\n",
    "        n_jobs=-1\n",
    "    )\n",
    "\n",
    "    # Use DMatrix for XGBoost to ensure GPU compatibility\n",
    "    if m['name'] == 'XGBoost':\n",
    "        dtrain = DMatrix(X_train, label=y_train, enable_categorical=True)\n",
    "        dval = DMatrix(X_val, label=y_val, enable_categorical=True)\n",
    "        grid.fit(dtrain.get_data(), dtrain.get_label())\n",
    "    else:\n",
    "        grid.fit(X_train, y_train)\n",
    "\n",
    "    # Evaluate\n",
    "    if m['name'] == 'XGBoost':\n",
    "        val_pred = grid.best_estimator_.predict(dval.get_data())\n",
    "    else:\n",
    "        val_pred = grid.best_estimator_.predict(X_val)\n",
    "\n",
    "    rmse = np.sqrt(mean_squared_error(np.expm1(y_val), np.expm1(val_pred)))\n",
    "    print(f\"{m['name']} Validation RMSE: {rmse:.2f}\")\n",
    "\n",
    "    if rmse < best_score:\n",
    "        best_score = rmse\n",
    "        best_model = grid.best_estimator_\n",
    "\n",
    "# Evaluate best model\n",
    "print(f\"\\nBest Model: {best_model}\")\n",
    "if isinstance(best_model, XGBRegressor):\n",
    "    final_pred = best_model.predict(DMatrix(X_val, enable_categorical=True).get_data())\n",
    "else:\n",
    "    final_pred = best_model.predict(X_val)\n",
    "\n",
    "final_rmse = np.sqrt(mean_squared_error(np.expm1(y_val), np.expm1(final_pred)))\n",
    "final_mae = mean_absolute_error(np.expm1(y_val), np.expm1(final_pred))\n",
    "\n",
    "print(f\"Final RMSE: {final_rmse:.2f}\")\n",
    "print(f\"Final MAE: {final_mae:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "421503ab",
   "metadata": {},
   "source": [
    "### Training selected model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f89182a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a99a85c89b954b45a5f54c866e17a3ef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/79 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\tvalidation_0-rmse:1.73811\n",
      "[1]\tvalidation_0-rmse:1.70124\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [03:03:57] WARNING: /workspace/src/learner.cc:740: \n",
      "Parameters: { \"predictor\" } are not used.\n",
      "\n",
      "  warnings.warn(smsg, UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2]\tvalidation_0-rmse:1.66960\n",
      "[3]\tvalidation_0-rmse:1.64754\n",
      "[4]\tvalidation_0-rmse:1.62967\n",
      "[5]\tvalidation_0-rmse:1.61450\n",
      "[6]\tvalidation_0-rmse:1.60012\n",
      "[7]\tvalidation_0-rmse:1.58993\n",
      "[8]\tvalidation_0-rmse:1.58093\n",
      "[9]\tvalidation_0-rmse:1.57076\n",
      "[10]\tvalidation_0-rmse:1.56556\n",
      "[11]\tvalidation_0-rmse:1.56057\n",
      "[12]\tvalidation_0-rmse:1.55657\n",
      "[13]\tvalidation_0-rmse:1.55478\n",
      "[14]\tvalidation_0-rmse:1.55072\n",
      "[15]\tvalidation_0-rmse:1.54898\n",
      "[16]\tvalidation_0-rmse:1.54518\n",
      "[17]\tvalidation_0-rmse:1.54427\n",
      "[18]\tvalidation_0-rmse:1.54305\n",
      "[19]\tvalidation_0-rmse:1.54245\n",
      "[20]\tvalidation_0-rmse:1.54107\n",
      "[21]\tvalidation_0-rmse:1.53897\n",
      "[22]\tvalidation_0-rmse:1.53853\n",
      "[23]\tvalidation_0-rmse:1.53753\n",
      "[24]\tvalidation_0-rmse:1.53837\n",
      "[25]\tvalidation_0-rmse:1.53842\n",
      "[26]\tvalidation_0-rmse:1.53715\n",
      "[27]\tvalidation_0-rmse:1.53738\n",
      "[28]\tvalidation_0-rmse:1.53842\n",
      "[29]\tvalidation_0-rmse:1.53723\n",
      "[30]\tvalidation_0-rmse:1.53832\n",
      "[31]\tvalidation_0-rmse:1.53784\n",
      "[32]\tvalidation_0-rmse:1.53869\n",
      "[33]\tvalidation_0-rmse:1.53745\n",
      "[34]\tvalidation_0-rmse:1.53706\n",
      "[35]\tvalidation_0-rmse:1.53698\n",
      "[36]\tvalidation_0-rmse:1.53679\n",
      "[37]\tvalidation_0-rmse:1.53661\n",
      "[38]\tvalidation_0-rmse:1.53629\n",
      "[39]\tvalidation_0-rmse:1.53672\n",
      "[40]\tvalidation_0-rmse:1.53643\n",
      "[41]\tvalidation_0-rmse:1.53557\n",
      "[42]\tvalidation_0-rmse:1.53486\n",
      "[43]\tvalidation_0-rmse:1.53486\n",
      "[44]\tvalidation_0-rmse:1.53502\n",
      "[45]\tvalidation_0-rmse:1.53571\n",
      "[46]\tvalidation_0-rmse:1.53635\n",
      "[47]\tvalidation_0-rmse:1.53587\n",
      "[48]\tvalidation_0-rmse:1.53554\n",
      "[49]\tvalidation_0-rmse:1.53529\n",
      "[50]\tvalidation_0-rmse:1.53521\n",
      "[51]\tvalidation_0-rmse:1.53492\n",
      "[52]\tvalidation_0-rmse:1.53551\n",
      "[53]\tvalidation_0-rmse:1.53580\n",
      "[54]\tvalidation_0-rmse:1.53575\n",
      "[55]\tvalidation_0-rmse:1.53537\n",
      "[56]\tvalidation_0-rmse:1.53573\n",
      "[57]\tvalidation_0-rmse:1.53613\n",
      "[58]\tvalidation_0-rmse:1.53596\n",
      "[59]\tvalidation_0-rmse:1.53620\n",
      "[60]\tvalidation_0-rmse:1.53617\n",
      "[61]\tvalidation_0-rmse:1.53645\n",
      "[62]\tvalidation_0-rmse:1.53656\n",
      "[63]\tvalidation_0-rmse:1.53563\n",
      "[64]\tvalidation_0-rmse:1.53587\n",
      "[65]\tvalidation_0-rmse:1.53584\n",
      "[66]\tvalidation_0-rmse:1.53546\n",
      "[67]\tvalidation_0-rmse:1.53558\n",
      "[68]\tvalidation_0-rmse:1.53642\n",
      "[69]\tvalidation_0-rmse:1.53650\n",
      "[70]\tvalidation_0-rmse:1.53685\n",
      "[71]\tvalidation_0-rmse:1.53671\n",
      "[72]\tvalidation_0-rmse:1.53636\n",
      "[73]\tvalidation_0-rmse:1.53633\n",
      "[74]\tvalidation_0-rmse:1.53618\n",
      "[75]\tvalidation_0-rmse:1.53579\n",
      "[76]\tvalidation_0-rmse:1.53618\n",
      "[77]\tvalidation_0-rmse:1.53687\n",
      "[78]\tvalidation_0-rmse:1.53681\n",
      "[79]\tvalidation_0-rmse:1.53705\n",
      "[80]\tvalidation_0-rmse:1.53713\n",
      "[81]\tvalidation_0-rmse:1.53703\n",
      "[82]\tvalidation_0-rmse:1.53746\n",
      "[83]\tvalidation_0-rmse:1.53767\n",
      "[84]\tvalidation_0-rmse:1.53796\n",
      "[85]\tvalidation_0-rmse:1.53813\n",
      "[86]\tvalidation_0-rmse:1.53808\n",
      "[87]\tvalidation_0-rmse:1.53813\n",
      "[88]\tvalidation_0-rmse:1.53781\n",
      "[89]\tvalidation_0-rmse:1.53786\n",
      "[90]\tvalidation_0-rmse:1.53778\n",
      "[91]\tvalidation_0-rmse:1.53789\n",
      "[92]\tvalidation_0-rmse:1.53832\n",
      "[93]\tvalidation_0-rmse:1.53814\n",
      "[94]\tvalidation_0-rmse:1.53796\n",
      "[95]\tvalidation_0-rmse:1.53770\n",
      "[96]\tvalidation_0-rmse:1.53795\n",
      "[97]\tvalidation_0-rmse:1.53792\n",
      "[98]\tvalidation_0-rmse:1.53820\n",
      "[99]\tvalidation_0-rmse:1.53827\n",
      "[100]\tvalidation_0-rmse:1.53829\n",
      "[101]\tvalidation_0-rmse:1.53854\n",
      "[102]\tvalidation_0-rmse:1.53884\n",
      "[103]\tvalidation_0-rmse:1.53899\n",
      "[104]\tvalidation_0-rmse:1.53848\n",
      "[105]\tvalidation_0-rmse:1.53856\n",
      "[106]\tvalidation_0-rmse:1.53863\n",
      "[107]\tvalidation_0-rmse:1.53919\n",
      "[108]\tvalidation_0-rmse:1.53917\n",
      "[109]\tvalidation_0-rmse:1.53934\n",
      "[110]\tvalidation_0-rmse:1.53985\n",
      "[111]\tvalidation_0-rmse:1.54002\n",
      "[112]\tvalidation_0-rmse:1.53969\n",
      "[113]\tvalidation_0-rmse:1.53989\n",
      "[114]\tvalidation_0-rmse:1.53983\n",
      "[115]\tvalidation_0-rmse:1.53989\n",
      "[116]\tvalidation_0-rmse:1.53948\n",
      "[117]\tvalidation_0-rmse:1.53961\n",
      "[118]\tvalidation_0-rmse:1.53941\n",
      "[119]\tvalidation_0-rmse:1.53957\n",
      "[120]\tvalidation_0-rmse:1.53985\n",
      "[121]\tvalidation_0-rmse:1.54026\n",
      "[122]\tvalidation_0-rmse:1.54030\n",
      "[123]\tvalidation_0-rmse:1.54027\n",
      "[124]\tvalidation_0-rmse:1.54036\n",
      "[125]\tvalidation_0-rmse:1.54013\n",
      "[126]\tvalidation_0-rmse:1.54027\n",
      "[127]\tvalidation_0-rmse:1.54054\n",
      "[128]\tvalidation_0-rmse:1.54063\n",
      "[129]\tvalidation_0-rmse:1.54093\n",
      "[130]\tvalidation_0-rmse:1.54071\n",
      "[131]\tvalidation_0-rmse:1.54122\n",
      "[132]\tvalidation_0-rmse:1.54085\n",
      "[133]\tvalidation_0-rmse:1.54064\n",
      "[134]\tvalidation_0-rmse:1.54074\n",
      "[135]\tvalidation_0-rmse:1.54074\n",
      "[136]\tvalidation_0-rmse:1.54093\n",
      "[137]\tvalidation_0-rmse:1.54126\n",
      "[138]\tvalidation_0-rmse:1.54151\n",
      "[139]\tvalidation_0-rmse:1.54133\n",
      "[140]\tvalidation_0-rmse:1.54094\n",
      "[141]\tvalidation_0-rmse:1.54110\n",
      "[142]\tvalidation_0-rmse:1.54095\n",
      "[143]\tvalidation_0-rmse:1.54059\n",
      "[144]\tvalidation_0-rmse:1.54076\n",
      "[145]\tvalidation_0-rmse:1.54051\n",
      "[146]\tvalidation_0-rmse:1.54051\n",
      "[147]\tvalidation_0-rmse:1.54048\n",
      "[148]\tvalidation_0-rmse:1.54075\n",
      "[149]\tvalidation_0-rmse:1.54041\n",
      "[150]\tvalidation_0-rmse:1.54046\n",
      "[151]\tvalidation_0-rmse:1.54049\n",
      "[152]\tvalidation_0-rmse:1.54059\n",
      "[153]\tvalidation_0-rmse:1.54071\n",
      "[154]\tvalidation_0-rmse:1.54093\n",
      "[155]\tvalidation_0-rmse:1.54096\n",
      "[156]\tvalidation_0-rmse:1.54086\n",
      "[157]\tvalidation_0-rmse:1.54078\n",
      "[158]\tvalidation_0-rmse:1.54068\n",
      "[159]\tvalidation_0-rmse:1.54088\n",
      "[160]\tvalidation_0-rmse:1.54083\n",
      "[161]\tvalidation_0-rmse:1.54095\n",
      "[162]\tvalidation_0-rmse:1.54101\n",
      "[163]\tvalidation_0-rmse:1.54129\n",
      "[164]\tvalidation_0-rmse:1.54137\n",
      "[165]\tvalidation_0-rmse:1.54155\n",
      "[166]\tvalidation_0-rmse:1.54153\n",
      "[167]\tvalidation_0-rmse:1.54140\n",
      "[168]\tvalidation_0-rmse:1.54139\n",
      "[169]\tvalidation_0-rmse:1.54127\n",
      "[170]\tvalidation_0-rmse:1.54141\n",
      "[171]\tvalidation_0-rmse:1.54161\n",
      "[172]\tvalidation_0-rmse:1.54140\n",
      "[173]\tvalidation_0-rmse:1.54143\n",
      "[174]\tvalidation_0-rmse:1.54157\n",
      "[175]\tvalidation_0-rmse:1.54162\n",
      "[176]\tvalidation_0-rmse:1.54185\n",
      "[177]\tvalidation_0-rmse:1.54195\n",
      "[178]\tvalidation_0-rmse:1.54197\n",
      "[179]\tvalidation_0-rmse:1.54194\n",
      "[180]\tvalidation_0-rmse:1.54193\n",
      "[181]\tvalidation_0-rmse:1.54178\n",
      "[182]\tvalidation_0-rmse:1.54176\n",
      "[183]\tvalidation_0-rmse:1.54183\n",
      "[184]\tvalidation_0-rmse:1.54184\n",
      "[185]\tvalidation_0-rmse:1.54175\n",
      "[186]\tvalidation_0-rmse:1.54184\n",
      "[187]\tvalidation_0-rmse:1.54186\n",
      "[188]\tvalidation_0-rmse:1.54193\n",
      "[189]\tvalidation_0-rmse:1.54198\n",
      "[190]\tvalidation_0-rmse:1.54193\n",
      "[191]\tvalidation_0-rmse:1.54202\n",
      "[192]\tvalidation_0-rmse:1.54176\n",
      "[193]\tvalidation_0-rmse:1.54169\n",
      "[194]\tvalidation_0-rmse:1.54145\n",
      "[195]\tvalidation_0-rmse:1.54146\n",
      "[196]\tvalidation_0-rmse:1.54163\n",
      "[197]\tvalidation_0-rmse:1.54149\n",
      "[198]\tvalidation_0-rmse:1.54164\n",
      "[199]\tvalidation_0-rmse:1.54162\n",
      "[200]\tvalidation_0-rmse:1.54159\n",
      "[201]\tvalidation_0-rmse:1.54151\n",
      "[202]\tvalidation_0-rmse:1.54147\n",
      "[203]\tvalidation_0-rmse:1.54151\n",
      "[204]\tvalidation_0-rmse:1.54146\n",
      "[205]\tvalidation_0-rmse:1.54146\n",
      "[206]\tvalidation_0-rmse:1.54140\n",
      "[207]\tvalidation_0-rmse:1.54141\n",
      "[208]\tvalidation_0-rmse:1.54140\n",
      "[209]\tvalidation_0-rmse:1.54138\n",
      "[210]\tvalidation_0-rmse:1.54134\n",
      "[211]\tvalidation_0-rmse:1.54140\n",
      "[212]\tvalidation_0-rmse:1.54135\n",
      "[213]\tvalidation_0-rmse:1.54140\n",
      "[214]\tvalidation_0-rmse:1.54162\n",
      "[215]\tvalidation_0-rmse:1.54170\n",
      "[216]\tvalidation_0-rmse:1.54171\n",
      "[217]\tvalidation_0-rmse:1.54162\n",
      "[218]\tvalidation_0-rmse:1.54154\n",
      "[219]\tvalidation_0-rmse:1.54153\n",
      "[220]\tvalidation_0-rmse:1.54143\n",
      "[221]\tvalidation_0-rmse:1.54147\n",
      "[222]\tvalidation_0-rmse:1.54136\n",
      "[223]\tvalidation_0-rmse:1.54135\n",
      "[224]\tvalidation_0-rmse:1.54139\n",
      "[225]\tvalidation_0-rmse:1.54149\n",
      "[226]\tvalidation_0-rmse:1.54161\n",
      "[227]\tvalidation_0-rmse:1.54175\n",
      "[228]\tvalidation_0-rmse:1.54174\n",
      "[229]\tvalidation_0-rmse:1.54170\n",
      "[230]\tvalidation_0-rmse:1.54166\n",
      "[231]\tvalidation_0-rmse:1.54162\n",
      "[232]\tvalidation_0-rmse:1.54165\n",
      "[233]\tvalidation_0-rmse:1.54158\n",
      "[234]\tvalidation_0-rmse:1.54157\n",
      "[235]\tvalidation_0-rmse:1.54169\n",
      "[236]\tvalidation_0-rmse:1.54154\n",
      "[237]\tvalidation_0-rmse:1.54158\n",
      "[238]\tvalidation_0-rmse:1.54160\n",
      "[239]\tvalidation_0-rmse:1.54159\n",
      "[240]\tvalidation_0-rmse:1.54167\n",
      "[241]\tvalidation_0-rmse:1.54149\n",
      "[242]\tvalidation_0-rmse:1.54154\n",
      "[243]\tvalidation_0-rmse:1.54148\n",
      "[244]\tvalidation_0-rmse:1.54161\n",
      "[245]\tvalidation_0-rmse:1.54164\n",
      "[246]\tvalidation_0-rmse:1.54164\n",
      "[247]\tvalidation_0-rmse:1.54162\n",
      "[248]\tvalidation_0-rmse:1.54169\n",
      "[249]\tvalidation_0-rmse:1.54165\n",
      "[250]\tvalidation_0-rmse:1.54153\n",
      "[251]\tvalidation_0-rmse:1.54145\n",
      "[252]\tvalidation_0-rmse:1.54148\n",
      "[253]\tvalidation_0-rmse:1.54150\n",
      "[254]\tvalidation_0-rmse:1.54142\n",
      "[255]\tvalidation_0-rmse:1.54155\n",
      "[256]\tvalidation_0-rmse:1.54149\n",
      "[257]\tvalidation_0-rmse:1.54143\n",
      "[258]\tvalidation_0-rmse:1.54139\n",
      "[259]\tvalidation_0-rmse:1.54143\n",
      "[260]\tvalidation_0-rmse:1.54141\n",
      "[261]\tvalidation_0-rmse:1.54147\n",
      "[262]\tvalidation_0-rmse:1.54149\n",
      "[263]\tvalidation_0-rmse:1.54155\n",
      "[264]\tvalidation_0-rmse:1.54157\n",
      "[265]\tvalidation_0-rmse:1.54153\n",
      "[266]\tvalidation_0-rmse:1.54153\n",
      "[267]\tvalidation_0-rmse:1.54162\n",
      "[268]\tvalidation_0-rmse:1.54167\n",
      "[269]\tvalidation_0-rmse:1.54166\n",
      "[270]\tvalidation_0-rmse:1.54160\n",
      "[271]\tvalidation_0-rmse:1.54163\n",
      "[272]\tvalidation_0-rmse:1.54175\n",
      "[273]\tvalidation_0-rmse:1.54167\n",
      "[274]\tvalidation_0-rmse:1.54168\n",
      "[275]\tvalidation_0-rmse:1.54176\n",
      "[276]\tvalidation_0-rmse:1.54172\n",
      "[277]\tvalidation_0-rmse:1.54171\n",
      "[278]\tvalidation_0-rmse:1.54164\n",
      "[279]\tvalidation_0-rmse:1.54162\n",
      "[280]\tvalidation_0-rmse:1.54163\n",
      "[281]\tvalidation_0-rmse:1.54163\n",
      "[282]\tvalidation_0-rmse:1.54167\n",
      "[283]\tvalidation_0-rmse:1.54170\n",
      "[284]\tvalidation_0-rmse:1.54177\n",
      "[285]\tvalidation_0-rmse:1.54190\n",
      "[286]\tvalidation_0-rmse:1.54191\n",
      "[287]\tvalidation_0-rmse:1.54190\n",
      "[288]\tvalidation_0-rmse:1.54199\n",
      "[289]\tvalidation_0-rmse:1.54204\n",
      "[290]\tvalidation_0-rmse:1.54213\n",
      "[291]\tvalidation_0-rmse:1.54205\n",
      "[292]\tvalidation_0-rmse:1.54207\n",
      "[293]\tvalidation_0-rmse:1.54210\n",
      "[294]\tvalidation_0-rmse:1.54212\n",
      "[295]\tvalidation_0-rmse:1.54207\n",
      "[296]\tvalidation_0-rmse:1.54203\n",
      "[297]\tvalidation_0-rmse:1.54210\n",
      "[298]\tvalidation_0-rmse:1.54207\n",
      "[299]\tvalidation_0-rmse:1.54210\n",
      "Final RMSE: 132.15\n",
      "Final MAE: 31.29\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
    "from xgboost import XGBRegressor\n",
    "import torch\n",
    "\n",
    "# Check for GPU availability\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "# Sample data for faster prototyping\n",
    "df = data.sample(n=10_000, random_state=1)\n",
    "\n",
    "# Preprocess data\n",
    "df['abstract'] = df['abstract'].fillna('')\n",
    "df['text'] = df['title'] + \" \" + df['abstract']\n",
    "\n",
    "# Scale 'year' feature\n",
    "scaler = MinMaxScaler()\n",
    "df['year_scaled'] = scaler.fit_transform(df[['year']])\n",
    "\n",
    "# Log transformation for target variable\n",
    "df['log_citation'] = np.log1p(df['n_citation'])\n",
    "\n",
    "# Generate embeddings using GPU\n",
    "model = SentenceTransformer('all-mpnet-base-v2', device='cuda')\n",
    "embeddings = model.encode(\n",
    "    df['text'].tolist(),\n",
    "    batch_size=128,\n",
    "    show_progress_bar=True,\n",
    "    device='cuda'\n",
    ")\n",
    "\n",
    "# Combine features\n",
    "X = np.hstack([embeddings, df['year_scaled'].values.reshape(-1, 1)])\n",
    "y = df['log_citation'].values\n",
    "\n",
    "# Train-validation split\n",
    "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "\n",
    "# Define XGBoost model\n",
    "xgb_model = XGBRegressor(\n",
    "    tree_method=\"hist\",\n",
    "    device=\"cuda\",\n",
    "    predictor=\"gpu_predictor\",\n",
    "    n_estimators=300,\n",
    "    max_depth=8,\n",
    "    learning_rate=0.1,\n",
    "    subsample=0.8\n",
    ")\n",
    "\n",
    "# Train model with XGBoost's built-in progress bar\n",
    "xgb_model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=True)\n",
    "\n",
    "# Predictions\n",
    "val_pred = xgb_model.predict(X_val)\n",
    "\n",
    "# Evaluate model\n",
    "final_rmse = np.sqrt(mean_squared_error(np.expm1(y_val), np.expm1(val_pred)))\n",
    "final_mae = mean_absolute_error(np.expm1(y_val), np.expm1(val_pred))\n",
    "\n",
    "print(f\"Final RMSE: {final_rmse:.2f}\")\n",
    "print(f\"Final MAE: {final_mae:.2f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba7f9386",
   "metadata": {},
   "source": [
    "### Testing trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00c4118b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n",
      "Encoding test data...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be72fc8920a542269193adbbc75cdc75",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/74 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Test Set Regression Metrics:\n",
      "RMSE: 210.04\n",
      "MAE: 107.10\n",
      "R2 Score: -0.34\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sqlite3\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
    "from xgboost import XGBRegressor\n",
    "import torch\n",
    "\n",
    "# ----------------------------------------\n",
    "# 1. Set up and check GPU availability\n",
    "# ----------------------------------------\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")\n",
    "# Connect to the SQLite database and load the data.\n",
    "conn = sqlite3.connect('papers.db')\n",
    "test_df = pd.read_sql_query(\"SELECT * FROM papers\", conn)\n",
    "conn.close()\n",
    "\n",
    "# Preprocess test data: handle missing values and combine title and abstract.\n",
    "test_df['abstract'] = test_df['abstract'].fillna('')\n",
    "test_df['text'] = test_df['title'] + \" \" + test_df['abstract']\n",
    "\n",
    "# Scale the 'year' feature in the test set using the same scaler fitted on training data.\n",
    "test_df['year_scaled'] = scaler.transform(test_df[['year']])\n",
    "\n",
    "# Prepare the target variable in the test set.\n",
    "# (Test set uses the column 'citations' for the citation count.)\n",
    "test_df['log_citation'] = np.log1p(test_df['citations'])\n",
    "\n",
    "# ----------------------------------------\n",
    "# 4. Generate text embeddings using a SentenceTransformer model\n",
    "# ----------------------------------------\n",
    "# Load the pre-trained SentenceTransformer model with GPU support\n",
    "embed_model = SentenceTransformer('all-mpnet-base-v2', device='cuda')\n",
    "\n",
    "# Generate embeddings for the combined text from training and test data.\n",
    "\n",
    "print(\"Encoding test data...\")\n",
    "test_embeddings = embed_model.encode(\n",
    "    test_df['text'].tolist(),\n",
    "    batch_size=128,\n",
    "    show_progress_bar=True,\n",
    "    device='cuda'\n",
    ")\n",
    "\n",
    "# ----------------------------------------\n",
    "# 5. Combine features (embeddings + scaled year) and prepare datasets\n",
    "# ----------------------------------------\n",
    "\n",
    "X_test = np.hstack([test_embeddings, test_df['year_scaled'].values.reshape(-1, 1)])\n",
    "y_test = test_df['log_citation'].values\n",
    "\n",
    "# ----------------------------------------\n",
    "# 7. Predict on the test set and evaluate regression metrics\n",
    "# ----------------------------------------\n",
    "# Predict in the log domain\n",
    "y_test_pred_log = xgb_model.predict(X_test)\n",
    "\n",
    "# Convert predictions back to original citation counts using the inverse transformation\n",
    "y_test_pred = np.expm1(y_test_pred_log)\n",
    "y_test_true = np.expm1(y_test)\n",
    "\n",
    "# Calculate regression metrics\n",
    "rmse = np.sqrt(mean_squared_error(y_test_true, y_test_pred))\n",
    "mae = mean_absolute_error(y_test_true, y_test_pred)\n",
    "r2 = r2_score(y_test_true, y_test_pred)\n",
    "\n",
    "print(\"\\nTest Set Regression Metrics:\")\n",
    "print(f\"RMSE: {rmse:.2f}\")\n",
    "print(f\"MAE: {mae:.2f}\")\n",
    "print(f\"R2 Score: {r2:.2f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d68704e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Top Predicted Citations:\n",
      "                                                  title  year  \\\n",
      "6131  DSI-Net: Deep Synergistic Interaction Network ...  2021   \n",
      "1077  Uni-ControlNet: All-in-One Control to Text-to-...  2023   \n",
      "246   Anchor Diffusion for Unsupervised Video Object...  2019   \n",
      "\n",
      "                 topic  predicted_citations  \n",
      "6131               LLM         1.189171e+13  \n",
      "1077  Diffusion Models         8.380618e+12  \n",
      "246   Diffusion Models         3.803213e+12  \n"
     ]
    }
   ],
   "source": [
    "test_df['predicted_citations'] = np.expm1(y_test_pred)\n",
    "\n",
    "# Display the top papers with the highest predicted citations\n",
    "top_predictions = test_df[['title', 'year', 'topic', 'predicted_citations']].sort_values(\n",
    "    by='predicted_citations', ascending=False\n",
    ").head(3)  # Show top 10 papers\n",
    "\n",
    "print(\"\\nTop Predicted Citations:\")\n",
    "print(top_predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e592fecf",
   "metadata": {},
   "source": [
    "# Author-Author Network\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "165a106c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10/100, Loss: 3.7697\n",
      "Epoch 20/100, Loss: 2.6772\n",
      "Epoch 30/100, Loss: 1.7634\n",
      "Epoch 40/100, Loss: 1.0372\n",
      "Epoch 50/100, Loss: 0.5276\n",
      "Epoch 60/100, Loss: 0.2340\n",
      "Epoch 70/100, Loss: 0.1039\n",
      "Epoch 80/100, Loss: 0.0525\n",
      "Epoch 90/100, Loss: 0.0334\n",
      "Epoch 100/100, Loss: 0.0260\n",
      "\n",
      "Test Set Regression Metrics:\n",
      "RMSE: 144.20\n",
      "MAE: 60.33\n",
      "R2 Score: -0.27\n",
      "\n",
      "Top Predicted Citations:\n",
      "                                                   title  year  \\\n",
      "4119   The Good, the Bad, and the Expert: How Consume...  2015   \n",
      "4740   Chinese Term Recognition and Extraction Based ...  2008   \n",
      "4045   High efficient hardware allocation framework o...  2015   \n",
      "10845  Score-Based Generative Modeling through Stocha...  2020   \n",
      "12922  Normalizing Flows for Probabilistic Modeling a...  2019   \n",
      "15288                  Large language models in medicine  2023   \n",
      "14175      Diffusion Models Beat GANs on Image Synthesis  2021   \n",
      "15382  RePaint: Inpainting using Denoising Diffusion ...  2022   \n",
      "16567   Vote-Selling: Infrastructure and Public Services  2018   \n",
      "12577                    Squeeze-and-Excitation Networks  2017   \n",
      "\n",
      "                   topic  predicted_citations  \n",
      "4119                 NaN          2697.807129  \n",
      "4740                 NaN          2314.161621  \n",
      "4045                 NaN          1178.770020  \n",
      "10845   Diffusion Models          1175.487061  \n",
      "12922  Foundation Models          1165.278809  \n",
      "15288  Generative Models          1165.196045  \n",
      "14175  Generative Models          1120.444824  \n",
      "15382  Generative Models          1099.473389  \n",
      "16567                LLM           949.788086  \n",
      "12577  Foundation Models           900.298340  \n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sqlite3\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.data import Data\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score\n",
    "\n",
    "# -------------------------------\n",
    "# 1. Load and Combine Data\n",
    "# -------------------------------\n",
    "\n",
    "# Load post-2017 dataset from papers.db\n",
    "conn = sqlite3.connect('papers.db')\n",
    "papers_df = pd.read_sql_query(\"SELECT * FROM papers\", conn)\n",
    "conn.close()\n",
    "\n",
    "# Standardize column names to match\n",
    "df = df.rename(columns={'n_citation': 'citations'})\n",
    "papers_df = papers_df.rename(columns={'citations': 'citations'})\n",
    "\n",
    "# Ensure 'authors' field is consistent across datasets\n",
    "df['authors'] = df['authors'].fillna('')\n",
    "papers_df['authors'] = papers_df['authors'].fillna('')\n",
    "\n",
    "# Combine datasets\n",
    "combined_df = pd.concat([df, papers_df], ignore_index=True)\n",
    "\n",
    "# Process authors (semicolon-separated)\n",
    "combined_df['author_list'] = combined_df['authors'].apply(lambda x: [a.strip() for a in x.split(';') if a.strip() != ''])\n",
    "\n",
    "# Target variable: log-transform citation counts\n",
    "combined_df['log_citations'] = np.log1p(combined_df['citations'])\n",
    "\n",
    "# -------------------------------\n",
    "# 2. Train-Test Split\n",
    "# -------------------------------\n",
    "train_df, test_df = train_test_split(combined_df, test_size=0.2, random_state=42)\n",
    "\n",
    "# -------------------------------\n",
    "# 3. Build the Author-Author Graph\n",
    "# -------------------------------\n",
    "# Collect all unique authors from the entire dataset\n",
    "all_authors = set()\n",
    "for authors in combined_df['author_list']:\n",
    "    for a in authors:\n",
    "        all_authors.add(a)\n",
    "all_authors = sorted(list(all_authors))\n",
    "author_to_id = {author: idx for idx, author in enumerate(all_authors)}\n",
    "num_authors = len(all_authors)\n",
    "\n",
    "# Build edge list: add an undirected edge for each pair of co-authors in a paper\n",
    "edges = []\n",
    "for authors in combined_df['author_list']:\n",
    "    author_ids = [author_to_id[a] for a in authors if a in author_to_id]\n",
    "    for i in range(len(author_ids)):\n",
    "        for j in range(i + 1, len(author_ids)):\n",
    "            edges.append((author_ids[i], author_ids[j]))\n",
    "            edges.append((author_ids[j], author_ids[i]))\n",
    "\n",
    "# Remove duplicate edges\n",
    "edges = list(set(edges))\n",
    "edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n",
    "\n",
    "# Create the PyTorch Geometric Data object\n",
    "data = Data(edge_index=edge_index)\n",
    "data.num_nodes = num_authors\n",
    "\n",
    "# -------------------------------\n",
    "# 4. Map Papers to Author Node Indices\n",
    "# -------------------------------\n",
    "def get_author_ids(author_list):\n",
    "    return [author_to_id[a] for a in author_list if a in author_to_id]\n",
    "\n",
    "train_paper_author_ids = train_df['author_list'].apply(get_author_ids).tolist()\n",
    "test_paper_author_ids = test_df['author_list'].apply(get_author_ids).tolist()\n",
    "\n",
    "# Convert target values to torch tensors\n",
    "y_train = torch.tensor(train_df['log_citations'].values, dtype=torch.float)\n",
    "y_test  = torch.tensor(test_df['log_citations'].values, dtype=torch.float)\n",
    "\n",
    "# -------------------------------\n",
    "# 5. Define the GNN Regressor Model\n",
    "# -------------------------------\n",
    "class AuthorGNNRegressor(nn.Module):\n",
    "    def __init__(self, num_nodes, in_channels=32, hidden_channels=64, out_channels=64):\n",
    "        super(AuthorGNNRegressor, self).__init__()\n",
    "        self.author_embeddings = nn.Embedding(num_nodes, in_channels)\n",
    "        self.conv1 = GCNConv(in_channels, hidden_channels)\n",
    "        self.conv2 = GCNConv(hidden_channels, out_channels)\n",
    "        self.regressor = nn.Linear(out_channels, 1)\n",
    "\n",
    "    def forward(self, data, paper_author_ids):\n",
    "        x = self.author_embeddings.weight\n",
    "        x = self.conv1(x, data.edge_index)\n",
    "        x = torch.relu(x)\n",
    "        x = self.conv2(x, data.edge_index)\n",
    "\n",
    "        paper_embeddings = []\n",
    "        for author_ids in paper_author_ids:\n",
    "            if len(author_ids) > 0:\n",
    "                authors_tensor = x[author_ids]\n",
    "                paper_emb = authors_tensor.mean(dim=0)\n",
    "            else:\n",
    "                paper_emb = torch.zeros(x.size(1), device=x.device)\n",
    "            paper_embeddings.append(paper_emb)\n",
    "        paper_embeddings = torch.stack(paper_embeddings, dim=0)\n",
    "        out = self.regressor(paper_embeddings)\n",
    "        return out.squeeze()\n",
    "\n",
    "# Instantiate the model, optimizer, and loss function\n",
    "model = AuthorGNNRegressor(num_nodes=num_authors)\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
    "criterion = nn.MSELoss()\n",
    "\n",
    "# -------------------------------\n",
    "# 6. Training Loop\n",
    "# -------------------------------\n",
    "num_epochs = 100\n",
    "model.train()\n",
    "for epoch in range(num_epochs):\n",
    "    optimizer.zero_grad()\n",
    "    preds = model(data, train_paper_author_ids)\n",
    "    loss = criterion(preds, y_train)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (epoch + 1) % 10 == 0:\n",
    "        print(f\"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}\")\n",
    "\n",
    "# -------------------------------\n",
    "# 7. Evaluation on Test Set\n",
    "# -------------------------------\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    test_preds = model(data, test_paper_author_ids)\n",
    "    test_preds_citations = torch.expm1(test_preds).cpu().numpy()\n",
    "    y_test_citations = torch.expm1(y_test).cpu().numpy()\n",
    "\n",
    "rmse = np.sqrt(mean_squared_error(y_test_citations, test_preds_citations))\n",
    "mae = mean_absolute_error(y_test_citations, test_preds_citations)\n",
    "r2 = r2_score(y_test_citations, test_preds_citations)\n",
    "\n",
    "print(\"\\nTest Set Regression Metrics:\")\n",
    "print(f\"RMSE: {rmse:.2f}\")\n",
    "print(f\"MAE: {mae:.2f}\")\n",
    "print(f\"R2 Score: {r2:.2f}\")\n",
    "\n",
    "# -------------------------------\n",
    "# 8. Display Top Predicted Citations\n",
    "# -------------------------------\n",
    "test_df = test_df.copy()\n",
    "test_df['predicted_citations'] = test_preds_citations\n",
    "top_predictions = test_df[['title', 'year', 'topic', 'predicted_citations']].sort_values(\n",
    "    by='predicted_citations', ascending=False\n",
    ").head(10)\n",
    "\n",
    "print(\"\\nTop Predicted Citations:\")\n",
    "print(top_predictions)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}

GitHub Events

Total
  • Push event: 1
  • Create event: 1
Last Year
  • Push event: 1
  • Create event: 1

Committers

Last synced: 7 months ago

All Time
  • Total Commits: 3
  • Total Committers: 1
  • Avg Commits per committer: 3.0
  • Development Distribution Score (DDS): 0.0
Past Year
  • Commits: 3
  • Committers: 1
  • Avg Commits per committer: 3.0
  • Development Distribution Score (DDS): 0.0
Top Committers
Name Email Commits
Arman Yazdani 1****m 3

Issues and Pull Requests

Last synced: 7 months ago