-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.mli
61 lines (51 loc) · 2.48 KB
/
model.mli
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
open Lacaml.D
open Layer
open Matrix
open Loss
open Actv
type matrix = Matrix.t
type layer = Layer.t
type loss = Loss.t
(* [model] represents a list of layers. *)
type model = layer list
(* [network] represents the entire neural network, which includes a model
* and a loss function. *)
type network = {
model : model;
loss : loss
}
(* [propagate model matrix] propagates [matrix] through the neural
* network represented by [model] and returns the activation matrix list.
* requires: the column space of [matrix] must match the row space
of the first layer in [model] *)
val propagate: model -> matrix -> matrix list
(* [backpropagate model m1 m2 loss] performs the backpropagation
* algorithm on [model], given the reference matrix [m1] and the
* entire model's output matrix [m2]. *)
val backpropagate: model -> matrix list -> matrix -> matrix list
(* [full_pass n x y] performs gradient update on the model weights and biases
* by propagating some input, calculating the gradient for the output loss,
* and backpropagating the gradient to update the weights and biases. It
* returns the network updated with these weights and biases. [x] is the
* data point, and [y] is that data point's label. *)
val full_pass : network -> matrix -> matrix -> network
(* [train n x steps epoch ?id] samples over the inputted dataset [x] and updates
* the network [n] iteratively in memory. Its output is the network and a list
* of saved files that enable later loading of the network. The network iterates
* [epoch] times, where each epoch consists of a number of [steps]. The optional
* argument [?id] can be used to save the network with a specific ID tag. *)
val train: network -> matrix -> int -> int -> ?id:string -> unit ->
(network * ((string * string) list))
(* [infer n m] is the index of the max value in the probability
* distribution given by matrix [m].
* requires: [m] is a vector. *)
val infer: network -> matrix -> int
(* [save_m m id] saves the weights and biases of all the layers in model [m] as
* matrix text files. It returns an association list of (weight file names,
* bias file names). The saved files are named "id-model-" followed by a number
* and whether it is a weight (wgt) or bias (bias). *)
val save_m: string -> model -> (string * string) list
(* [save_net id n] saves the model in network [n] with ID [id] in the matrices
* directory. The files associated with the model are prefixed with
* "saved_net-[id]". *)
val save_net: string -> network -> (string * string) list