Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import rdkit.Chem as Chem | |
| import torch.nn.functional as F | |
| from nnutils import * | |
| from chemutils import get_mol | |
| ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown'] | |
| ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 4 + 1 | |
| BOND_FDIM = 5 + 6 | |
| MAX_NB = 6 | |
| def onek_encoding_unk(x, allowable_set): | |
| if x not in allowable_set: | |
| x = allowable_set[-1] | |
| return list(map(lambda s: x == s, allowable_set)) | |
| def atom_features(atom): | |
| return torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST) | |
| + onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) | |
| + onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0]) | |
| + onek_encoding_unk(int(atom.GetChiralTag()), [0,1,2,3]) | |
| + [atom.GetIsAromatic()]) | |
| def bond_features(bond): | |
| bt = bond.GetBondType() | |
| stereo = int(bond.GetStereo()) | |
| fbond = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()] | |
| fstereo = onek_encoding_unk(stereo, [0,1,2,3,4,5]) | |
| return torch.Tensor(fbond + fstereo) | |
| class MPN(nn.Module): | |
| def __init__(self, hidden_size, depth): | |
| super(MPN, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.depth = depth | |
| self.W_i = nn.Linear(ATOM_FDIM + BOND_FDIM, hidden_size, bias=False) | |
| self.W_h = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.W_o = nn.Linear(ATOM_FDIM + hidden_size, hidden_size) | |
| def forward(self, fatoms, fbonds, agraph, bgraph, scope): | |
| fatoms = create_var(fatoms) | |
| fbonds = create_var(fbonds) | |
| agraph = create_var(agraph) | |
| bgraph = create_var(bgraph) | |
| binput = self.W_i(fbonds) | |
| message = F.relu(binput) | |
| for i in range(self.depth - 1): | |
| nei_message = index_select_ND(message, 0, bgraph) | |
| nei_message = nei_message.sum(dim=1) | |
| nei_message = self.W_h(nei_message) | |
| message = F.relu(binput + nei_message) | |
| nei_message = index_select_ND(message, 0, agraph) | |
| nei_message = nei_message.sum(dim=1) | |
| ainput = torch.cat([fatoms, nei_message], dim=1) | |
| atom_hiddens = F.relu(self.W_o(ainput)) | |
| max_len = max([x for _,x in scope]) | |
| batch_vecs = [] | |
| for st,le in scope: | |
| cur_vecs = atom_hiddens[st : st + le].mean(dim=0) | |
| batch_vecs.append( cur_vecs ) | |
| mol_vecs = torch.stack(batch_vecs, dim=0) | |
| return mol_vecs | |
| def tensorize(mol_batch): | |
| padding = torch.zeros(ATOM_FDIM + BOND_FDIM) | |
| fatoms,fbonds = [],[padding] #Ensure bond is 1-indexed | |
| in_bonds,all_bonds = [],[(-1,-1)] #Ensure bond is 1-indexed | |
| scope = [] | |
| total_atoms = 0 | |
| for smiles in mol_batch: | |
| mol = get_mol(smiles) | |
| #mol = Chem.MolFromSmiles(smiles) | |
| n_atoms = mol.GetNumAtoms() | |
| for atom in mol.GetAtoms(): | |
| fatoms.append( atom_features(atom) ) | |
| in_bonds.append([]) | |
| for bond in mol.GetBonds(): | |
| a1 = bond.GetBeginAtom() | |
| a2 = bond.GetEndAtom() | |
| x = a1.GetIdx() + total_atoms | |
| y = a2.GetIdx() + total_atoms | |
| b = len(all_bonds) | |
| all_bonds.append((x,y)) | |
| fbonds.append( torch.cat([fatoms[x], bond_features(bond)], 0) ) | |
| in_bonds[y].append(b) | |
| b = len(all_bonds) | |
| all_bonds.append((y,x)) | |
| fbonds.append( torch.cat([fatoms[y], bond_features(bond)], 0) ) | |
| in_bonds[x].append(b) | |
| scope.append((total_atoms,n_atoms)) | |
| total_atoms += n_atoms | |
| total_bonds = len(all_bonds) | |
| fatoms = torch.stack(fatoms, 0) | |
| fbonds = torch.stack(fbonds, 0) | |
| agraph = torch.zeros(total_atoms,MAX_NB).long() | |
| bgraph = torch.zeros(total_bonds,MAX_NB).long() | |
| for a in range(total_atoms): | |
| for i,b in enumerate(in_bonds[a]): | |
| agraph[a,i] = b | |
| for b1 in range(1, total_bonds): | |
| x,y = all_bonds[b1] | |
| for i,b2 in enumerate(in_bonds[x]): | |
| if all_bonds[b2][0] != y: | |
| bgraph[b1,i] = b2 | |
| return (fatoms, fbonds, agraph, bgraph, scope) | |