Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add UnsafePointer generic .store() #3720

Draft
wants to merge 2 commits into
base: nightly
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions stdlib/src/memory/memory.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from sys import (
from collections import Optional
from builtin.dtype import _integral_type_of
from memory.pointer import AddressSpace, _GPUAddressSpace
from .unsafe_pointer import _default_alignment

# ===----------------------------------------------------------------------=== #
# Utilities
Expand Down Expand Up @@ -288,6 +289,65 @@ fn memset[
_memset_impl(ptr.bitcast[Byte](), value, count * sizeof[type]())


@always_inline
fn memset[
type: Movable, address_space: AddressSpace
](ptr: UnsafePointer[type, address_space], owned value: type):
"""Stores a single element value at the given offset by moving it.

Parameters:
type: The data type of the elements.
address_space: The address space of the pointer.

Args:
ptr: UnsafePointer to the memory address to store.
value: The value to move into the memory address.
"""

alias dt = DType.get_dtype[T]()

@parameter
if dt is not DType.invalid:
memset(ptr.bitcast[SIMD[dt, 1]]().offset(offset), value^)
else:
(ptr + offset).init_pointee_move(value^)

@always_inline("nodebug")
fn memset[
type: DType,
address_space: AddressSpace,
alignment: Int,
width: Int, //,
*,
volatile: Bool = False,
](
ptr: UnsafePointer[Scalar[type], address_space, alignment],
owned value: SIMD[type, width],
):
"""Stores a single element value at the given offset by moving it.

Parameters:
type: The data type of SIMD vector elements.
address_space: The address space of the pointer.
alignment: The minimal alignment of the address.
width: The size of the SIMD vector.
volatile: Whether the operation is volatile or not.

Args:
ptr: UnsafePointer to the memory address to store.
value: The value to store.
"""
@parameter
if volatile:
__mlir_op.`pop.store`[
alignment = alignment.value, isVolatile = __mlir_attr.unit
](value, ptr.bitcast[SIMD[type, width]]().address)
else:
__mlir_op.`pop.store`[alignment = alignment.value](
value, ptr.bitcast[SIMD[type, width]]().address
)


# ===----------------------------------------------------------------------===#
# memset_zero
# ===----------------------------------------------------------------------===#
Expand Down
150 changes: 67 additions & 83 deletions stdlib/src/memory/unsafe_pointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -574,15 +574,53 @@ struct UnsafePointer[

@always_inline
fn store[
T: IntLike,
type: DType, //,
*,
alignment: Int = _default_alignment[type](),
volatile: Bool = False,
I: IntLike, T: Movable, //
](
inout self: UnsafePointer[T, AddressSpace.GENERIC, *_, **_],
offset: I,
owned value: T,
):
"""Stores a single element value at the given offset by moving it.

Parameters:
I: The type of offset, either `Int` or `UInt`.
T: The data type of the elements.

Args:
offset: The offset to store to.
value: The value to store.
"""
memset(self.offset(offset), value^)

@always_inline
fn store[
I: IntLike, T: Copyable, //
](
inout self: UnsafePointer[T, AddressSpace.GENERIC, *_, **_],
offset: I,
value: T,
count: Int,
):
"""Stores values at the given offset count times by copying it.

Parameters:
I: The type of offset, either `Int` or `UInt`.
T: The data type of the elements.

Args:
offset: The offset to store to.
value: The value to store.
count: The amount of times to copy the value.
"""
memset(self.offset(offset), value, count)

@always_inline
fn store[
T: IntLike, type: DType, //, *, volatile: Bool
](
self: UnsafePointer[Scalar[type], *_, **_],
offset: T,
val: Scalar[type],
value: Scalar[type],
):
"""Stores a single element value at the given offset.

Expand All @@ -593,27 +631,21 @@ struct UnsafePointer[
Parameters:
T: The type of offset, either `Int` or `UInt`.
type: The data type of SIMD vector elements.
alignment: The minimal alignment of the address.
volatile: Whether the operation is volatile or not.

Args:
offset: The offset to store to.
val: The value to store.
value: The value to store.
"""
self.offset(offset)._store[alignment=alignment, volatile=volatile](val)
memset[volatile=volatile](self.offset(offset), value)

@always_inline
fn store[
T: IntLike,
type: DType,
width: Int, //,
*,
alignment: Int = _default_alignment[type, width](),
volatile: Bool = False,
T: IntLike, type: DType, width: Int, //, *, volatile: Bool = False
](
self: UnsafePointer[Scalar[type], *_, **_],
offset: T,
val: SIMD[type, width],
value: SIMD[type, width],
):
"""Stores a single element value at the given offset.

Expand All @@ -625,26 +657,21 @@ struct UnsafePointer[
T: The type of offset, either `Int` or `UInt`.
type: The data type of SIMD vector elements.
width: The size of the SIMD vector.
alignment: The minimal alignment of the address.
volatile: Whether the operation is volatile or not.

Args:
offset: The offset to store to.
val: The value to store.
value: The value to store.
"""
self.offset(offset).store[alignment=alignment, volatile=volatile](val)
memset[volatile=volatile](self.offset(offset), value)

@always_inline
fn store[
type: DType,
offset_type: DType, //,
*,
alignment: Int = _default_alignment[type](),
volatile: Bool = False,
type: DType, offset_type: DType, //, *, volatile: Bool
](
self: UnsafePointer[Scalar[type], *_, **_],
offset: Scalar[offset_type],
val: Scalar[type],
value: Scalar[type],
):
"""Stores a single element value at the given offset.

Expand All @@ -654,30 +681,26 @@ struct UnsafePointer[
Parameters:
type: The data type of SIMD vector elements.
offset_type: The data type of the offset value.
alignment: The minimal alignment of the address.
volatile: Whether the operation is volatile or not.

Args:
offset: The offset to store to.
val: The value to store.
value: The value to store.
"""
constrained[offset_type.is_integral(), "offset must be integer"]()
self.offset(int(offset))._store[alignment=alignment, volatile=volatile](
val
)
memset[volatile=volatile](self.offset(offset), value)

@always_inline
fn store[
type: DType,
width: Int,
offset_type: DType, //,
*,
alignment: Int = _default_alignment[type, width](),
volatile: Bool = False,
](
self: UnsafePointer[Scalar[type], *_, **_],
offset: Scalar[offset_type],
val: SIMD[type, width],
value: SIMD[type, width],
):
"""Stores a single element value at the given offset.

Expand All @@ -688,48 +711,37 @@ struct UnsafePointer[
type: The data type of SIMD vector elements.
width: The size of the SIMD vector.
offset_type: The data type of the offset value.
alignment: The minimal alignment of the address.
volatile: Whether the operation is volatile or not.

Args:
offset: The offset to store to.
val: The value to store.
value: The value to store.
"""
constrained[offset_type.is_integral(), "offset must be integer"]()
self.offset(int(offset))._store[alignment=alignment, volatile=volatile](
val
)
memset[volatile=volatile](self.offset(offset), value)

@always_inline("nodebug")
fn store[
type: DType, //,
*,
alignment: Int = _default_alignment[type](),
volatile: Bool = False,
](self: UnsafePointer[Scalar[type], *_, **_], val: Scalar[type]):
type: DType, //, *, volatile: Bool
](self: UnsafePointer[Scalar[type], *_, **_], value: Scalar[type]):
"""Stores a single element value.

Constraints:
The width and alignment must be positive integer values.

Parameters:
type: The data type of SIMD vector elements.
alignment: The minimal alignment of the address.
volatile: Whether the operation is volatile or not.

Args:
val: The value to store.
value: The value to store.
"""
self._store[alignment=alignment, volatile=volatile](val)
memset[volatile=volatile](self, value)

@always_inline("nodebug")
fn store[
type: DType,
width: Int, //,
*,
alignment: Int = _default_alignment[type, width](),
volatile: Bool = False,
](self: UnsafePointer[Scalar[type], *_, **_], val: SIMD[type, width]):
type: DType, width: Int, //, *, volatile: Bool = False
](self: UnsafePointer[Scalar[type], *_, **_], value: SIMD[type, width]):
"""Stores a single element value.

Constraints:
Expand All @@ -738,36 +750,12 @@ struct UnsafePointer[
Parameters:
type: The data type of SIMD vector elements.
width: The size of the SIMD vector.
alignment: The minimal alignment of the address.
volatile: Whether the operation is volatile or not.

Args:
val: The value to store.
value: The value to store.
"""
self._store[alignment=alignment, volatile=volatile](val)

@always_inline("nodebug")
fn _store[
type: DType,
width: Int,
*,
alignment: Int = _default_alignment[type, width](),
volatile: Bool = False,
](self: UnsafePointer[Scalar[type], *_, **_], val: SIMD[type, width]):
constrained[width > 0, "width must be a positive integer value"]()
constrained[
alignment > 0, "alignment must be a positive integer value"
]()

@parameter
if volatile:
__mlir_op.`pop.store`[
alignment = alignment.value, isVolatile = __mlir_attr.unit
](val, self.bitcast[SIMD[type, width]]().address)
else:
__mlir_op.`pop.store`[alignment = alignment.value](
val, self.bitcast[SIMD[type, width]]().address
)
memset[volatile=volatile](self, value)

@always_inline("nodebug")
fn strided_load[
Expand All @@ -792,9 +780,7 @@ struct UnsafePointer[

@always_inline("nodebug")
fn strided_store[
type: DType,
T: Intable, //,
width: Int,
type: DType, T: Intable, //, width: Int
](
self: UnsafePointer[Scalar[type], *_, **_],
val: SIMD[type, width],
Expand All @@ -818,9 +804,7 @@ struct UnsafePointer[
type: DType, //,
*,
width: Int = 1,
alignment: Int = alignof[
SIMD[type, width]
]() if triple_is_nvidia_cuda() else 1,
alignment: Int = _default_alignment[type, width](),
](
self: UnsafePointer[Scalar[type], *_, **_],
offset: SIMD[_, width],
Expand Down
Loading