import torch.nn as nn import torch.nn.functional as F from dgl.nn import NNConv, SAGEConv class EntityGraphModule(nn.Module): def __init__(self, node_in_feats, edge_in_feats, h_feats, num_classes): super(EntityGraphModule, self).__init__() lin = nn.Linear(edge_in_feats, node_in_feats * h_feats) edge_func = lambda e_feat: lin(e_feat) self.conv1 = NNConv(node_in_feats, h_feats, edge_func) self.conv2 = SAGEConv(h_feats, num_classes, "pool") def forward(self, g, node_features, edge_features): h = self.conv1(g, node_features, edge_features) h = F.relu(h) h = self.conv2(g, h) g.ndata["h"] = h return dgl.mean_nodes(g, "h")
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter