Spaces:
Build error
Build error
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from mol_tree import MolTree | |
| import numpy as np | |
| from jtnn_enc import JTNNEncoder | |
| from mpn import MPN | |
| from jtmpn import JTMPN | |
| import pickle | |
| import os, random | |
| class PairTreeFolder(object): | |
| def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, y_assm=True, replicate=None): | |
| self.data_folder = data_folder | |
| self.data_files = [fn for fn in os.listdir(data_folder)] | |
| self.batch_size = batch_size | |
| self.vocab = vocab | |
| self.num_workers = num_workers | |
| self.y_assm = y_assm | |
| self.shuffle = shuffle | |
| if replicate is not None: #expand is int | |
| self.data_files = self.data_files * replicate | |
| def __iter__(self): | |
| for fn in self.data_files: | |
| fn = os.path.join(self.data_folder, fn) | |
| with open(fn, 'rb') as f: | |
| data = pickle.load(f) | |
| if self.shuffle: | |
| random.shuffle(data) #shuffle data before batch | |
| batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)] | |
| if len(batches[-1]) < self.batch_size: | |
| batches.pop() | |
| dataset = PairTreeDataset(batches, self.vocab, self.y_assm) | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0]) | |
| for b in dataloader: | |
| yield b | |
| del data, batches, dataset, dataloader | |
| class MolTreeFolder(object): | |
| def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None): | |
| self.data_folder = data_folder | |
| self.data_files = [fn for fn in os.listdir(data_folder)] | |
| self.batch_size = batch_size | |
| self.vocab = vocab | |
| self.num_workers = num_workers | |
| self.shuffle = shuffle | |
| self.assm = assm | |
| if replicate is not None: #expand is int | |
| self.data_files = self.data_files * replicate | |
| def __iter__(self): | |
| for fn in self.data_files: | |
| fn = os.path.join(self.data_folder, fn) | |
| with open(fn, 'rb') as f: | |
| data = pickle.load(f) | |
| if self.shuffle: | |
| random.shuffle(data) #shuffle data before batch | |
| batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)] | |
| if len(batches[-1]) < self.batch_size: | |
| batches.pop() | |
| dataset = MolTreeDataset(batches, self.vocab, self.assm) | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0]) | |
| for b in dataloader: | |
| yield b | |
| del data, batches, dataset, dataloader | |
| class PairTreeDataset(Dataset): | |
| def __init__(self, data, vocab, y_assm): | |
| self.data = data | |
| self.vocab = vocab | |
| self.y_assm = y_assm | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| batch0, batch1 = zip(*self.data[idx]) | |
| return tensorize(batch0, self.vocab, assm=False), tensorize(batch1, self.vocab, assm=self.y_assm) | |
| class MolTreeDataset(Dataset): | |
| def __init__(self, data, vocab, assm=True): | |
| self.data = data | |
| self.vocab = vocab | |
| self.assm = assm | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return tensorize(self.data[idx], self.vocab, assm=self.assm) | |
| def tensorize(tree_batch, vocab, assm=True): | |
| set_batch_nodeID(tree_batch, vocab) | |
| smiles_batch = [tree.smiles for tree in tree_batch] | |
| jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch) | |
| jtenc_holder = jtenc_holder | |
| mpn_holder = MPN.tensorize(smiles_batch) | |
| if assm is False: | |
| return tree_batch, jtenc_holder, mpn_holder | |
| cands = [] | |
| batch_idx = [] | |
| for i,mol_tree in enumerate(tree_batch): | |
| for node in mol_tree.nodes: | |
| #Leaf node's attachment is determined by neighboring node's attachment | |
| if node.is_leaf or len(node.cands) == 1: continue | |
| cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] ) | |
| batch_idx.extend([i] * len(node.cands)) | |
| jtmpn_holder = JTMPN.tensorize(cands, mess_dict) | |
| batch_idx = torch.LongTensor(batch_idx) | |
| return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx) | |
| def set_batch_nodeID(mol_batch, vocab): | |
| tot = 0 | |
| for mol_tree in mol_batch: | |
| for node in mol_tree.nodes: | |
| node.idx = tot | |
| node.wid = vocab.get_index(node.smiles) | |
| tot += 1 | |
| class PropMolTreeDataset(Dataset): | |
| def __init__(self, data, vocab, assm=True): | |
| self.data = data | |
| self.vocab = vocab | |
| self.assm = assm | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return tensorize_prop(self.data[idx],self.vocab, assm=self.assm) | |
| class PropMolTreeFolder(object): | |
| def __init__(self, data_folder, vocab, batch_size, num_workers=4, shuffle=True, assm=True, replicate=None): | |
| self.data_folder = data_folder | |
| self.data_files = [fn for fn in os.listdir(data_folder)] | |
| self.batch_size = batch_size | |
| self.vocab = vocab | |
| self.num_workers = num_workers | |
| self.shuffle = shuffle | |
| self.assm = assm | |
| if replicate is not None: #expand is int | |
| self.data_files = self.data_files * replicate | |
| def __iter__(self): | |
| for fn in self.data_files: | |
| fn = os.path.join(self.data_folder, fn) | |
| with open(fn, 'rb') as f: | |
| data = pickle.load(f) | |
| # print(data[0:5]) | |
| if self.shuffle: | |
| random.shuffle(data) #shuffle data before batch | |
| batches = [data[i : i + self.batch_size] for i in range(0, len(data), self.batch_size)] | |
| if len(batches[-1]) < self.batch_size: | |
| batches.pop() | |
| dataset = PropMolTreeDataset(batches, self.vocab, self.assm) | |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.num_workers, collate_fn=lambda x:x[0]) | |
| for b in dataloader: | |
| yield b | |
| del data, batches, dataset, dataloader | |
| def tensorize_prop(data, vocab, assm=True): | |
| tree_batch,prop = list(zip(*data)) | |
| set_batch_nodeID(tree_batch, vocab) | |
| smiles_batch = [tree.smiles for tree in tree_batch] | |
| jtenc_holder,mess_dict = JTNNEncoder.tensorize(tree_batch) | |
| jtenc_holder = jtenc_holder | |
| mpn_holder = MPN.tensorize(smiles_batch) | |
| if assm is False: | |
| return tree_batch, jtenc_holder, mpn_holder | |
| cands = [] | |
| batch_idx = [] | |
| for i,mol_tree in enumerate(tree_batch): | |
| for node in mol_tree.nodes: | |
| #Leaf node's attachment is determined by neighboring node's attachment | |
| if node.is_leaf or len(node.cands) == 1: continue | |
| cands.extend( [(cand, mol_tree.nodes, node) for cand in node.cands] ) | |
| batch_idx.extend([i] * len(node.cands)) | |
| jtmpn_holder = JTMPN.tensorize(cands, mess_dict) | |
| batch_idx = torch.LongTensor(batch_idx) | |
| return tree_batch, jtenc_holder, mpn_holder, (jtmpn_holder,batch_idx), prop |