我需要在大型 df 的每一行上运行另一个库的算法,但在将代码转换为极坐标表达式以获得更好的性能时遇到问题。以下是几个 DF 示例:
df_products = pl.DataFrame({
'SKU':['apple','banana','carrot','date'],
'DESCRIPTION': [
"Wire Rope",
"Connector",
"Tap",
"Zebra"
],
'CATL3': [
"Fittings",
"Tube",
"Tools",
"Animal"
],
'YELLOW_CAT': [
"Rope Accessories",
"Tube Fittings",
"Forming Taps",
"Striped"
],
'INDEX': [0, 5, 25, 90],
'EMBEDDINGS': [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10,11,12]
],
})
df_items_sm_ex = pl.DataFrame({
'PRODUCT_INFO':['apple','banana','carrot'],
'SEARCH_SIMILARITY_SCORE': [
[1., 0.87, 0.54, 0.33],
[1., 0.83, 0.77, 0.55],
[1., 0.92, 0.84, 0.65]
],
'SEARCH_POSITION': [
[0, 5, 25, 90],
[1, 2, 151, 373],
[3, 5, 95, 1500]
],
'SKU':['apple','banana','carrot'],
'YELLOW_CAT': [
"Rope Accessories",
"Tube Fittings",
"Forming Taps"
],
'CATL3': [
"Fittings",
"Tube",
"Tools"
],
'EMBEDDINGS': [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
],
})
现在是代码
df_items_sm_ex.select(
pl.struct(df_items_sm_ex.columns)
.map_elements(lambda row: build_complements(
row, df_items, rfc, rfc_comp, engine, current_datetime
)))
def build_complements(row, df_products, ml, ml_comp, engine, current_datetime):
try:
#step 1 - generate the base new dataframe
output_df = build_candidate_dataframe(row, df_products)
#step 2 - preprocess / clean / run predictions on the dataframe
output_df = process_candidate_output(df_products, output_df, ml, ml_comp)
#step 3 write dataframes to SQL
write_validate_complements(output_df, row, current_datetime, engine)
except Exception as e:
print(f'exception: {repr(e)}')
def build_candidate_dataframe(row, df_products):
df_len = len(row['SEARCH_SIMILARITY_SCORE'])
schema = {'QUERY': str,
'SIMILARITY_SCORE': pl.Float32,
'POSITION': pl.Int64,
'QUERY_SKU': str,
'QUERY_LEAF': str,
'QUERY_CAT': str,
'QUERY_EMBEDDINGS': pl.List(pl.Float32)
}
output_df = pl.DataFrame({'QUERY': [row['PRODUCT_INFO']] * df_len,
'SIMILARITY_SCORE': row['SEARCH_SIMILARITY_SCORE'],
'POSITION': row['SEARCH_POSITION'],
'QUERY_SKU': [row['SKU']] * df_len,
'QUERY_LEAF': [row['YELLOW_CAT']] * df_len,
'QUERY_CAT': [row['CATL3']] * df_len,
'QUERY_EMBEDDINGS': [row['EMBEDDINGS']] * df_len
}, schema=schema).sort("SIMILARITY_SCORE", descending=True)
output_df = output_df.join(df_products[['SKU', 'EMBEDDINGS', 'INDEX', 'DESCRIPTION', 'CATL3', 'YELLOW_CAT']], left_on=['POSITION'], right_on=['INDEX'], how='left')
output_df = output_df.rename({"DESCRIPTION": "SIMILAR_PRODUCT_INFO", "CATL3": "SIMILAR_PRODUCT_CAT", "YELLOW_CAT": "SIMILAR_PRODUCT_LEAF"})
return output_df
def process_candidate_output(df_products, output_df, ml, ml_comp):
combined_embeddings = (output_df.to_pandas()['QUERY_EMBEDDINGS'] + output_df.to_pandas()['EMBEDDINGS']) / 2
output_df = output_df.with_columns(pl.Series(name='COMBINED_EMBEDDINGS', values=combined_embeddings))
output_df = output_df[['QUERY', 'QUERY_SKU', 'QUERY_CAT', 'QUERY_LEAF', 'SIMILAR_PRODUCT_INFO', 'SIMILAR_PRODUCT_CAT', 'SIMILAR_PRODUCT_LEAF', 'SIMILARITY_SCORE', 'COMBINED_EMBEDDINGS', 'SKU', 'POSITION']]
output_df = output_df.filter(
pl.col('SKU') != output_df['QUERY_SKU'][0]
)
#ML predictions
output_df = predict_complements(output_df, ml)
output_df = output_df.filter(
pl.col('COMPLEMENTARY_PREDICTIONS') == 1
)
#Other ML predictions
output_df = predict_required_accessories(output_df, ml_comp)
output_df = output_df.sort(by='LABEL_PROBABILITY', descending=True)
return output_df
对于 df 的每一行,您将数据从极坐标中提取到 python 对象中,然后重新启动极坐标 df。这在时间和内存上都非常昂贵。相反,您应该将所有内容保留在极地记忆中,除非由于某种原因您绝对不能这样做。
不要费心使用
map_elements
,因为你不会将任何内容返回到 df,只需使用常规的 for 循环即可。我还将您的两个功能合并到了这个中:
def build_and_process_candidate_output(i, df_items_sm_ex, df_products, ml, ml_comp):
output_df = (
df_items_sm_ex[i]
.explode("SEARCH_SIMILARITY_SCORE", "SEARCH_POSITION")
.rename(
{
"PRODUCT_INFO": "QUERY",
"SEARCH_SIMILARITY_SCORE": "SIMILARITY_SCORE",
"SEARCH_POSITION": "POSITION",
"SKU": "QUERY_SKU",
"YELLOW_CAT": "QUERY_LEAF",
"CATL3": "QUERY_CAT",
"EMBEDDINGS": "QUERY_EMBEDDINGS",
}
)
.join(
df_products.select(
"SKU", "EMBEDDINGS", "INDEX", "DESCRIPTION", "CATL3", "YELLOW_CAT"
),
left_on=["POSITION"],
right_on=["INDEX"],
how="left",
)
.rename(
{
"DESCRIPTION": "SIMILAR_PRODUCT_INFO",
"CATL3": "SIMILAR_PRODUCT_CAT",
"YELLOW_CAT": "SIMILAR_PRODUCT_LEAF",
}
)
.select(
"QUERY",
"QUERY_SKU",
"QUERY_CAT",
"QUERY_LEAF",
"SIMILAR_PRODUCT_INFO",
"SIMILAR_PRODUCT_CAT",
"SIMILAR_PRODUCT_LEAF",
"SIMILARITY_SCORE",
# This assumes these are the same length
(pl.col("QUERY_EMBEDDINGS").explode() + pl.col("EMBEDDINGS").explode())
.implode()
.over("POSITION")
.alias("COMBINED_EMBEDDINGS"),
"SKU",
"POSITION",
)
# Given the sample data, this filter makes everything go away
# which is why supplying good sample data is important
.filter(pl.col("SKU") != pl.col("QUERY_SKU").first())
)
# ML predictions
output_df = predict_complements(output_df, ml)
output_df = output_df.filter(pl.col("COMPLEMENTARY_PREDICTIONS") == 1)
# Other ML predictions
output_df = predict_required_accessories(output_df, ml_comp)
output_df = output_df.sort(by="LABEL_PROBABILITY", descending=True)
return output_df