import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import NNConv, SAGEConv
class EntityGraphModule(nn.Module):
def __init__(self, node_in_feats, edge_in_feats, h_feats, num_classes):
super(EntityGraphModule, self).__init__()
lin = nn.Linear(edge_in_feats, node_in_feats * h_feats)
edge_func = lambda e_feat: lin(e_feat)
self.conv1 = NNConv(node_in_feats, h_feats, edge_func)
self.conv2 = SAGEConv(h_feats, num_classes, "pool")
def forward(self, g, node_features, edge_features):
h = self.conv1(g, node_features, edge_features)
h = F.relu(h)
h = self.conv2(g, h)
g.ndata["h"] = h
return dgl.mean_nodes(g, "h")