Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add Pointer(to=) and UnsafePointer(to=) constructors #3606

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions mojo/stdlib/src/memory/pointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,15 @@ struct Pointer[
"""
self._value = _mlir_value

@always_inline("nodebug")
fn __init__(out self, *, ref [origin, address_space._value.value]to: type):
"""Constructs a Pointer from a reference to a value.

Args:
to: The value to construct a pointer to.
"""
self = Self(_mlir_value=__get_mvalue_as_litref(to))

@staticmethod
@always_inline("nodebug")
fn address_of(ref [origin, address_space]value: type) -> Self:
Expand Down
9 changes: 9 additions & 0 deletions mojo/stdlib/src/memory/unsafe_pointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ struct UnsafePointer[
"""
self.address = value

@always_inline("nodebug")
fn __init__(out self, *, ref [origin, address_space._value.value]to: type):
"""Constructs a Pointer from a reference to a value.

Args:
to: The value to construct a pointer to.
"""
self = Self(__mlir_op.`lit.ref.to_pointer`(__get_mvalue_as_litref(to)))

@always_inline
@implicit
fn __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# RUN: %mojo %s
from testing import assert_equal, assert_true
from testing import assert_equal, assert_true, assert_not_equal


def test_copy_reference_explicitly():
Expand Down Expand Up @@ -41,7 +41,13 @@ def test_str():
assert_true(String(a_ref).startswith("0x"))


def test_pointer_to():
var local = 1
assert_not_equal(0, Pointer(to=local)[])


def main():
test_copy_reference_explicitly()
test_equality()
test_str()
test_pointer_to()
6 changes: 6 additions & 0 deletions mojo/stdlib/test/memory/test_unsafepointer.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def test_address_of():
_ = local


def test_pointer_to():
var local = 1
assert_not_equal(0, UnsafePointer(to=local)[])


def test_explicit_copy_of_pointer_address():
var local = 1
var ptr = UnsafePointer[Int].address_of(local)
Expand Down Expand Up @@ -358,6 +363,7 @@ def test_volatile_load_and_store_simd():

def main():
test_address_of()
test_pointer_to()

test_refitem()
test_refitem_offset()
Expand Down