Upload source code structural_encoder_ablation.py
Browse files
structural_encoder_ablation.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import List, Optional, Tuple, Any
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from torch_geometric.data import Batch
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
|
| 10 |
+
# Import builder from dataloader for inference
|
| 11 |
+
from dataloader import CodeGraphBuilder
|
| 12 |
+
|
| 13 |
+
from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion
|
| 14 |
+
|
| 15 |
+
class StructuralEncoderOnlyGraph(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Ablation variant 1: Pure Structural Encoder.
|
| 18 |
+
Removes GraphCodeBERT and uses only the graph path (R-GNN).
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.device = torch.device(device)
|
| 24 |
+
|
| 25 |
+
self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers)
|
| 26 |
+
self.graph_encoder.to(self.device)
|
| 27 |
+
|
| 28 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor:
|
| 29 |
+
# Ignore text inputs for OnlyGraph
|
| 30 |
+
return self.graph_encoder(graph_batch)
|
| 31 |
+
|
| 32 |
+
def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray:
|
| 33 |
+
builder = CodeGraphBuilder()
|
| 34 |
+
codes = df["code"].tolist()
|
| 35 |
+
batches = range(0, len(codes), batch_size)
|
| 36 |
+
all_embeddings: List[torch.Tensor] = []
|
| 37 |
+
|
| 38 |
+
for start in tqdm(batches, desc=desc):
|
| 39 |
+
batch_codes = codes[start:start + batch_size]
|
| 40 |
+
|
| 41 |
+
data_list = [builder.build(c) for c in batch_codes]
|
| 42 |
+
graph_batch = Batch.from_data_list(data_list)
|
| 43 |
+
|
| 44 |
+
# Dummy inputs for signature compatibility
|
| 45 |
+
dummy_ids = torch.zeros((1,1), device=self.device)
|
| 46 |
+
dummy_mask = torch.zeros((1,1), device=self.device)
|
| 47 |
+
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
out = self.forward(dummy_ids, dummy_mask, graph_batch)
|
| 50 |
+
all_embeddings.append(out.cpu())
|
| 51 |
+
|
| 52 |
+
embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
|
| 53 |
+
if save_path is not None:
|
| 54 |
+
np.save(save_path, embeddings)
|
| 55 |
+
return embeddings
|
| 56 |
+
|
| 57 |
+
def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
|
| 58 |
+
if not checkpoint_path:
|
| 59 |
+
raise ValueError("checkpoint_path must be provided")
|
| 60 |
+
state = torch.load(checkpoint_path, map_location=map_location)
|
| 61 |
+
if isinstance(state, dict) and "state_dict" in state:
|
| 62 |
+
state = state["state_dict"]
|
| 63 |
+
self.load_state_dict(state, strict=strict)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class StructuralEncoderConcat(StructuralEncoderV2):
|
| 67 |
+
"""
|
| 68 |
+
Ablation variant 2: Concatenation Fusion.
|
| 69 |
+
Keeps both text and graph paths but fuses them via simple concatenation + projection
|
| 70 |
+
instead of Gated Fusion.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
|
| 74 |
+
super().__init__(device, graph_hidden_dim, graph_layers)
|
| 75 |
+
|
| 76 |
+
text_dim = self.text_model.config.hidden_size
|
| 77 |
+
graph_dim = self.text_model.config.hidden_size
|
| 78 |
+
|
| 79 |
+
self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim)
|
| 80 |
+
self.concat_proj.to(self.device)
|
| 81 |
+
|
| 82 |
+
del self.fusion
|
| 83 |
+
|
| 84 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor:
|
| 85 |
+
text_embeddings = self.encode_text(input_ids, attention_mask)
|
| 86 |
+
graph_embeddings = self.graph_encoder(graph_batch)
|
| 87 |
+
|
| 88 |
+
combined = torch.cat([text_embeddings, graph_embeddings], dim=-1)
|
| 89 |
+
return self.concat_proj(combined)
|