From 5226290207f7ff6d1264c043388bc2b520458ce0 Mon Sep 17 00:00:00 2001 From: Kris Rowe Date: Fri, 8 Dec 2023 14:30:08 -0600 Subject: [PATCH] Fixes corruption of memory datatypes caused by short circuit logic in `modeMemory_t::slice` (#727) --- src/occa/internal/core/memory.cpp | 4 ---- tests/src/core/memory.cpp | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/occa/internal/core/memory.cpp b/src/occa/internal/core/memory.cpp index 4f98132ad..fbac13a11 100644 --- a/src/occa/internal/core/memory.cpp +++ b/src/occa/internal/core/memory.cpp @@ -77,10 +77,6 @@ namespace occa { modeMemory_t* modeMemory_t::slice(const dim_t offset_, const udim_t bytes) { - - //quick return if we're not really slicing - if ((offset_ == 0) && (bytes == size)) return this; - OCCA_ERROR("ModeMemory not initialized or has been freed", modeBuffer != NULL); diff --git a/tests/src/core/memory.cpp b/tests/src/core/memory.cpp index 8dac8e9cb..63cc509a2 100644 --- a/tests/src/core/memory.cpp +++ b/tests/src/core/memory.cpp @@ -4,11 +4,13 @@ void testMalloc(); void testSlice(); void testUnwrap(); +void testCast(); int main(const int argc, const char **argv) { testMalloc(); testSlice(); testUnwrap(); + testCast(); return 0; } @@ -154,3 +156,18 @@ void testUnwrap() { delete[] host_memory; } + +void testCast() { + occa::device occa_device({{"mode", "Serial"}}); + + occa::memory occa_memory = occa_device.malloc(10); + + ASSERT_TRUE(occa::dtype::double_ == occa_memory.dtype()); + + occa::memory casted_memory = occa_memory.cast(occa::dtype::byte); + + ASSERT_TRUE(occa::dtype::double_ == occa_memory.dtype()); + ASSERT_TRUE(occa::dtype::byte == casted_memory.dtype()); + + ASSERT_EQ(occa_memory.size(), casted_memory.size()); +}