Skip to content

Commit

Permalink
[stdlib] Introduce fixed-width bin and hex
Browse files Browse the repository at this point in the history
Signed-off-by: Yiwu Chen <210at85@gmail.com>
  • Loading branch information
soraros committed Oct 22, 2024
1 parent 92e2230 commit 7b88131
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 1 deletion.
2 changes: 2 additions & 0 deletions stdlib/src/bit/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ from .bit import (
pop_count,
rotate_bits_left,
rotate_bits_right,
bin,
hex,
)
78 changes: 78 additions & 0 deletions stdlib/src/bit/bit.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ from bit import count_leading_zeros
```
"""

from memory import bitcast, unpack_bits
from sys import llvm_intrinsic, sizeof
from sys.info import bitwidthof
from utils import Span

# ===----------------------------------------------------------------------===#
# count_leading_zeros
Expand Down Expand Up @@ -654,3 +656,79 @@ fn rotate_bits_right[
return llvm_intrinsic["llvm.fshr", __type_of(x), has_side_effect=False](
x, x, SIMD[type, width](shift)
)


# ===----------------------------------------------------------------------===#
# bin and hex
# ===----------------------------------------------------------------------===#


fn bin[dtype: DType, //](x: Scalar[dtype]) -> String:
"""Converts a scalar to a binary string.
Parameters:
dtype: The data type of the input scalar.
Args:
x: The input scalar value.
Returns:
A binary string representation of the input scalar value.
"""
alias len = dtype.bitwidth() + 1
buff = String._buffer_type(capacity=len)
_write_bin(x, buff)
buff.size = len
return String(impl=buff)


@always_inline
fn _write_bin(x: Scalar, s: Span[Byte, _]):
alias `0` = ord("0")

@parameter
if x.type.sizeof() == 1:
r = x
else:
r = byte_swap(x)
bytes = unpack_bits(r).cast[DType.uint8]()
s.unsafe_ptr().store(bytes + `0`)


# fmt: off
alias _table = SIMD[DType.uint8, 16](
ord("0"), ord("1"), ord("2"), ord("3"), ord("4"), ord("5"), ord("6"), ord("7"),
ord("8"), ord("9"), ord("a"), ord("b"), ord("c"), ord("d"), ord("e"), ord("f"),
)
# fmt: on


fn hex[dtype: DType, //](x: Scalar[dtype]) -> String:
"""Converts a scalar to a hexadecimal string.
Parameters:
dtype: The data type of the input scalar.
Args:
x: The input scalar value.
Returns:
A hexadecimal string representation of the input scalar value.
"""
alias len = dtype.sizeof() * 2 + 1
buff = String._buffer_type(capacity=len)
_write_hex(x, buff)
buff.size = len
return String(impl=buff)


@always_inline
fn _write_hex(x: Scalar, s: Span[Byte, _]):
@parameter
if x.type.sizeof() == 1:
r = x
else:
r = byte_swap(x)
bytes = bitcast[DType.uint8, x.type.sizeof()](r)
nibbles = (bytes >> 4).interleave(bytes & 0xF)
s.unsafe_ptr().store(_table._dynamic_shuffle(nibbles))
2 changes: 1 addition & 1 deletion stdlib/src/memory/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ from .arc import Arc
from .box import Box
from .memory import memcmp, memcpy, memset, memset_zero, stack_allocation
from .pointer import AddressSpace, Pointer
from .unsafe import bitcast
from .unsafe import bitcast, unpack_bits
from .unsafe_pointer import UnsafePointer
30 changes: 30 additions & 0 deletions stdlib/src/memory/unsafe.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,33 @@ fn bitcast[
return __mlir_op.`pop.bitcast`[
_type = __mlir_type[`!pop.scalar<`, new_type.value, `>`]
](val.value)


@always_inline("nodebug")
fn unpack_bits[
dtype: DType, //, width: Int = bitwidthof[dtype]()
](res: Scalar[dtype]) -> SIMD[DType.bool, width]:
"""Pack a scalar value into a SIMD vector of boolean values.
Parameters:
dtype: The data type of the input scalar value.
width: The width of the SIMD vector.
Constraints:
The bitwidth of the data type must be equal to the SIMD width.
Args:
res: The input scalar value.
Returns:
A SIMD vector where each element is a boolean value representing the
corresponding bit of the input scalar value.
"""
constrained[
bitwidthof[dtype]() == width,
"the bitwidth of the data type must be equal to the SIMD width",
]()
b = __mlir_op.`pop.bitcast`[
_type = __mlir_type[`!pop.simd<`, width.value, `, bool>`]
](res.value)
return SIMD[DType.bool, width](b)

0 comments on commit 7b88131

Please sign in to comment.