GNN dataset create2

PHOTO EMBED

Wed Sep 01 2021 09:04:08 GMT+0000 (Coordinated Universal Time)

Saved by @QuinnFox12 #gnn #dataset

import torch
from torch_geometric.data import InMemoryDataset
from tqdm import tqdm

class classDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(classDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []
    @property
    def processed_file_names(self):
        return ['./train_vec.csv']

    def download(self):
        pass
    
    def process(self):
        
        data_list = []
        # process by label_vec
        # treat each email in a label categories as a node, and therefore all emails in the same label form a graph
        grouped = df.groupby('label_vec') # group the preprocessed data by label_vec and iterate over these groups (use for)
        for label_vec, group in tqdm(grouped):
            # each iteration, the text_vec in each group are categorically encoded again
            label_email_id = LabelEncoder().fit_transform(group.text_vec) # since for each graph, the node index should count from 0. 
            group = group.reset_index(drop=True)
            group['label_email_id'] = label_email_id
            node_features = group.loc[group.label_vec==label_vec,['label_email_id','text_vec']].sort_values('label_email_id').label_vec.drop_duplicates().values

            node_features = torch.LongTensor(node_features).unsqueeze(1)
            target_nodes = group.label_email_id.values[1:]
            source_nodes = group.label_email_id.values[:-1]

            edge_index = torch.tensor([source_nodes,
                                   target_nodes], dtype=torch.long)
            x = node_features

            y = torch.FloatTensor([group.label_vec.values[0]]).type(torch.LongTensor)

            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
content_copyCOPY