GNN dataset create1
Wed Sep 01 2021 09:03:36 GMT+0000 (Coordinated Universal Time)
Saved by @QuinnFox12 #gnn #dataset
class MyOwnDataset(Dataset): def __init__(self, root, transform=None, pre_transform=None): '''Root: where the dataset should be store, this folder is split into raw_dir (downloaded dataset) and processed_dir (processed data)''' super(cateDataset, self).__init__(root, transform, pre_transform) @property def raw_file_names(self): '''if this file exists in raw_dir, the download is not triggered, the download func is not implemented here)''' return 'cate_id_01.csv' @property def processed_file_names(self): '''if these files are found in raw_dir, processing is skipped''' return 'not_implemented.pt' def download(self): # Download to `self.raw_dir`. #path = download_url(url, self.raw_dir) pass def process(self): self.data = pd.read_csv(self.raw_paths[0]) for index, mol in tqdm(self.data.iterrows(), total =self.data.shape[0]): cate = df["categories"] categories_main = df["categories_main"] #get node features node_feats = self._get_node_features(cate) # get edge features edge_feats = self._get_edge_features(cate) # get adjancy infor edge_index = self._get_adjacency_info(cate) #get labels info label = self._get_labels(categories_main) # create data object data = Data(x=node_feats, edge_index = edge_index, edge_attr = edge_feats, y = label) torch.save(data, os.path.join(sef.processed_dir, f'data_{index}.pt)) # Create data object data = Data(x=node_feats, edge_index=edge_index, edge_attr=edge_feats, y=label, smiles=mol["smiles"] ) if self.test: torch.save(data, os.path.join(self.processed_dir, f'data_test_{index}.pt')) else: torch.save(data, os.path.join(self.processed_dir, f'data_{index}.pt')) def _get_node_features(self, mol): """ This will return a matrix / 2d array of the shape [Number of Nodes, Node Feature size] """ all_node_feats = [] for atom in mol.GetAtoms(): node_feats = [] # Feature 1: Atomic number node_feats.append(atom.GetAtomicNum()) # Feature 2: Atom degree node_feats.append(atom.GetDegree()) # Feature 3: Formal charge node_feats.append(atom.GetFormalCharge()) # Feature 4: Hybridization node_feats.append(atom.GetHybridization()) # Feature 5: Aromaticity node_feats.append(atom.GetIsAromatic()) # Feature 6: Total Num Hs node_feats.append(atom.GetTotalNumHs()) # Feature 7: Radical Electrons node_feats.append(atom.GetNumRadicalElectrons()) # Feature 8: In Ring node_feats.append(atom.IsInRing()) # Feature 9: Chirality node_feats.append(atom.GetChiralTag()) # Append node features to matrix all_node_feats.append(node_feats) all_node_feats = np.asarray(all_node_feats) return torch.tensor(all_node_feats, dtype=torch.float) def _get_edge_features(self, mol): """ This will return a matrix / 2d array of the shape [Number of edges, Edge Feature size] """ all_edge_feats = [] for bond in mol.GetBonds(): edge_feats = [] # Feature 1: Bond type (as double) edge_feats.append(bond.GetBondTypeAsDouble()) # Feature 2: Rings edge_feats.append(bond.IsInRing()) # Append node features to matrix (twice, per direction) all_edge_feats += [edge_feats, edge_feats] all_edge_feats = np.asarray(all_edge_feats) return torch.tensor(all_edge_feats, dtype=torch.float) def _get_adjacency_info(self, mol): """ We could also use rdmolops.GetAdjacencyMatrix(mol) but we want to be sure that the order of the indices matches the order of the edge features """ edge_indices = [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() edge_indices += [[i, j], [j, i]] edge_indices = torch.tensor(edge_indices) edge_indices = edge_indices.t().to(torch.long).view(2, -1) return edge_indices def _get_labels(self, label): label = np.asarray([label]) return torch.tensor(label, dtype=torch.int64) def len(self): return self.data.shape[0] def get(self, idx): """ - Equivalent to __getitem__ in pytorch - Is not needed for PyG's InMemoryDataset """ if self.test: data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt')) else: data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt')) return data
Comments