File size: 4,946 Bytes
b611e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors


# UTILS: Molecule Processing with 3D Coordinates
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    mol = Chem.AddHs(mol)
    try:
        AllChem.EmbedMolecule(mol, AllChem.ETKDG())
        AllChem.UFFOptimizeMolecule(mol)
    except:
        return None

    conf = mol.GetConformer()
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()

    node_feats = []
    pos = []
    edge_index = []
    edge_attrs = []

    for atom in atoms:
        # Normalize atomic number
        node_feats.append([atom.GetAtomicNum() / 100.0])
        position = conf.GetAtomPosition(atom.GetIdx())
        pos.append([position.x, position.y, position.z])

    for bond in bonds:
        start = bond.GetBeginAtomIdx()
        end = bond.GetEndAtomIdx()
        edge_index.append([start, end])
        edge_index.append([end, start])
        bond_type = bond.GetBondType()
        bond_class = {
            Chem.BondType.SINGLE: 0,
            Chem.BondType.DOUBLE: 1,
            Chem.BondType.TRIPLE: 2,
            Chem.BondType.AROMATIC: 3
        }.get(bond_type, 0)
        edge_attrs.extend([[bond_class], [bond_class]])

    return Data(
        x=torch.tensor(node_feats, dtype=torch.float),
        pos=torch.tensor(pos, dtype=torch.float),
        edge_index=torch.tensor(edge_index, dtype=torch.long).t().contiguous(),
        edge_attr=torch.tensor(edge_attrs, dtype=torch.long)
    )

# EGNN Layer
class EGNNLayer(MessagePassing):
    def __init__(self, node_dim):
        super().__init__(aggr='add')
        self.node_mlp = nn.Sequential(
            nn.Linear(node_dim * 2 + 1, 128),
            nn.ReLU(),
            nn.Linear(128, node_dim)
        )
        self.coord_mlp = nn.Sequential(
            nn.Linear(1, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x, pos, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        self.coord_updates = torch.zeros_like(pos)
        x_out, coord_out = self.propagate(edge_index, x=x, pos=pos)
        return x_out, pos + coord_out

    def message(self, x_i, x_j, pos_i, pos_j):
        edge_vec = pos_j - pos_i
        dist = ((edge_vec**2).sum(dim=-1, keepdim=True) + 1e-8).sqrt()
        h = torch.cat([x_i, x_j, dist], dim=-1)
        edge_msg = self.node_mlp(h)
        coord_update = self.coord_mlp(dist) * edge_vec
        return edge_msg, coord_update

    def message_and_aggregate(self, adj_t, x):
        raise NotImplementedError("This EGNN layer does not support sparse adjacency matrices.")

    def aggregate(self, inputs, index):
        edge_msg, coord_update = inputs
        aggr_msg = torch.zeros(index.max() + 1, edge_msg.size(-1), device=edge_msg.device).index_add_(0, index, edge_msg)
        aggr_coord = torch.zeros(index.max() + 1, coord_update.size(-1), device=coord_update.device).index_add_(0, index, coord_update)
        return aggr_msg, aggr_coord

    def update(self, aggr_out, x):
        msg, coord_update = aggr_out
        return x + msg, coord_update

# Time Embedding
class TimeEmbedding(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, embed_dim)
        )

    def forward(self, t):
        return self.net(t.view(-1, 1).float() / 1000)

# Olfactory Conditioning
class OlfactoryConditioner(nn.Module):
    def __init__(self, num_labels, embed_dim):
        super().__init__()
        self.embedding = nn.Linear(num_labels, embed_dim)

    def forward(self, labels):
        return self.embedding(labels.float())

# EGNN Diffusion Model
class EGNNDiffusionModel(nn.Module):
    def __init__(self, node_dim, embed_dim):
        super().__init__()
        self.time_embed = TimeEmbedding(embed_dim)
        self.egnn1 = EGNNLayer(node_dim + embed_dim * 2)
        self.egnn2 = EGNNLayer(node_dim + embed_dim * 2)
        self.bond_predictor = nn.Sequential(
            nn.Linear((node_dim + embed_dim * 2) * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 4)
        )

    def forward(self, x_t, pos, edge_index, t, cond_embed):
        batch_size = x_t.size(0)
        t_embed = self.time_embed(t).expand(batch_size, -1)
        cond_embed = cond_embed.expand(batch_size, -1)
        x_input = torch.cat([x_t, cond_embed, t_embed], dim=1)
        x1, pos1 = self.egnn1(x_input, pos, edge_index)
        x2, pos2 = self.egnn2(x1, pos1, edge_index)
        edge_feats = torch.cat([x2[edge_index[0]], x2[edge_index[1]]], dim=1)
        bond_logits = self.bond_predictor(edge_feats)
        return x2[:, :x_t.shape[1]], bond_logits