Pyspark: неправильные результаты для вычисления min и avg после groupby

#apache-spark #pyspark #apache-spark-sql

#apache-spark #pyspark #apache-spark-sql

Вопрос:

У меня есть Spark dataframe с columns id ( date_from ) и price . Пример:

 id          date_from   price
10000012    2021-08-12  19283.334
10000012    2021-05-16  4400.0
10000012    2021-06-08  5718.69
10000012    2021-07-09  15283.333
10000012    2021-07-02  9087.5
10000012    2021-07-04  15283.333
10000012    2021-06-22  9061.111
10000012    2021-06-26  9076.667
10000012    2021-06-27  9080.77
10000012    2021-07-10  15283.333
10000012    2021-08-14  19283.334
10000012    2021-05-09  4400.0
10000012    2021-05-12  4400.0
10000012    2021-06-17  9065.64
10000012    2021-05-19  4400.0
10000166    2021-05-06  5801.4287
10000166    2021-04-01  4954.375
10000166    2021-04-22  5173.7856
10000166    2021-06-27  12655.429
10000166    2021-02-23  5167.5
 

Я хочу рассчитать минимальную цену и среднюю цену. Для этого я пытался:

 groupBy_id = ["id"]
aggregate = ["price"]
funs = [min, mean]
exprs = [f(col(c)) for f in funs for c in aggregate]
df = df.groupby(*groupBy_id).agg(*exprs)
 

А также:

 df = df.groupby("id").agg(min("price").alias("min(norm_price)"),avg("price").alias("avg(norm_price)"))
 

Но некоторые min(norm_price) значения больше avg(norm_price) единиц.
Вывод:

 id,min(norm_price),avg(norm_price)
10000012,11150.0,10287.276085889778
10000166,10370.761904761903,6082.360302835207
10000185,5054.642857142857,5424.533834586466
10000421,3990.0,3990.0
 

Что я делаю не так?

Ответ №1:

Вам нужно убедиться, что norm_price имеет тип double, а не string . В противном min случае будет возвращена минимальная строка, а не минимальное число.

 df = df.withColumn('price', col('price').cast('double'))
df = df.groupby(*groupBy_id).agg(*exprs)
 

Ответ №2:

Я сделал что-то довольно простое:

 from pyspark.sql.types import (
     StringType,
     StructField,
     StructType,
     FloatType
)
from pyspark.sql import functions as F

schema = StructType([
     StructField('id', StringType(), True),
     StructField('date', StringType(), True),
     StructField('price', FloatType(), True)
])
df = spark.read.csv(
        "price.csv",
        header='true',
        schema=schema
)
df.groupBy("id").agg(F.avg('price'), F.min('price')).show()
 

Это дает мне желаемый результат:
введите описание изображения здесь