Spaces:
Build error
Build error
| import rdkit | |
| import rdkit.Chem as Chem | |
| from scipy.sparse import csr_matrix | |
| from scipy.sparse.csgraph import minimum_spanning_tree | |
| from collections import defaultdict | |
| from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions | |
| from vocab import Vocab | |
| MST_MAX_WEIGHT = 100 | |
| MAX_NCAND = 2000 | |
| def set_atommap(mol, num=0): | |
| for atom in mol.GetAtoms(): | |
| atom.SetAtomMapNum(num) | |
| def get_mol(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| Chem.Kekulize(mol, clearAromaticFlags=True) | |
| return mol | |
| def get_smiles(mol): | |
| return Chem.MolToSmiles(mol, kekuleSmiles=True) | |
| def decode_stereo(smiles2D): | |
| mol = Chem.MolFromSmiles(smiles2D) | |
| dec_isomers = list(EnumerateStereoisomers(mol)) | |
| dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers] | |
| smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers] | |
| chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"] | |
| if len(chiralN) > 0: | |
| for mol in dec_isomers: | |
| for idx in chiralN: | |
| mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) | |
| smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True)) | |
| return smiles3D | |
| def sanitize(mol): | |
| try: | |
| smiles = get_smiles(mol) | |
| mol = get_mol(smiles) | |
| except Exception as e: | |
| return None | |
| return mol | |
| def copy_atom(atom): | |
| new_atom = Chem.Atom(atom.GetSymbol()) | |
| new_atom.SetFormalCharge(atom.GetFormalCharge()) | |
| new_atom.SetAtomMapNum(atom.GetAtomMapNum()) | |
| return new_atom | |
| def copy_edit_mol(mol): | |
| new_mol = Chem.RWMol(Chem.MolFromSmiles('')) | |
| for atom in mol.GetAtoms(): | |
| new_atom = copy_atom(atom) | |
| new_mol.AddAtom(new_atom) | |
| for bond in mol.GetBonds(): | |
| a1 = bond.GetBeginAtom().GetIdx() | |
| a2 = bond.GetEndAtom().GetIdx() | |
| bt = bond.GetBondType() | |
| new_mol.AddBond(a1, a2, bt) | |
| return new_mol | |
| def get_clique_mol(mol, atoms): | |
| smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True) | |
| new_mol = Chem.MolFromSmiles(smiles, sanitize=False) | |
| new_mol = copy_edit_mol(new_mol).GetMol() | |
| new_mol = sanitize(new_mol) #We assume this is not None | |
| return new_mol | |
| def tree_decomp(mol): | |
| n_atoms = mol.GetNumAtoms() | |
| if n_atoms == 1: #special case | |
| return [[0]], [] | |
| cliques = [] | |
| for bond in mol.GetBonds(): | |
| a1 = bond.GetBeginAtom().GetIdx() | |
| a2 = bond.GetEndAtom().GetIdx() | |
| if not bond.IsInRing(): | |
| cliques.append([a1,a2]) | |
| ssr = [list(x) for x in Chem.GetSymmSSSR(mol)] | |
| cliques.extend(ssr) | |
| nei_list = [[] for i in range(n_atoms)] | |
| for i in range(len(cliques)): | |
| for atom in cliques[i]: | |
| nei_list[atom].append(i) | |
| #Merge Rings with intersection > 2 atoms | |
| for i in range(len(cliques)): | |
| if len(cliques[i]) <= 2: continue | |
| for atom in cliques[i]: | |
| for j in nei_list[atom]: | |
| if i >= j or len(cliques[j]) <= 2: continue | |
| inter = set(cliques[i]) & set(cliques[j]) | |
| if len(inter) > 2: | |
| cliques[i].extend(cliques[j]) | |
| cliques[i] = list(set(cliques[i])) | |
| cliques[j] = [] | |
| cliques = [c for c in cliques if len(c) > 0] | |
| nei_list = [[] for i in range(n_atoms)] | |
| for i in range(len(cliques)): | |
| for atom in cliques[i]: | |
| nei_list[atom].append(i) | |
| #Build edges and add singleton cliques | |
| edges = defaultdict(int) | |
| for atom in range(n_atoms): | |
| if len(nei_list[atom]) <= 1: | |
| continue | |
| cnei = nei_list[atom] | |
| bonds = [c for c in cnei if len(cliques[c]) == 2] | |
| rings = [c for c in cnei if len(cliques[c]) > 4] | |
| if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2): #In general, if len(cnei) >= 3, a singleton should be added, but 1 bond + 2 ring is currently not dealt with. | |
| cliques.append([atom]) | |
| c2 = len(cliques) - 1 | |
| for c1 in cnei: | |
| edges[(c1,c2)] = 1 | |
| elif len(rings) > 2: #Multiple (n>2) complex rings | |
| cliques.append([atom]) | |
| c2 = len(cliques) - 1 | |
| for c1 in cnei: | |
| edges[(c1,c2)] = MST_MAX_WEIGHT - 1 | |
| else: | |
| for i in range(len(cnei)): | |
| for j in range(i + 1, len(cnei)): | |
| c1,c2 = cnei[i],cnei[j] | |
| inter = set(cliques[c1]) & set(cliques[c2]) | |
| if edges[(c1,c2)] < len(inter): | |
| edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction | |
| edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()] | |
| if len(edges) == 0: | |
| return cliques, edges | |
| #Compute Maximum Spanning Tree | |
| row,col,data = zip(*edges) | |
| n_clique = len(cliques) | |
| clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) ) | |
| junc_tree = minimum_spanning_tree(clique_graph) | |
| row,col = junc_tree.nonzero() | |
| edges = [(row[i],col[i]) for i in range(len(row))] | |
| return (cliques, edges) | |
| def atom_equal(a1, a2): | |
| return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge() | |
| #Bond type not considered because all aromatic (so SINGLE matches DOUBLE) | |
| def ring_bond_equal(b1, b2, reverse=False): | |
| b1 = (b1.GetBeginAtom(), b1.GetEndAtom()) | |
| if reverse: | |
| b2 = (b2.GetEndAtom(), b2.GetBeginAtom()) | |
| else: | |
| b2 = (b2.GetBeginAtom(), b2.GetEndAtom()) | |
| return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1]) | |
| def attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap): | |
| prev_nids = [node.nid for node in prev_nodes] | |
| for nei_node in prev_nodes + neighbors: | |
| nei_id,nei_mol = nei_node.nid,nei_node.mol | |
| amap = nei_amap[nei_id] | |
| for atom in nei_mol.GetAtoms(): | |
| if atom.GetIdx() not in amap: | |
| new_atom = copy_atom(atom) | |
| amap[atom.GetIdx()] = ctr_mol.AddAtom(new_atom) | |
| if nei_mol.GetNumBonds() == 0: | |
| nei_atom = nei_mol.GetAtomWithIdx(0) | |
| ctr_atom = ctr_mol.GetAtomWithIdx(amap[0]) | |
| ctr_atom.SetAtomMapNum(nei_atom.GetAtomMapNum()) | |
| else: | |
| for bond in nei_mol.GetBonds(): | |
| a1 = amap[bond.GetBeginAtom().GetIdx()] | |
| a2 = amap[bond.GetEndAtom().GetIdx()] | |
| if ctr_mol.GetBondBetweenAtoms(a1, a2) is None: | |
| ctr_mol.AddBond(a1, a2, bond.GetBondType()) | |
| elif nei_id in prev_nids: #father node overrides | |
| ctr_mol.RemoveBond(a1, a2) | |
| ctr_mol.AddBond(a1, a2, bond.GetBondType()) | |
| return ctr_mol | |
| def local_attach(ctr_mol, neighbors, prev_nodes, amap_list): | |
| ctr_mol = copy_edit_mol(ctr_mol) | |
| nei_amap = {nei.nid:{} for nei in prev_nodes + neighbors} | |
| for nei_id,ctr_atom,nei_atom in amap_list: | |
| nei_amap[nei_id][nei_atom] = ctr_atom | |
| ctr_mol = attach_mols(ctr_mol, neighbors, prev_nodes, nei_amap) | |
| return ctr_mol.GetMol() | |
| #This version records idx mapping between ctr_mol and nei_mol | |
| def enum_attach(ctr_mol, nei_node, amap, singletons): | |
| nei_mol,nei_idx = nei_node.mol,nei_node.nid | |
| att_confs = [] | |
| black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons] | |
| ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list] | |
| ctr_bonds = [bond for bond in ctr_mol.GetBonds()] | |
| if nei_mol.GetNumBonds() == 0: #neighbor singleton | |
| nei_atom = nei_mol.GetAtomWithIdx(0) | |
| used_list = [atom_idx for _,atom_idx,_ in amap] | |
| for atom in ctr_atoms: | |
| if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list: | |
| new_amap = amap + [(nei_idx, atom.GetIdx(), 0)] | |
| att_confs.append( new_amap ) | |
| elif nei_mol.GetNumBonds() == 1: #neighbor is a bond | |
| bond = nei_mol.GetBondWithIdx(0) | |
| bond_val = int(bond.GetBondTypeAsDouble()) | |
| b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom() | |
| for atom in ctr_atoms: | |
| #Optimize if atom is carbon (other atoms may change valence) | |
| if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val: | |
| continue | |
| if atom_equal(atom, b1): | |
| new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())] | |
| att_confs.append( new_amap ) | |
| elif atom_equal(atom, b2): | |
| new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())] | |
| att_confs.append( new_amap ) | |
| else: | |
| #intersection is an atom | |
| for a1 in ctr_atoms: | |
| for a2 in nei_mol.GetAtoms(): | |
| if atom_equal(a1, a2): | |
| #Optimize if atom is carbon (other atoms may change valence) | |
| if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4: | |
| continue | |
| new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())] | |
| att_confs.append( new_amap ) | |
| #intersection is an bond | |
| if ctr_mol.GetNumBonds() > 1: | |
| for b1 in ctr_bonds: | |
| for b2 in nei_mol.GetBonds(): | |
| if ring_bond_equal(b1, b2): | |
| new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())] | |
| att_confs.append( new_amap ) | |
| if ring_bond_equal(b1, b2, reverse=True): | |
| new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), (nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())] | |
| att_confs.append( new_amap ) | |
| return att_confs | |
| #Try rings first: Speed-Up | |
| def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[]): | |
| all_attach_confs = [] | |
| singletons = [nei_node.nid for nei_node in neighbors + prev_nodes if nei_node.mol.GetNumAtoms() == 1] | |
| def search(cur_amap, depth): | |
| if len(all_attach_confs) > MAX_NCAND: | |
| return | |
| if depth == len(neighbors): | |
| all_attach_confs.append(cur_amap) | |
| return | |
| nei_node = neighbors[depth] | |
| cand_amap = enum_attach(node.mol, nei_node, cur_amap, singletons) | |
| cand_smiles = set() | |
| candidates = [] | |
| for amap in cand_amap: | |
| cand_mol = local_attach(node.mol, neighbors[:depth+1], prev_nodes, amap) | |
| cand_mol = sanitize(cand_mol) | |
| if cand_mol is None: | |
| continue | |
| smiles = get_smiles(cand_mol) | |
| if smiles in cand_smiles: | |
| continue | |
| cand_smiles.add(smiles) | |
| candidates.append(amap) | |
| if len(candidates) == 0: | |
| return | |
| for new_amap in candidates: | |
| search(new_amap, depth + 1) | |
| search(prev_amap, 0) | |
| cand_smiles = set() | |
| candidates = [] | |
| aroma_score = [] | |
| for amap in all_attach_confs: | |
| cand_mol = local_attach(node.mol, neighbors, prev_nodes, amap) | |
| cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol)) | |
| smiles = Chem.MolToSmiles(cand_mol) | |
| if smiles in cand_smiles or check_singleton(cand_mol, node, neighbors) == False: | |
| continue | |
| cand_smiles.add(smiles) | |
| candidates.append( (smiles,amap) ) | |
| aroma_score.append( check_aroma(cand_mol, node, neighbors) ) | |
| return candidates, aroma_score | |
| def check_singleton(cand_mol, ctr_node, nei_nodes): | |
| rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() > 2] | |
| singletons = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() == 1] | |
| if len(singletons) > 0 or len(rings) == 0: return True | |
| n_leaf2_atoms = 0 | |
| for atom in cand_mol.GetAtoms(): | |
| nei_leaf_atoms = [a for a in atom.GetNeighbors() if not a.IsInRing()] #a.GetDegree() == 1] | |
| if len(nei_leaf_atoms) > 1: | |
| n_leaf2_atoms += 1 | |
| return n_leaf2_atoms == 0 | |
| def check_aroma(cand_mol, ctr_node, nei_nodes): | |
| rings = [node for node in nei_nodes + [ctr_node] if node.mol.GetNumAtoms() >= 3] | |
| if len(rings) < 2: return 0 #Only multi-ring system needs to be checked | |
| get_nid = lambda x: 0 if x.is_leaf else x.nid | |
| benzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in Vocab.benzynes] | |
| penzynes = [get_nid(node) for node in nei_nodes + [ctr_node] if node.smiles in Vocab.penzynes] | |
| if len(benzynes) + len(penzynes) == 0: | |
| return 0 #No specific aromatic rings | |
| n_aroma_atoms = 0 | |
| for atom in cand_mol.GetAtoms(): | |
| if atom.GetAtomMapNum() in benzynes+penzynes and atom.GetIsAromatic(): | |
| n_aroma_atoms += 1 | |
| if n_aroma_atoms >= len(benzynes) * 4 + len(penzynes) * 3: | |
| return 1000 | |
| else: | |
| return -0.001 | |
| #Only used for debugging purpose | |
| def dfs_assemble(cur_mol, global_amap, fa_amap, cur_node, fa_node): | |
| fa_nid = fa_node.nid if fa_node is not None else -1 | |
| prev_nodes = [fa_node] if fa_node is not None else [] | |
| children = [nei for nei in cur_node.neighbors if nei.nid != fa_nid] | |
| neighbors = [nei for nei in children if nei.mol.GetNumAtoms() > 1] | |
| neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True) | |
| singletons = [nei for nei in children if nei.mol.GetNumAtoms() == 1] | |
| neighbors = singletons + neighbors | |
| cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node.nid] | |
| cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap) | |
| cand_smiles,cand_amap = zip(*cands) | |
| label_idx = cand_smiles.index(cur_node.label) | |
| label_amap = cand_amap[label_idx] | |
| for nei_id,ctr_atom,nei_atom in label_amap: | |
| if nei_id == fa_nid: | |
| continue | |
| global_amap[nei_id][nei_atom] = global_amap[cur_node.nid][ctr_atom] | |
| cur_mol = attach_mols(cur_mol, children, [], global_amap) #father is already attached | |
| for nei_node in children: | |
| if not nei_node.is_leaf: | |
| dfs_assemble(cur_mol, global_amap, label_amap, nei_node, cur_node) | |
| if __name__ == "__main__": | |
| import sys | |
| from mol_tree import MolTree | |
| lg = rdkit.RDLogger.logger() | |
| lg.setLevel(rdkit.RDLogger.CRITICAL) | |
| smiles = ["O=C1[C@@H]2C=C[C@@H](C=CC2)C1(c1ccccc1)c1ccccc1","O=C([O-])CC[C@@]12CCCC[C@]1(O)OC(=O)CC2", "ON=C1C[C@H]2CC3(C[C@@H](C1)c1ccccc12)OCCO3", "C[C@H]1CC(=O)[C@H]2[C@@]3(O)C(=O)c4cccc(O)c4[C@@H]4O[C@@]43[C@@H](O)C[C@]2(O)C1", 'Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br', 'CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1', "O=c1c2ccc3c(=O)n(-c4nccs4)c(=O)c4ccc(c(=O)n1-c1nccs1)c2c34", "O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1"] | |
| def tree_test(): | |
| for s in sys.stdin: | |
| s = s.split()[0] | |
| tree = MolTree(s) | |
| print('-------------------------------------------') | |
| print(s) | |
| for node in tree.nodes: | |
| print(node.smiles, [x.smiles for x in node.neighbors]) | |
| def decode_test(): | |
| wrong = 0 | |
| for tot,s in enumerate(sys.stdin): | |
| s = s.split()[0] | |
| tree = MolTree(s) | |
| tree.recover() | |
| cur_mol = copy_edit_mol(tree.nodes[0].mol) | |
| global_amap = [{}] + [{} for node in tree.nodes] | |
| global_amap[1] = {atom.GetIdx():atom.GetIdx() for atom in cur_mol.GetAtoms()} | |
| dfs_assemble(cur_mol, global_amap, [], tree.nodes[0], None) | |
| cur_mol = cur_mol.GetMol() | |
| cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) | |
| set_atommap(cur_mol) | |
| dec_smiles = Chem.MolToSmiles(cur_mol) | |
| gold_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(s)) | |
| if gold_smiles != dec_smiles: | |
| print(gold_smiles, dec_smiles) | |
| wrong += 1 | |
| print(wrong, tot + 1) | |
| def enum_test(): | |
| for s in sys.stdin: | |
| s = s.split()[0] | |
| tree = MolTree(s) | |
| tree.recover() | |
| tree.assemble() | |
| for node in tree.nodes: | |
| if node.label not in node.cands: | |
| print(tree.smiles) | |
| print(node.smiles, [x.smiles for x in node.neighbors]) | |
| print(node.label, len(node.cands)) | |
| def count(): | |
| cnt,n = 0,0 | |
| for s in sys.stdin: | |
| s = s.split()[0] | |
| tree = MolTree(s) | |
| tree.recover() | |
| tree.assemble() | |
| for node in tree.nodes: | |
| cnt += len(node.cands) | |
| n += len(tree.nodes) | |
| #print cnt * 1.0 / n | |
| count() | |