Preview:
import os

os.environ["DGLBACKEND"] = "pytorch"
import pandas as pd
import torch
import dgl
from dgl.data import DGLDataset

class EntitiesDataset(DGLDataset):
    def __init__(self, entitiesFile):
        self.entitiesFile = entitiesFile
        super().__init__(name="entities")

    def process(self):
        entities = pd.read_json(self.entitiesFile, lines=1)

        self.graphs = []
        self.labels = []

        for _, entity in entities.iterrows():
            a = []
            b = []
            r1_feat = []
            r2_feat = []
            for edge in entity["edges"]:
                a.append(edge["a"])
                b.append(edge["b"])
                r1_feat.append(edge["R1"])
                r2_feat.append(edge["R2"])
            a = torch.LongTensor(a)
            b = torch.LongTensor(b)
            edge_features = torch.LongTensor([r1_feat, r2_feat]).t()

            node_feat = [[node["totalValue"], node["items"]] for node in entity["records"]]
            node_features = torch.tensor(node_feat)

            g = dgl.graph((a, b), num_nodes=len(entity["records"]))
            g.edata["feat"] = edge_features
            g.ndata["feat"] = node_features
            g = dgl.add_self_loop(g)

            self.graphs.append(g)
            self.labels.append(entity["fraud"])

        self.labels = torch.LongTensor(self.labels)

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)

dataset = EntitiesDataset("./entities.jsonl")
print(dataset)
print(dataset[0])
downloadDownload PNG downloadDownload JPEG downloadDownload SVG

Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!

Click to optimize width for Twitter