为了简单起见,仅以 ADD 为例。
在编译器后端,多重加法是由多个ADD指令的组合来组织的。例如
ADD(1, ADD(1,3))
。
但是由于它是结合律和交换律的,所以顺序以及哪两个先相加并不重要。因此它可以被视为等同于
ADD(1, 1, 3)
。
问题在于搜索树以找到模式
ADD(operand$0, operand$1, ...)
的有效子树的算法。
我的算法如下:
bool try_switch(
std::vector<NODE> &ret,
std::vector<NODE> &ops,
std::vector<NODE> &possible_shrink,
std::vector<NODE> &possible_grow
)
{
for (NODE shrink : possible_shrink) {
for (NODE grow : possible_grow) {
// avoid duplicate
if (!is_left_to_right(shrink, grow)) continue;
swap_nodes(ret, shrink, grow);
if (sub_nodes_match(ret, ops)) {
return true;
}
// Due to the change of picked node, the possiblity of further swich will be adjusted
Adjust_possible(shrink, grow, possible_shrink, possible_grow)
if(try_switch(ret, ops, &possible_shrink, &possible_grow)) {
return ture;
}
// Revert the adjustment
swap_nodes(ret, grow, shrink);
Adjust_possible(grow, shrink, possible_shrink, possible_grow)
}
}
return false;
}
bool match_ADD(int num_of_ops, std::vector<NODE> ops)
{
std::vector<NODE> possible_shrink;
std::vector<NODE> possible_grow;
std::vector<NODE> ret = find_n_mins_one_ADD(num_of_ops, &possible_shrink, &possible_grow);
if (sub_nodes_match(ret, ops)) {
return true;
}
return try_switch(ret, ops, &possible_shrink, &possible_grow);
}
关键问题是,对于某些可能的树节点,它们可以是
ADD
,也可以是ADD(operand$0, operand$1, ...)
的操作数。
我的想法是匹配
operand$0,...,operand$n
我需要一棵子树中 ADD
的 $n-1$ 个节点,并将该子树的所有叶子与目标 operand$i
匹配,如果所有它们匹配,则匹配成功。
如果没有,我需要放弃创建的子树中的一个节点,并用新的
ADD
节点替换它,形成一棵新的子树,并检查当前子树的所有叶子是否与operand$0, operand$1, ...
匹配
此外,通过交换节点,它会创建或删除新的可能性,因此我需要进行递归调用来回退并尝试其他可能性。
如果我没记错的话,前面算法的正确性是可以证明的,但是效率高吗??
或者是否有更好的算法我不知道??
鉴于您已经拥有有效的表达式树,您可以使用简单的递归算法连接
Add
或 Mul
。
该算法以线性时间运行,并且纯粹是自下而上的模式匹配简化器。
我留给你弄清楚如何用 C++ 编写它。
节点的顺序并不重要。
ADD(ADD(x, y), z)
ADD(x, ADD(y, z))
这两个都将被简化为 ADD(x, y, z),因为您要遍历所有子项,首先将它们展平,然后收集它们的参数(如果它们是相同的操作)。
def flatten_expr(expr):
match expr:
case Add(args):
flattened = []
for arg in args:
flat_arg = flatten_expr(arg)
match flat_arg:
case Add(nested_args):
flattened.extend(nested_args)
case _:
flattened.append(flat_arg)
return Add(flattened)
case Mul(args):
flattened = []
for arg in args:
flat_arg = flatten_expr(arg)
match flat_arg:
case Mul(nested_args):
flattened.extend(nested_args)
case _:
flattened.append(flat_arg)
return Mul(flattened)
case _:
return expr