如何批量删除琐碎的 if 语句

问题描述 投票:0回答:1

我正在重构一个大型代码库,重构的一部分包括删除一些检查。我想不出一种有效的方法。 例如,如果旧代码类似于:

if A:
  print("branch 1")
else:
  print("branch 2")

我想删除 A 并使其始终为

True
。所以我的新代码看起来像这样:

if True:
  print("branch 1")
else:
  print("branch 2") # unreachable

现在我想将代码简化为:

print("branch 1")

一些不同的场景使这个问题变得复杂: 场景一:

if A and B:
  print("branch 1")
else:
  print("branch 2")

将变成:

if B:
  print("branch 1")
else:
  print("branch 2")

场景2:

if A:
  print("branch 1")
  return
print("branch 2")

将变成:

print("branch 1")
  return

我想不出一种有效的方法。我尝试运行几个不同的 linter,看看它们是否会建议简化,但没有成功。我也不太擅长编写自己的脚本 🙃

python refactoring codemod
1个回答
0
投票

这可以使用 Python 的 AST 模块 来实现。以下代码将用

A
替换变量
True
的所有实例(可以使用
VARS
变量修改/扩展),并且还将简化布尔表达式(例如,
if True and X:
变为
if X:
)和简单的
if
语句(删除显然永远不会运行的案例)。要将其应用到名为
file.py
的文件,请将此代码保存在一个文件中(我假设该文件名为
if_simplifier.py
并且位于同一目录中)并从命令行运行以下命令:

python if_simplifier.py file.py

两个注意事项:

  1. 虽然我在几个相关示例上测试了此代码,但我可能错过了一个重要的边缘情况,因此请确保任何重新格式化的代码看起来正确,并确保测试新代码以确保代码仍然有效。
  2. 虽然
    ast.unparse()
    的工作方式保证缩进不会不好,但此代码不会删除无法访问的代码(如场景 2 中的
    print("branch 2")
    ),并且它确实删除了一些间距/格式,这将使结果代码可读性稍差。您可以手动或使用Vulture等工具检测无法访问的代码,并且可以使用Black(也可作为VS Code的扩展)等工具重新格式化代码以恢复可读性。
import ast
import sys
from typing import Any

VARS = {"A": True}


class Visitor(ast.NodeTransformer):
    def __init__(self, variables: dict[str, Any] | None = None):
        super().__init__()
        self.variables = variables or {}

    def visit_Name(self, node: ast.Name) -> ast.AST:
        self.generic_visit(node)
        if node.id in self.variables:
            return ast.Constant(value=self.variables[node.id])
        return node

    def visit_If(self, node: ast.If) -> ast.AST | list[ast.stmt]:
        self.generic_visit(node)
        if isinstance(node.test, ast.Constant):
            if node.test.value:
                return node.body
            else:
                return node.orelse
        return node

    def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
        self.generic_visit(node)
        if isinstance(node.op, ast.Not) and isinstance(node.operand, ast.Constant):
            return ast.Constant(value=not node.operand.value)
        return node

    def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST:
        self.generic_visit(node)
        if isinstance(node.op, ast.Or):
            new_values = []
            for value in node.values:
                if isinstance(value, ast.Constant):
                    if value.value:
                        return value
                else:
                    new_values.append(value)
            if len(new_values) == 0:
                return ast.Constant(value=False)
            elif len(new_values) == 1:
                return new_values[0]
            else:
                node.values = new_values
        elif isinstance(node.op, ast.And):
            new_values = []
            for value in node.values:
                if isinstance(value, ast.Constant):
                    if not value.value:
                        return value
                else:
                    new_values.append(value)
            if len(new_values) == 0:
                return ast.Constant(value=True)
            elif len(new_values) == 1:
                return new_values[0]
            else:
                node.values = new_values
        return node


def simplify_ifs(code: str, variables: dict[str, Any] | None = None) -> str:
    tree = ast.parse(code)
    visitor = Visitor(variables=variables)
    visitor.visit(tree)
    return ast.unparse(tree)


def main():
    path = sys.argv[1]
    code = ""
    with open(path, "r") as f:
        code = f.read()
    with open(path, "w") as f:
        f.write(simplify_ifs(code, variables=VARS))


if __name__ == "__main__":
    main()
© www.soinside.com 2019 - 2024. All rights reserved.