Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Fix make String.join() work on StringSlice #3677

Draft
wants to merge 8 commits into
base: nightly
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions stdlib/src/builtin/string_literal.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ from hashlib._hasher import _HashableWithHasher, _Hasher
from utils import StringRef, Span, StringSlice, StaticString
from utils import Writable, Writer
from utils._visualizers import lldb_formatter_wrapping_type

from utils.span import AsBytesRead
from utils.string_slice import (
_StringSliceIter,
_FormatCurlyEntry,
_CurlyEntryFormattable,
)
from collections.string import _atol

# ===----------------------------------------------------------------------===#
# StringLiteral
Expand All @@ -40,6 +41,7 @@ from utils.string_slice import (
struct StringLiteral(
Boolable,
Comparable,
CollectionElement,
CollectionElementNew,
Writable,
IntableRaising,
Expand All @@ -48,7 +50,7 @@ struct StringLiteral(
Sized,
Stringable,
FloatableRaising,
BytesCollectionElement,
AsBytesRead,
_HashableWithHasher,
):
"""This type represents a string literal.
Expand Down Expand Up @@ -415,6 +417,24 @@ struct StringLiteral(
len=self.byte_length(),
)

@always_inline
fn as_bytes_read[O: ImmutableOrigin, //](ref [O]self) -> Span[UInt8, O]:
"""Returns an immutable contiguous slice of the bytes.

Parameters:
O: The Origin of the bytes.

Returns:
An immutable contiguous slice pointing to the bytes.

Notes:
This does not include the trailing null terminator.
"""

return Span[UInt8, O](
unsafe_ptr=self.unsafe_ptr(), len=self.byte_length()
)

@always_inline
fn format[*Ts: _CurlyEntryFormattable](self, *args: *Ts) raises -> String:
"""Format a template with `*args`.
Expand Down Expand Up @@ -493,7 +513,25 @@ struct StringLiteral(
"""
return __mlir_op.`pop.string.replace`(self.value, old.value, new.value)

fn join[T: StringableCollectionElement](self, elems: List[T, *_]) -> String:
fn join[
T: StringableCollectionElement, //
](self, elems: List[T, *_]) -> String:
"""Joins string elements using the current string as a delimiter.

Parameters:
T: The types of the elements.

Args:
elems: The input values.

Returns:
The joined string.
"""
return self.as_string_slice().join(elems)

fn join_bytes[
T: BytesReadCollectionElement, //,
](self, elems: List[T, *_]) -> String:
"""Joins string elements using the current string as a delimiter.

Parameters:
Expand All @@ -505,7 +543,7 @@ struct StringLiteral(
Returns:
The joined string.
"""
return str(self).join(elems)
return self.as_string_slice().join_bytes(elems)

fn split(self, sep: String, maxsplit: Int = -1) raises -> List[String]:
"""Split the string literal by a separator.
Expand Down
8 changes: 4 additions & 4 deletions stdlib/src/builtin/value.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ trait StringableCollectionElement(CollectionElement, Stringable):
pass


trait BytesCollectionElement(CollectionElement, AsBytes):
"""The BytesCollectionElement trait denotes a trait composition
of the `CollectionElement` and `AsBytes`.
trait BytesReadCollectionElement(CollectionElement, AsBytesRead):
"""The BytesReadCollectionElement trait denotes a trait composition
of the `CollectionElement` and `AsBytesRead`.

This is useful to have as a named entity since Mojo does not
currently support anonymous trait compositions to constrain
on `CollectionElement & AsBytes` in the parameter.
on `CollectionElement & AsBytesRead` in the parameter.
"""

pass
Expand Down
83 changes: 25 additions & 58 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,9 @@ struct String(
_ = is_first
return result

fn join[T: StringableCollectionElement](self, elems: List[T, *_]) -> String:
fn join[
T: StringableCollectionElement, //
](self, elems: List[T, *_]) -> String:
"""Joins string elements using the current string as a delimiter.

Parameters:
Expand All @@ -1430,33 +1432,10 @@ struct String(
Returns:
The joined string.
"""
return self.as_string_slice().join(elems)

# TODO(#3403): Simplify this when the linked conditional conformance
# feature is added. Runs a faster algorithm if the concrete types are
# able to be converted to a span of bytes.
@parameter
if _type_is_eq[T, String]():
return self.fast_join(rebind[List[String]](elems))
elif _type_is_eq[T, StringLiteral]():
return self.fast_join(rebind[List[StringLiteral]](elems))
# FIXME(#3597): once StringSlice conforms to CollectionElement trait:
# if _type_is_eq[T, StringSlice]():
# return self.fast_join(rebind[List[StringSlice]](elems))
else:
var result: String = ""
var is_first = True

for e in elems:
if is_first:
is_first = False
else:
result += self
result += str(e[])

return result

fn fast_join[
T: BytesCollectionElement, //,
fn join_bytes[
T: BytesReadCollectionElement, //,
](self, elems: List[T, *_]) -> String:
"""Joins string elements using the current string as a delimiter.

Expand All @@ -1469,37 +1448,7 @@ struct String(
Returns:
The joined string.
"""
var n_elems = len(elems)
if n_elems == 0:
return String("")
var len_self = self.byte_length()
var len_elems = 0
# Calculate the total size of the elements to join beforehand
# to prevent alloc syscalls as we know the buffer size.
# This can hugely improve the performance on large lists
for e_ref in elems:
len_elems += len(e_ref[].as_bytes())
var capacity = len_self * (n_elems - 1) + len_elems
var buf = Self._buffer_type(capacity=capacity)
var self_ptr = self.unsafe_ptr()
var ptr = buf.unsafe_ptr()
var offset = 0
var i = 0
var is_first = True
while i < n_elems:
if is_first:
is_first = False
else:
memcpy(dest=ptr + offset, src=self_ptr, count=len_self)
offset += len_self
var e = elems[i].as_bytes()
var e_len = len(e)
memcpy(dest=ptr + offset, src=e.unsafe_ptr(), count=e_len)
offset += e_len
i += 1
buf.size = capacity
buf.append(0)
return String(buf^)
return self.as_string_slice().join_bytes(elems)

fn unsafe_ptr(self) -> UnsafePointer[UInt8]:
"""Retrieves a pointer to the underlying memory.
Expand Down Expand Up @@ -1535,6 +1484,24 @@ struct String(
unsafe_ptr=self._buffer.unsafe_ptr(), len=self.byte_length()
)

@always_inline
fn as_bytes_read[O: ImmutableOrigin, //](ref [O]self) -> Span[UInt8, O]:
"""Returns an immutable contiguous slice of the bytes.

Parameters:
O: The Origin of the bytes.

Returns:
An immutable contiguous slice pointing to the bytes.

Notes:
This does not include the trailing null terminator.
"""

return Span[UInt8, O](
unsafe_ptr=self.unsafe_ptr(), len=self.byte_length()
)

@always_inline
fn as_string_slice(ref [_]self) -> StringSlice[__origin_of(self)]:
"""Returns a string slice of the data owned by this string.
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/prelude/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ from builtin.value import (
Defaultable,
CollectionElement,
CollectionElementNew,
BytesCollectionElement,
BytesReadCollectionElement,
StringableCollectionElement,
EqualityComparableCollectionElement,
ComparableCollectionElement,
Expand Down Expand Up @@ -133,5 +133,5 @@ from collections.string import (
)
from hashlib.hash import hash, Hashable
from memory import Pointer, AddressSpace
from utils import AsBytes, Writable, Writer
from utils import AsBytes, AsBytesRead, Writable, Writer
from documentation import doc_private
2 changes: 1 addition & 1 deletion stdlib/src/utils/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .index import Index, IndexList, product
from .inline_string import InlineString
from .loop import unroll
from .span import AsBytes, Span
from .span import Span, AsBytes, AsBytesRead
from .static_tuple import StaticTuple
from .stringref import StringRef
from .string_slice import StaticString, StringSlice
Expand Down
20 changes: 20 additions & 0 deletions stdlib/src/utils/span.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@ trait AsBytes:
...


trait AsBytesRead:
"""The `AsBytesRead` trait denotes a type that can be returned as an
immutable byte span.
"""

fn as_bytes_read[O: ImmutableOrigin, //](ref [O]self) -> Span[Byte, O]:
"""Returns an immutable contiguous slice of the bytes.
Parameters:
O: The Origin of the bytes.
Returns:
An immutable contiguous slice pointing to the bytes.
Notes:
This does not include the trailing null terminator.
"""
...


@value
struct _SpanIter[
is_mutable: Bool, //,
Expand Down
Loading