import torch from torch_geometric.data import InMemoryDataset from tqdm import tqdm class classDataset(InMemoryDataset): def __init__(self, root, transform=None, pre_transform=None): super(classDataset, self).__init__(root, transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_file_names(self): return [] @property def processed_file_names(self): return ['./train_vec.csv'] def download(self): pass def process(self): data_list = [] # process by label_vec # treat each email in a label categories as a node, and therefore all emails in the same label form a graph grouped = df.groupby('label_vec') # group the preprocessed data by label_vec and iterate over these groups (use for) for label_vec, group in tqdm(grouped): # each iteration, the text_vec in each group are categorically encoded again label_email_id = LabelEncoder().fit_transform(group.text_vec) # since for each graph, the node index should count from 0. group = group.reset_index(drop=True) group['label_email_id'] = label_email_id node_features = group.loc[group.label_vec==label_vec,['label_email_id','text_vec']].sort_values('label_email_id').label_vec.drop_duplicates().values node_features = torch.LongTensor(node_features).unsqueeze(1) target_nodes = group.label_email_id.values[1:] source_nodes = group.label_email_id.values[:-1] edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long) x = node_features y = torch.FloatTensor([group.label_vec.values[0]]).type(torch.LongTensor) data = Data(x=x, edge_index=edge_index, y=y) data_list.append(data) data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[0])
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter