移动平均线的窗口函数

问题描述 投票:0回答:2

我正在尝试在 pandas 中复制 SQL 的窗口函数。

SELECT avg(totalprice) OVER (
    PARTITION BY custkey
    ORDER BY orderdate
    RANGE BETWEEN interval '1' month PRECEDING AND CURRENT ROW)
FROM orders

我有这个数据框:

from io  import StringIO
import pandas as pd

myst="""cust_1,2020-10-10,100
cust_2,2020-10-10,15
cust_1,2020-10-15,200
cust_1,2020-10-16,240
cust_2,2020-12-20,25
cust_1,2020-12-25,140
cust_2,2021-01-01,5

"""
u_cols=['customer_id', 'date', 'price']

myf = StringIO(myst)
import pandas as pd
df = pd.read_csv(StringIO(myst), sep=',', names = u_cols)
df=df.sort_values(list(df.columns))

计算限制为最近 1 个月的移动平均线后,它会看起来像这样......

from io  import StringIO
import pandas as pd

myst="""cust_1,2020-10-10,100,100
cust_2,2020-10-10,15,15
cust_1,2020-10-15,200,150
cust_1,2020-10-16,240,180
cust_2,2020-12-20,25,25
cust_1,2020-12-25,140,140
cust_2,2021-01-01,5,15

"""
u_cols=['customer_id', 'date', 'price', 'my_average']

myf = StringIO(myst)
import pandas as pd
my_df = pd.read_csv(StringIO(myst), sep=',', names = u_cols)
my_df=my_df.sort_values(list(my_df.columns))

如图所示:

https://trino.io/assets/blog/window-features/running-average-range.svg

我尝试写一个这样的函数...

import numpy as np
def mylogic(myro):
    mylist = list()
    mydate = myro['date'][0]
    for i in range(len(myro)):            
        if myro['date'][i] > mydate:
            mylist.append(myro['price'][i])
            mydate = myro['date'][i]
    return np.mean(mylist)

但这返回了一个 key_error。

pandas
2个回答
2
投票

您可以使用过去30天的滚动功能

df['date'] = pd.to_datetime(df['date'])    

df['my_average'] = (df.groupby('customer_id')
                      .apply(lambda d: d.rolling('30D', on='date')['price'].mean())
                      .reset_index(level=0, drop=True)
                      .astype(int)
                   )

输出:

  customer_id       date  price  my_average
0      cust_1 2020-10-10    100         100
2      cust_1 2020-10-15    200         150
3      cust_1 2020-10-16    240         180
5      cust_1 2020-12-25    140         140
1      cust_2 2020-10-10     15          15
4      cust_2 2020-12-20     25          25
6      cust_2 2021-01-01      5          15

0
投票
duckdb:

(
df.sql.select("*,date::date date2,date_trunc('month',date2) date3")
 .select("*,avg(price) over(partition by customer_id order by date2 range between 30 preceding and current row) avg2")
 .order("index")
)

┌───────┬─────────────┬────────────┬───────┬────────────┬────────────┬────────────┬────────┐
│ index │ customer_id │    date    │ price │ my_average │   date2    │   date3    │  avg2  │
│ int64 │   varchar   │  varchar   │ int64 │   int64    │    date    │    date    │ double │
├───────┼─────────────┼────────────┼───────┼────────────┼────────────┼────────────┼────────┤
│     0 │ cust_1      │ 2020-10-10 │   100 │        100 │ 2020-10-10 │ 2020-10-01 │  100.0 │
│     1 │ cust_2      │ 2020-10-10 │    15 │         15 │ 2020-10-10 │ 2020-10-01 │   15.0 │
│     2 │ cust_1      │ 2020-10-15 │   200 │        150 │ 2020-10-15 │ 2020-10-01 │  150.0 │
│     3 │ cust_1      │ 2020-10-16 │   240 │        180 │ 2020-10-16 │ 2020-10-01 │  180.0 │
│     4 │ cust_2      │ 2020-12-20 │    25 │         25 │ 2020-12-20 │ 2020-12-01 │   25.0 │
│     5 │ cust_1      │ 2020-12-25 │   140 │        140 │ 2020-12-25 │ 2020-12-01 │  140.0 │
│     6 │ cust_2      │ 2021-01-01 │     5 │         15 │ 2021-01-01 │ 2021-01-01 │   15.0 │
└───────┴─────────────┴────────────┴───────┴────────────┴────────────┴────────────┴────────┘
© www.soinside.com 2019 - 2024. All rights reserved.