File size: 4,946 Bytes
b611e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
# UTILS: Molecule Processing with 3D Coordinates
def smiles_to_graph(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
mol = Chem.AddHs(mol)
try:
AllChem.EmbedMolecule(mol, AllChem.ETKDG())
AllChem.UFFOptimizeMolecule(mol)
except:
return None
conf = mol.GetConformer()
atoms = mol.GetAtoms()
bonds = mol.GetBonds()
node_feats = []
pos = []
edge_index = []
edge_attrs = []
for atom in atoms:
# Normalize atomic number
node_feats.append([atom.GetAtomicNum() / 100.0])
position = conf.GetAtomPosition(atom.GetIdx())
pos.append([position.x, position.y, position.z])
for bond in bonds:
start = bond.GetBeginAtomIdx()
end = bond.GetEndAtomIdx()
edge_index.append([start, end])
edge_index.append([end, start])
bond_type = bond.GetBondType()
bond_class = {
Chem.BondType.SINGLE: 0,
Chem.BondType.DOUBLE: 1,
Chem.BondType.TRIPLE: 2,
Chem.BondType.AROMATIC: 3
}.get(bond_type, 0)
edge_attrs.extend([[bond_class], [bond_class]])
return Data(
x=torch.tensor(node_feats, dtype=torch.float),
pos=torch.tensor(pos, dtype=torch.float),
edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(),
edge_attr=torch.tensor(edge_attrs, dtype=torch.long)
)
# EGNN Layer
class EGNNLayer(MessagePassing):
def __init__(self, node_dim):
super().__init__(aggr='add')
self.node_mlp = nn.Sequential(
nn.Linear(node_dim * 2 + 1, 128),
nn.ReLU(),
nn.Linear(128, node_dim)
)
self.coord_mlp = nn.Sequential(
nn.Linear(1, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
def forward(self, x, pos, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
self.coord_updates = torch.zeros_like(pos)
x_out, coord_out = self.propagate(edge_index, x=x, pos=pos)
return x_out, pos + coord_out
def message(self, x_i, x_j, pos_i, pos_j):
edge_vec = pos_j - pos_i
dist = ((edge_vec**2).sum(dim=-1, keepdim=True) + 1e-8).sqrt()
h = torch.cat([x_i, x_j, dist], dim=-1)
edge_msg = self.node_mlp(h)
coord_update = self.coord_mlp(dist) * edge_vec
return edge_msg, coord_update
def message_and_aggregate(self, adj_t, x):
raise NotImplementedError("This EGNN layer does not support sparse adjacency matrices.")
def aggregate(self, inputs, index):
edge_msg, coord_update = inputs
aggr_msg = torch.zeros(index.max() + 1, edge_msg.size(-1), device=edge_msg.device).index_add_(0, index, edge_msg)
aggr_coord = torch.zeros(index.max() + 1, coord_update.size(-1), device=coord_update.device).index_add_(0, index, coord_update)
return aggr_msg, aggr_coord
def update(self, aggr_out, x):
msg, coord_update = aggr_out
return x + msg, coord_update
# Time Embedding
class TimeEmbedding(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(1, 32),
nn.ReLU(),
nn.Linear(32, embed_dim)
)
def forward(self, t):
return self.net(t.view(-1, 1).float() / 1000)
# Olfactory Conditioning
class OlfactoryConditioner(nn.Module):
def __init__(self, num_labels, embed_dim):
super().__init__()
self.embedding = nn.Linear(num_labels, embed_dim)
def forward(self, labels):
return self.embedding(labels.float())
# EGNN Diffusion Model
class EGNNDiffusionModel(nn.Module):
def __init__(self, node_dim, embed_dim):
super().__init__()
self.time_embed = TimeEmbedding(embed_dim)
self.egnn1 = EGNNLayer(node_dim + embed_dim * 2)
self.egnn2 = EGNNLayer(node_dim + embed_dim * 2)
self.bond_predictor = nn.Sequential(
nn.Linear((node_dim + embed_dim * 2) * 2, 64),
nn.ReLU(),
nn.Linear(64, 4)
)
def forward(self, x_t, pos, edge_index, t, cond_embed):
batch_size = x_t.size(0)
t_embed = self.time_embed(t).expand(batch_size, -1)
cond_embed = cond_embed.expand(batch_size, -1)
x_input = torch.cat([x_t, cond_embed, t_embed], dim=1)
x1, pos1 = self.egnn1(x_input, pos, edge_index)
x2, pos2 = self.egnn2(x1, pos1, edge_index)
edge_feats = torch.cat([x2[edge_index[0]], x2[edge_index[1]]], dim=1)
bond_logits = self.bond_predictor(edge_feats)
return x2[:, :x_t.shape[1]], bond_logits |