Flask-SQLAlchemy 父级和子级的多个过滤器

问题描述 投票:0回答:1

我有以下父类:

class Pairing(Base):
    __tablename__ = 'pairing'
    id: Mapped[int] = mapped_column(Integer, unique=True, primary_key=True, autoincrement=True)
    pairing_no: Mapped[str] = mapped_column(String(10), unique=False, nullable=False)
    total_expenses: Mapped[str] = mapped_column(String(10), unique=False, nullable=True)
    pairing_tafb: Mapped[str] = mapped_column(String(10), unique=False, nullable=True)

    flights: Mapped[List["Flight"]] = relationship(
        back_populates="pairing", cascade='all, delete'
    )

以及以下子班:

class Flight(Base):
    __tablename__ = 'flight'
    id: Mapped[int] = mapped_column(Integer, unique=True, primary_key=True, autoincrement=True)
    destination_station: Mapped[str] = mapped_column(String, unique=False)

    pairing_no: Mapped[str] = mapped_column(ForeignKey('pairing.pairing_no'), unique=False)
    pairing: Mapped["Pairing"] = relationship(back_populates='flights')

我可以使用以下命令成功地在配对类上应用多个过滤器:

pairing_filter = {
        'pairing_no': request.form.get('search_pairing'),
        'pairing_tafb': request.form.get('pairing_tafb'),
        'total_expenses': request.form.get('total_expenses'),
    }
    pairing_filter = {key: value for (key, value) in pairing_filter.items() if value}
pairings_to_display = db.session.scalars(select(Pairing).filter_by(**pairing_filter).order_by(Pairing.id)).all()

但我想应用更多过滤器来过滤 Flight 类。例如,我想进一步过滤pairings_to_display结果,以仅显示包含降落在LHR的航班的配对。根据用户输入,可以应用任何过滤器组合。

我怎样才能实现这个目标?

flask sqlalchemy flask-sqlalchemy
1个回答
0
投票

要根据引用模型的属性过滤结果,需要事先使用ForeignKey进行连接。

stmt = db.select(Pairing) \
    .join(Flight) \
    .order_by(Pairing.id) 
key_bindings = { 
    'pairing_tafb': 'Pairing.pairing_tafb', 
    'search_pairing': 'Pairing.pairing_no',
    'total_expenses': 'Pairing.total_expenses',
    'dest': 'Flight.destination_station', 
}
for k,v in key_bindings.items():
    class_name,attr_name = v.split('.',1)
    if k in request.form:
        # Get class defined in the same module as the query by name.
        cls = globals().get(class_name.capitalize())
        stmt = stmt.where(getattr(cls, attr_name) == request.form[k])
results = db.session.scalars(stmt).all()
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.