Training and Testing

PHOTO EMBED

Tue Aug 22 2023 13:17:31 GMT+0000 (Coordinated Universal Time)

Saved by @Tilores #entityresolution #frauddetection #gnn #machinelearning

h_feats = 64
learn_iterations = 50
learn_rate = 0.01

model = EntityGraphModule(
    dataset.graphs[0].ndata["feat"].shape[1],
    dataset.graphs[0].edata["feat"].shape[1],
    h_feats,
    dataset.labels.max().item() + 1
)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)

for _ in range(learn_iterations):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata["feat"].float(), batched_graph.edata["feat"].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["feat"].float(), batched_graph.edata["feat"].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

acc = num_correct / num_tests
print("Test accuracy:", acc)
content_copyCOPY