-
Notifications
You must be signed in to change notification settings - Fork 4
/
tensorutils.mojo
87 lines (81 loc) · 2.48 KB
/
tensorutils.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2023, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
from math import mod, trunc
from tensor import Tensor
fn tensorprint[type: DType](t: Tensor[type]) -> None:
var rank = t.rank()
var dim0: Int = 0
var dim1: Int = 0
var dim2: Int = 0
if rank == 0 or rank > 3:
print("Error: Tensor rank should be: 1,2, or 3. Tensor rank is ", rank)
return
if rank == 1:
dim0 = 1
dim1 = 1
dim2 = t.dim(0)
if rank == 2:
dim0 = 1
dim1 = t.dim(0)
dim2 = t.dim(1)
if rank == 3:
dim0 = t.dim(0)
dim1 = t.dim(1)
dim2 = t.dim(2)
var val: SIMD[type, 1] = 0.0
for i in range(dim0):
if i == 0 and rank == 3:
print("[")
else:
if i > 0:
print()
for j in range(dim1):
if rank != 1:
if j == 0:
print_no_newline(" [")
else:
print_no_newline("\n ")
print_no_newline("[")
for k in range(dim2):
if rank == 1:
val = t[k]
if rank == 2:
val = t[j, k]
if rank == 3:
val = t[i, j, k]
var int_str = String(trunc(val).cast[DType.int32]())
var float_str = String(mod(val, 1))
var s = int_str + "." + float_str[2:6]
if k == 0:
print_no_newline(s)
else:
print_no_newline(" ", s)
print_no_newline("]")
if rank > 1:
print_no_newline("]")
print()
if rank == 3:
print("]")
print(
"Tensor shape:",
t.shape().__str__(),
", Tensor rank:",
rank,
",",
"DType:",
type.__str__(),
)
print()
fn main():
var t = Tensor[DType.float32](2, 2)
tensorprint(t)