h_feats = 64 learn_iterations = 50 learn_rate = 0.01 model = EntityGraphModule( dataset.graphs[0].ndata["feat"].shape[1], dataset.graphs[0].edata["feat"].shape[1], h_feats, dataset.labels.max().item() + 1 ) optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) for _ in range(learn_iterations): for batched_graph, labels in train_dataloader: pred = model(batched_graph, batched_graph.ndata["feat"].float(), batched_graph.edata["feat"].float()) loss = F.cross_entropy(pred, labels) optimizer.zero_grad() loss.backward() optimizer.step() num_correct = 0 num_tests = 0 for batched_graph, labels in test_dataloader: pred = model(batched_graph, batched_graph.ndata["feat"].float(), batched_graph.edata["feat"].float()) num_correct += (pred.argmax(1) == labels).sum().item() num_tests += len(labels) acc = num_correct / num_tests print("Test accuracy:", acc)
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter