Skip to content

Commit

Permalink
added explicit dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Aug 14, 2024
1 parent 13eccf2 commit 66531b5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ int main(int argc, char *argv[]) {
torch::inductor::AOTIModelContainerRunnerCuda *runner;
runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path, 1);
std::vector<torch::Tensor> inputs = {
torch::randn({32,1}, at::kCUDA),
torch::randn({32,3}, at::kCUDA),
torch::randn({2,50}, at::kCUDA),
torch::randn({32,1}, at::kCUDA)
torch::randn({32,1}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)),
torch::randn({32,3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)),
torch::randn({2,50}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)),
torch::randn({32}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA))
};
std::vector<torch::Tensor> outputs = runner->run(inputs);
std::cout << "Result from the first inference:"<< std::endl;
Expand Down

0 comments on commit 66531b5

Please sign in to comment.