import os
os.environ["DGLBACKEND"] = "pytorch"
import pandas as pd
import torch
import dgl
from dgl.data import DGLDataset
class EntitiesDataset(DGLDataset):
def __init__(self, entitiesFile):
self.entitiesFile = entitiesFile
super().__init__(name="entities")
def process(self):
entities = pd.read_json(self.entitiesFile, lines=1)
self.graphs = []
self.labels = []
for _, entity in entities.iterrows():
a = []
b = []
r1_feat = []
r2_feat = []
for edge in entity["edges"]:
a.append(edge["a"])
b.append(edge["b"])
r1_feat.append(edge["R1"])
r2_feat.append(edge["R2"])
a = torch.LongTensor(a)
b = torch.LongTensor(b)
edge_features = torch.LongTensor([r1_feat, r2_feat]).t()
node_feat = [[node["totalValue"], node["items"]] for node in entity["records"]]
node_features = torch.tensor(node_feat)
g = dgl.graph((a, b), num_nodes=len(entity["records"]))
g.edata["feat"] = edge_features
g.ndata["feat"] = node_features
g = dgl.add_self_loop(g)
self.graphs.append(g)
self.labels.append(entity["fraud"])
self.labels = torch.LongTensor(self.labels)
def __getitem__(self, i):
return self.graphs[i], self.labels[i]
def __len__(self):
return len(self.graphs)
dataset = EntitiesDataset("./entities.jsonl")
print(dataset)
print(dataset[0])