import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import NNConv, SAGEConv

class EntityGraphModule(nn.Module):
    def __init__(self, node_in_feats, edge_in_feats, h_feats, num_classes):
        super(EntityGraphModule, self).__init__()
        lin = nn.Linear(edge_in_feats, node_in_feats * h_feats)
        edge_func = lambda e_feat: lin(e_feat)
        self.conv1 = NNConv(node_in_feats, h_feats, edge_func)

        self.conv2 = SAGEConv(h_feats, num_classes, "pool")

    def forward(self, g, node_features, edge_features):
        h = self.conv1(g, node_features, edge_features)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")
