class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
'''Root: where the dataset should be store, this folder is split into raw_dir (downloaded dataset)
and processed_dir (processed data)'''
super(cateDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
'''if this file exists in raw_dir, the download is not triggered,
the download func is not implemented here)'''
return 'cate_id_01.csv'
@property
def processed_file_names(self):
'''if these files are found in raw_dir, processing is skipped'''
return 'not_implemented.pt'
def download(self):
# Download to `self.raw_dir`.
#path = download_url(url, self.raw_dir)
pass
def process(self):
self.data = pd.read_csv(self.raw_paths[0])
for index, mol in tqdm(self.data.iterrows(), total =self.data.shape[0]):
cate = df["categories"]
categories_main = df["categories_main"]
#get node features
node_feats = self._get_node_features(cate)
# get edge features
edge_feats = self._get_edge_features(cate)
# get adjancy infor
edge_index = self._get_adjacency_info(cate)
#get labels info
label = self._get_labels(categories_main)
# create data object
data = Data(x=node_feats,
edge_index = edge_index,
edge_attr = edge_feats,
y = label)
torch.save(data,
os.path.join(sef.processed_dir,
f'data_{index}.pt))
# Create data object
data = Data(x=node_feats,
edge_index=edge_index,
edge_attr=edge_feats,
y=label,
smiles=mol["smiles"]
)
if self.test:
torch.save(data,
os.path.join(self.processed_dir,
f'data_test_{index}.pt'))
else:
torch.save(data,
os.path.join(self.processed_dir,
f'data_{index}.pt'))
def _get_node_features(self, mol):
"""
This will return a matrix / 2d array of the shape
[Number of Nodes, Node Feature size]
"""
all_node_feats = []
for atom in mol.GetAtoms():
node_feats = []
# Feature 1: Atomic number
node_feats.append(atom.GetAtomicNum())
# Feature 2: Atom degree
node_feats.append(atom.GetDegree())
# Feature 3: Formal charge
node_feats.append(atom.GetFormalCharge())
# Feature 4: Hybridization
node_feats.append(atom.GetHybridization())
# Feature 5: Aromaticity
node_feats.append(atom.GetIsAromatic())
# Feature 6: Total Num Hs
node_feats.append(atom.GetTotalNumHs())
# Feature 7: Radical Electrons
node_feats.append(atom.GetNumRadicalElectrons())
# Feature 8: In Ring
node_feats.append(atom.IsInRing())
# Feature 9: Chirality
node_feats.append(atom.GetChiralTag())
# Append node features to matrix
all_node_feats.append(node_feats)
all_node_feats = np.asarray(all_node_feats)
return torch.tensor(all_node_feats, dtype=torch.float)
def _get_edge_features(self, mol):
"""
This will return a matrix / 2d array of the shape
[Number of edges, Edge Feature size]
"""
all_edge_feats = []
for bond in mol.GetBonds():
edge_feats = []
# Feature 1: Bond type (as double)
edge_feats.append(bond.GetBondTypeAsDouble())
# Feature 2: Rings
edge_feats.append(bond.IsInRing())
# Append node features to matrix (twice, per direction)
all_edge_feats += [edge_feats, edge_feats]
all_edge_feats = np.asarray(all_edge_feats)
return torch.tensor(all_edge_feats, dtype=torch.float)
def _get_adjacency_info(self, mol):
"""
We could also use rdmolops.GetAdjacencyMatrix(mol)
but we want to be sure that the order of the indices
matches the order of the edge features
"""
edge_indices = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_indices += [[i, j], [j, i]]
edge_indices = torch.tensor(edge_indices)
edge_indices = edge_indices.t().to(torch.long).view(2, -1)
return edge_indices
def _get_labels(self, label):
label = np.asarray([label])
return torch.tensor(label, dtype=torch.int64)
def len(self):
return self.data.shape[0]
def get(self, idx):
""" - Equivalent to __getitem__ in pytorch
- Is not needed for PyG's InMemoryDataset
"""
if self.test:
data = torch.load(os.path.join(self.processed_dir,
f'data_test_{idx}.pt'))
else:
data = torch.load(os.path.join(self.processed_dir,
f'data_{idx}.pt'))
return data
Comments