|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch_geometric.data import Data |
|
|
from torch_geometric.nn import MessagePassing |
|
|
from torch_geometric.utils import add_self_loops |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem, Descriptors |
|
|
|
|
|
|
|
|
|
|
|
def smiles_to_graph(smiles): |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return None |
|
|
mol = Chem.AddHs(mol) |
|
|
try: |
|
|
AllChem.EmbedMolecule(mol, AllChem.ETKDG()) |
|
|
AllChem.UFFOptimizeMolecule(mol) |
|
|
except: |
|
|
return None |
|
|
|
|
|
conf = mol.GetConformer() |
|
|
atoms = mol.GetAtoms() |
|
|
bonds = mol.GetBonds() |
|
|
|
|
|
node_feats = [] |
|
|
pos = [] |
|
|
edge_index = [] |
|
|
edge_attrs = [] |
|
|
|
|
|
for atom in atoms: |
|
|
|
|
|
node_feats.append([atom.GetAtomicNum() / 100.0]) |
|
|
position = conf.GetAtomPosition(atom.GetIdx()) |
|
|
pos.append([position.x, position.y, position.z]) |
|
|
|
|
|
for bond in bonds: |
|
|
start = bond.GetBeginAtomIdx() |
|
|
end = bond.GetEndAtomIdx() |
|
|
edge_index.append([start, end]) |
|
|
edge_index.append([end, start]) |
|
|
bond_type = bond.GetBondType() |
|
|
bond_class = { |
|
|
Chem.BondType.SINGLE: 0, |
|
|
Chem.BondType.DOUBLE: 1, |
|
|
Chem.BondType.TRIPLE: 2, |
|
|
Chem.BondType.AROMATIC: 3 |
|
|
}.get(bond_type, 0) |
|
|
edge_attrs.extend([[bond_class], [bond_class]]) |
|
|
|
|
|
return Data( |
|
|
x=torch.tensor(node_feats, dtype=torch.float), |
|
|
pos=torch.tensor(pos, dtype=torch.float), |
|
|
edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(), |
|
|
edge_attr=torch.tensor(edge_attrs, dtype=torch.long) |
|
|
) |
|
|
|
|
|
|
|
|
class EGNNLayer(MessagePassing): |
|
|
def __init__(self, node_dim): |
|
|
super().__init__(aggr='add') |
|
|
self.node_mlp = nn.Sequential( |
|
|
nn.Linear(node_dim * 2 + 1, 128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, node_dim) |
|
|
) |
|
|
self.coord_mlp = nn.Sequential( |
|
|
nn.Linear(1, 128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x, pos, edge_index): |
|
|
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) |
|
|
self.coord_updates = torch.zeros_like(pos) |
|
|
x_out, coord_out = self.propagate(edge_index, x=x, pos=pos) |
|
|
return x_out, pos + coord_out |
|
|
|
|
|
def message(self, x_i, x_j, pos_i, pos_j): |
|
|
edge_vec = pos_j - pos_i |
|
|
dist = ((edge_vec**2).sum(dim=-1, keepdim=True) + 1e-8).sqrt() |
|
|
h = torch.cat([x_i, x_j, dist], dim=-1) |
|
|
edge_msg = self.node_mlp(h) |
|
|
coord_update = self.coord_mlp(dist) * edge_vec |
|
|
return edge_msg, coord_update |
|
|
|
|
|
def message_and_aggregate(self, adj_t, x): |
|
|
raise NotImplementedError("This EGNN layer does not support sparse adjacency matrices.") |
|
|
|
|
|
def aggregate(self, inputs, index): |
|
|
edge_msg, coord_update = inputs |
|
|
aggr_msg = torch.zeros(index.max() + 1, edge_msg.size(-1), device=edge_msg.device).index_add_(0, index, edge_msg) |
|
|
aggr_coord = torch.zeros(index.max() + 1, coord_update.size(-1), device=coord_update.device).index_add_(0, index, coord_update) |
|
|
return aggr_msg, aggr_coord |
|
|
|
|
|
def update(self, aggr_out, x): |
|
|
msg, coord_update = aggr_out |
|
|
return x + msg, coord_update |
|
|
|
|
|
|
|
|
class TimeEmbedding(nn.Module): |
|
|
def __init__(self, embed_dim): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(1, 32), |
|
|
nn.ReLU(), |
|
|
nn.Linear(32, embed_dim) |
|
|
) |
|
|
|
|
|
def forward(self, t): |
|
|
return self.net(t.view(-1, 1).float() / 1000) |
|
|
|
|
|
|
|
|
class OlfactoryConditioner(nn.Module): |
|
|
def __init__(self, num_labels, embed_dim): |
|
|
super().__init__() |
|
|
self.embedding = nn.Linear(num_labels, embed_dim) |
|
|
|
|
|
def forward(self, labels): |
|
|
return self.embedding(labels.float()) |
|
|
|
|
|
|
|
|
class EGNNDiffusionModel(nn.Module): |
|
|
def __init__(self, node_dim, embed_dim): |
|
|
super().__init__() |
|
|
self.time_embed = TimeEmbedding(embed_dim) |
|
|
self.egnn1 = EGNNLayer(node_dim + embed_dim * 2) |
|
|
self.egnn2 = EGNNLayer(node_dim + embed_dim * 2) |
|
|
self.bond_predictor = nn.Sequential( |
|
|
nn.Linear((node_dim + embed_dim * 2) * 2, 64), |
|
|
nn.ReLU(), |
|
|
nn.Linear(64, 4) |
|
|
) |
|
|
|
|
|
def forward(self, x_t, pos, edge_index, t, cond_embed): |
|
|
batch_size = x_t.size(0) |
|
|
t_embed = self.time_embed(t).expand(batch_size, -1) |
|
|
cond_embed = cond_embed.expand(batch_size, -1) |
|
|
x_input = torch.cat([x_t, cond_embed, t_embed], dim=1) |
|
|
x1, pos1 = self.egnn1(x_input, pos, edge_index) |
|
|
x2, pos2 = self.egnn2(x1, pos1, edge_index) |
|
|
edge_feats = torch.cat([x2[edge_index[0]], x2[edge_index[1]]], dim=1) |
|
|
bond_logits = self.bond_predictor(edge_feats) |
|
|
return x2[:, :x_t.shape[1]], bond_logits |