forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gin.py
190 lines (160 loc) · 5.92 KB
/
gin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
How Powerful are Graph Neural Networks
https://arxiv.org/abs/1810.00826
https://openreview.net/forum?id=ryGs6iA5Km
Author's implementation: https://github.com/weihua916/powerful-gnns
"""
import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn
from dgl.nn.mxnet.conv import GINConv
from dgl.nn.mxnet.glob import AvgPooling, MaxPooling, SumPooling
class ApplyNodeFunc(nn.Block):
"""Update the node feature hv with MLP, BN and ReLU."""
def __init__(self, mlp):
super(ApplyNodeFunc, self).__init__()
with self.name_scope():
self.mlp = mlp
self.bn = nn.BatchNorm(in_channels=self.mlp.output_dim)
def forward(self, h):
h = self.mlp(h)
h = self.bn(h)
h = nd.relu(h)
return h
class MLP(nn.Block):
"""MLP with linear output"""
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
"""MLP layers construction
Paramters
---------
num_layers: int
The number of linear layers
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
"""
super(MLP, self).__init__()
self.linear_or_not = True
self.num_layers = num_layers
self.output_dim = output_dim
with self.name_scope():
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = nn.Dense(output_dim, in_units=input_dim)
else:
self.linear_or_not = False
self.linears = nn.Sequential()
self.batch_norms = nn.Sequential()
self.linears.add(nn.Dense(hidden_dim, in_units=input_dim))
for layer in range(num_layers - 2):
self.linears.add(nn.Dense(hidden_dim, in_units=hidden_dim))
self.linears.add(nn.Dense(output_dim, in_units=hidden_dim))
for layer in range(num_layers - 1):
self.batch_norms.add(nn.BatchNorm(in_channels=hidden_dim))
def forward(self, x):
if self.linear_or_not:
return self.linear(x)
else:
h = x
for i in range(self.num_layers - 1):
h = nd.relu(self.batch_norms[i](self.linears[i](h)))
return self.linears[-1](h)
class GIN(nn.Block):
"""GIN model"""
def __init__(
self,
num_layers,
num_mlp_layers,
input_dim,
hidden_dim,
output_dim,
final_dropout,
learn_eps,
graph_pooling_type,
neighbor_pooling_type,
):
"""model parameters setting
Paramters
---------
num_layers: int
The number of linear layers in the neural network
num_mlp_layers: int
The number of linear layers in mlps
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
final_dropout: float
dropout ratio on the final linear layer
learn_eps: boolean
If True, learn epsilon to distinguish center nodes from neighbors
If False, aggregate neighbors and center nodes altogether.
neighbor_pooling_type: str
how to aggregate neighbors (sum, mean, or max)
graph_pooling_type: str
how to aggregate entire nodes in a graph (sum, mean or max)
"""
super(GIN, self).__init__()
self.num_layers = num_layers
self.learn_eps = learn_eps
with self.name_scope():
# List of MLPs
self.ginlayers = nn.Sequential()
self.batch_norms = nn.Sequential()
for i in range(self.num_layers - 1):
if i == 0:
mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)
else:
mlp = MLP(
num_mlp_layers, hidden_dim, hidden_dim, hidden_dim
)
self.ginlayers.add(
GINConv(
ApplyNodeFunc(mlp),
neighbor_pooling_type,
0,
self.learn_eps,
)
)
self.batch_norms.add(nn.BatchNorm(in_channels=hidden_dim))
self.linears_prediction = nn.Sequential()
for i in range(num_layers):
if i == 0:
self.linears_prediction.add(
nn.Dense(output_dim, in_units=input_dim)
)
else:
self.linears_prediction.add(
nn.Dense(output_dim, in_units=hidden_dim)
)
self.drop = nn.Dropout(final_dropout)
if graph_pooling_type == "sum":
self.pool = SumPooling()
elif graph_pooling_type == "mean":
self.pool = AvgPooling()
elif graph_pooling_type == "max":
self.pool = MaxPooling()
else:
raise NotImplementedError
def forward(self, g, h):
hidden_rep = [h]
for i in range(self.num_layers - 1):
h = self.ginlayers[i](g, h)
h = self.batch_norms[i](h)
h = nd.relu(h)
hidden_rep.append(h)
score_over_layer = 0
# perform pooling over all nodes in each graph in every layer
for i, h in enumerate(hidden_rep):
pooled_h = self.pool(g, h)
score_over_layer = score_over_layer + self.drop(
self.linears_prediction[i](pooled_h)
)
return score_over_layer