dv4aby commited on
Commit
207014d
·
verified ·
1 Parent(s): ce729f4

Upload source code structural_encoder_ablation.py

Browse files
Files changed (1) hide show
  1. structural_encoder_ablation.py +89 -0
structural_encoder_ablation.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional, Tuple, Any
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import pandas as pd
7
+ from torch_geometric.data import Batch
8
+ from transformers import AutoTokenizer
9
+
10
+ # Import builder from dataloader for inference
11
+ from dataloader import CodeGraphBuilder
12
+
13
+ from structural_encoder_v2 import RelationalGraphEncoder, StructuralEncoderV2, GatedFusion
14
+
15
+ class StructuralEncoderOnlyGraph(nn.Module):
16
+ """
17
+ Ablation variant 1: Pure Structural Encoder.
18
+ Removes GraphCodeBERT and uses only the graph path (R-GNN).
19
+ """
20
+
21
+ def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2, out_dim: int = 768):
22
+ super().__init__()
23
+ self.device = torch.device(device)
24
+
25
+ self.graph_encoder = RelationalGraphEncoder(hidden_dim=graph_hidden_dim, out_dim=out_dim, num_layers=graph_layers)
26
+ self.graph_encoder.to(self.device)
27
+
28
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor:
29
+ # Ignore text inputs for OnlyGraph
30
+ return self.graph_encoder(graph_batch)
31
+
32
+ def generate_embeddings(self, df: "pd.DataFrame", batch_size: int = 8, save_path: str | None = None, desc: str = "Structural OnlyGraph embeddings") -> np.ndarray:
33
+ builder = CodeGraphBuilder()
34
+ codes = df["code"].tolist()
35
+ batches = range(0, len(codes), batch_size)
36
+ all_embeddings: List[torch.Tensor] = []
37
+
38
+ for start in tqdm(batches, desc=desc):
39
+ batch_codes = codes[start:start + batch_size]
40
+
41
+ data_list = [builder.build(c) for c in batch_codes]
42
+ graph_batch = Batch.from_data_list(data_list)
43
+
44
+ # Dummy inputs for signature compatibility
45
+ dummy_ids = torch.zeros((1,1), device=self.device)
46
+ dummy_mask = torch.zeros((1,1), device=self.device)
47
+
48
+ with torch.no_grad():
49
+ out = self.forward(dummy_ids, dummy_mask, graph_batch)
50
+ all_embeddings.append(out.cpu())
51
+
52
+ embeddings = torch.cat(all_embeddings, dim=0).numpy().astype("float32")
53
+ if save_path is not None:
54
+ np.save(save_path, embeddings)
55
+ return embeddings
56
+
57
+ def load_checkpoint(self, checkpoint_path: str, map_location: str | torch.device = "cpu", strict: bool = True) -> None:
58
+ if not checkpoint_path:
59
+ raise ValueError("checkpoint_path must be provided")
60
+ state = torch.load(checkpoint_path, map_location=map_location)
61
+ if isinstance(state, dict) and "state_dict" in state:
62
+ state = state["state_dict"]
63
+ self.load_state_dict(state, strict=strict)
64
+
65
+
66
+ class StructuralEncoderConcat(StructuralEncoderV2):
67
+ """
68
+ Ablation variant 2: Concatenation Fusion.
69
+ Keeps both text and graph paths but fuses them via simple concatenation + projection
70
+ instead of Gated Fusion.
71
+ """
72
+
73
+ def __init__(self, device: torch.device | str, graph_hidden_dim: int = 256, graph_layers: int = 2):
74
+ super().__init__(device, graph_hidden_dim, graph_layers)
75
+
76
+ text_dim = self.text_model.config.hidden_size
77
+ graph_dim = self.text_model.config.hidden_size
78
+
79
+ self.concat_proj = nn.Linear(text_dim + graph_dim, text_dim)
80
+ self.concat_proj.to(self.device)
81
+
82
+ del self.fusion
83
+
84
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, graph_batch: Batch) -> torch.Tensor:
85
+ text_embeddings = self.encode_text(input_ids, attention_mask)
86
+ graph_embeddings = self.graph_encoder(graph_batch)
87
+
88
+ combined = torch.cat([text_embeddings, graph_embeddings], dim=-1)
89
+ return self.concat_proj(combined)