From 66531b5e5cfbce28c48117a931f6ccf281b7f006 Mon Sep 17 00:00:00 2001 From: Mit Kotak Date: Wed, 14 Aug 2024 03:22:20 -0400 Subject: [PATCH] added explicit dtypes --- inference.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/inference.cpp b/inference.cpp index ff405e3..f6a46be 100644 --- a/inference.cpp +++ b/inference.cpp @@ -13,10 +13,10 @@ int main(int argc, char *argv[]) { torch::inductor::AOTIModelContainerRunnerCuda *runner; runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path, 1); std::vector inputs = { - torch::randn({32,1}, at::kCUDA), - torch::randn({32,3}, at::kCUDA), - torch::randn({2,50}, at::kCUDA), - torch::randn({32,1}, at::kCUDA) + torch::randn({32,1}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)), + torch::randn({32,3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)), + torch::randn({2,50}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)), + torch::randn({32}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)) }; std::vector outputs = runner->run(inputs); std::cout << "Result from the first inference:"<< std::endl;