我正在为pytorch编写C ++扩展,并使用c ++ api这样做。对于我的forward
函数,我需要传递一个可选的张量。在函数内部,我想根据是否传递了这个可选参数来做不同的事情。通常,我们在C ++中对可选指针参数使用NULL,如果指针为NULL则检查函数内部。我不知道如何为at::Tensor
类型的Torch的c ++ api做这个。
void xyz_forward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor optional_constraints = something)
{
if(optional_constraints){
//do something
}else{
//do something else
}
}
请注意,我不能做const at::Tensor optional_constraints = at::ones
或其他什么,因为该参数可以取任何实际值,并且可以具有不同的大小/形状。我不能将数值赋值为可选参数。有没有相同的NULL
?
因为我找不到任何类似的例子。在noArray()
中,OpenCV API(主要用于传递视觉矩阵,如掩码),我建议你使用重载函数用于此目的
void xyz_forward(
const at::Tensor xyz1,
const at::Tensor xyz2)
{
// optional tensor wasnt passed
}
void xyz_forward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor optional_constraints)
{
// optional tensor passed
}
一种可能性是使用std::optional作为std::optional<at::Tensor> optional_constraints = std::nullopt
。它在上下文中可以转换为bool
,因此您可以使用if (optional_constraints)
进行检查。如果传递一个,则使用.value()
方法获取张量,否则默认值为std::nullopt
。