import torch
from torch_geometric.data import InMemoryDataset
from tqdm import tqdm
class classDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(classDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return []
@property
def processed_file_names(self):
return ['./train_vec.csv']
def download(self):
pass
def process(self):
data_list = []
# process by label_vec
# treat each email in a label categories as a node, and therefore all emails in the same label form a graph
grouped = df.groupby('label_vec') # group the preprocessed data by label_vec and iterate over these groups (use for)
for label_vec, group in tqdm(grouped):
# each iteration, the text_vec in each group are categorically encoded again
label_email_id = LabelEncoder().fit_transform(group.text_vec) # since for each graph, the node index should count from 0.
group = group.reset_index(drop=True)
group['label_email_id'] = label_email_id
node_features = group.loc[group.label_vec==label_vec,['label_email_id','text_vec']].sort_values('label_email_id').label_vec.drop_duplicates().values
node_features = torch.LongTensor(node_features).unsqueeze(1)
target_nodes = group.label_email_id.values[1:]
source_nodes = group.label_email_id.values[:-1]
edge_index = torch.tensor([source_nodes,
target_nodes], dtype=torch.long)
x = node_features
y = torch.FloatTensor([group.label_vec.values[0]]).type(torch.LongTensor)
data = Data(x=x, edge_index=edge_index, y=y)
data_list.append(data)
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter