PyTorch c ++扩展中的可选张量

问题描述 投票:1回答:2

我正在为pytorch编写C ++扩展,并使用c ++ api这样做。对于我的forward函数,我需要传递一个可选的张量。在函数内部,我想根据是否传递了这个可选参数来做不同的事情。通常,我们在C ++中对可选指针参数使用NULL,如果指针为NULL则检查函数内部。我不知道如何为at::Tensor类型的Torch的c ++ api做这个。

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints = something)
{
     if(optional_constraints){
        //do something
     }else{
        //do something else
     }
}

请注意,我不能做const at::Tensor optional_constraints = at::ones或其他什么,因为该参数可以取任何实际值,并且可以具有不同的大小/形状。我不能将数值赋值为可选参数。有没有相同的NULL

c++ pytorch torch
2个回答
0
投票

因为我找不到任何类似的例子。在noArray()中,OpenCV API(主要用于传递视觉矩阵,如掩码),我建议你使用重载函数用于此目的

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2)
{
     // optional tensor wasnt passed
}

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints)
{
     // optional tensor passed
}

0
投票

一种可能性是使用std::optional作为std::optional<at::Tensor> optional_constraints = std::nullopt。它在上下文中可以转换为bool,因此您可以使用if (optional_constraints)进行检查。如果传递一个,则使用.value()方法获取张量,否则默认值为std::nullopt

© www.soinside.com 2019 - 2024. All rights reserved.