我是 Halide 的初学者,在大约 1000 行 Halide 程序中遇到错误。 我已将其缩小到最小尺寸:
#include "Halide.h"
namespace {
using namespace Halide;
Var x("x"), y("y"), c("c"), yo("yo"), yi("yi");
class Cord : public Halide::Generator<Cord> {
public:
Input<Buffer<int32_t, 3>> G{ "G" };
Output<Buffer<int32_t, 3>> output{"output"};
Cord():Q("Q"), M("M") {}
void generate() {
// Algorithm
Func G2("G2");
G2(x, y, c) = G(x, y, c)*G(x, y, c);
M(x, y, c) = 2*G2(x, y, c);
Q(x, y, c) = 3*G2(x, y, c);// + M(x, y, c);
output(x, y, c) = Q(x, y, c) + M(x, y, c);
// Schedule
if (!using_autoscheduler()) {
int strip_size = 480;
G2.store_root().compute_at(Q, yi);
//G2.compute_root();
Q.compute_root()
.split(y, yo, yi, strip_size / 2);
}
}
private:
Func Q, M;
};
} // namespace
HALIDE_REGISTER_GENERATOR(Cord, cord)
给出构建错误:Func“G2$0”在以下无效位置计算... 如果我将 G2 时间表替换为注释的时间表:G2.compute_root(),它就可以正常工作。我想使用原始时间表,因为它在派生此代码的原始计算中速度更快。
为什么原来的时间表行不通?我猜测这与需要 G2 来计算 Q 以及计算用于计算输出的 M 有关。从输出中删除 M 项即可使其工作。
实际的程序,这是一个减少,有,而不是输出(x,y,c)= Q(x,y,c)+ M(x,y,c);对 Q 进行一长串计算,并分别对 M 进行一系列计算,然后将每个计算的结果以需要每个域大小相等的方式组合起来。这些长计算中有多个compute_root()。这就是为什么我做了 Q.compute_root() 而不是让它在输出循环中求值。
谢谢
原始计划不起作用,因为在 Halide 中,生产者的计算位置必须包含其所有消费者。我认为你希望 store at 位置包围消费者就足够了,但不幸的是 Halide 不支持这一点(因为我们必须推断 Q 上循环的界限以涵盖 Q 的所有用途,但也确保作为副作用计算的 G2 值足以满足 M) 中 G2 的所有用途。
您必须计算根 G2,或者计算两次。如果 G2 很小并且仅在 Q 和 M 的循环嵌套中的一处使用,则可以将其内联。否则,您可以使用clone_in制作G2的单独副本以在每个地方使用:
G2.clone_in(M).compute_at(M, yi);
G2.compute_at(Q, yi);