-
Notifications
You must be signed in to change notification settings - Fork 1
/
mlp.py
58 lines (49 loc) · 2.08 KB
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
This file is a implementation of Multilayer Perceptron (MLP) based on Neural Solvers framework: https://github.com/Photon-AI-Research/NeuralSolvers repository.
Core Developers:
Patrick Stiller (HZDR)
Maksim Zhdanov (HZDR)
Jeyhun Rustamov (HZDR)
Raj Dhansukhbhai Sutariya (HZDR)
P. Stiller, F. Bethke, M. Böhme, R. Pausch, S. Torge, A. Debus, J. Vorberger, M.Bussmann, N. Hoffmann: Large-scale Neural Solvers for Partial Differential Equations (2020).
"""
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, input_size, output_size, hidden_size, num_hidden, lb, ub, activation=torch.tanh, normalize=True):
super(MLP, self).__init__()
self.linear_layers = nn.ModuleList()
self.activation = activation
self.init_layers(input_size, output_size, hidden_size,num_hidden)
self.lb = torch.Tensor(lb).float()
self.ub = torch.Tensor(ub).float()
self.normalize = normalize
def init_layers(self, input_size, output_size, hidden_size, num_hidden):
self.linear_layers.append(nn.Linear(input_size, hidden_size))
for _ in range(num_hidden):
self.linear_layers.append(nn.Linear(hidden_size, hidden_size))
self.linear_layers.append(nn.Linear(hidden_size, output_size))
for m in self.linear_layers:
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def forward(self, x):
if self.normalize:
x = 2.0*(x - self.lb)/(self.ub - self.lb) - 1.0
for i in range(len(self.linear_layers) - 1):
x = self.linear_layers[i](x)
x = self.activation(x)
x = self.linear_layers[-1](x)
return x
def cuda(self):
super(MLP, self).cuda()
self.lb = self.lb.cuda()
self.ub = self.ub.cuda()
def cpu(self):
super(MLP, self).cpu()
self.lb = self.lb.cpu()
self.ub = self.ub.cpu()
def to(self, device):
super(MLP, self).to(device)
self.lb = self.lb.to(device)
self.ub = self.ub.to(device)