PySpark: Агрегатная функция для столбца с несколькими условиями

#python #pyspark #apache-spark-sql

Вопрос:

У меня есть два кадра данных PySpark A и B

 A
GROUP |    date    | 
 1    | 2021-02-01 |
 1    | 2021-04-01 |
 1    | 2021-07-23 | 
 1    | 2021-07-30 | 
 2    | 2021-02-01 |
 2    | 2021-04-01 |
 2    | 2021-07-23 | 
 2    | 2021-07-30 | 


B
GROUP |    date    | val
 1    | 2021-03-31 | 15
 2    | 2021-03-31 | 25
 2    | 2021-06-30 | 40
 

Я хочу присоединиться к ним таким образом, чтобы новый столбец last_reported_val был (MAX(B. дата) и B. дата Столбец val должен принимать соответствующее значение из B . Ниже приведен пример:

 GROUP |    date    | last_reported_val |   val   |
 1    | 2021-02-01 |        NULL       |   NULL  |
 1    | 2021-04-01 |      2021-03-31   |   15    |
 1    | 2021-07-23 |      2021-03-31   |   15    |
 1    | 2021-07-30 |      2021-03-31   |   15    |
 2    | 2021-02-01 |        NULL       |   NULL  |
 2    | 2021-04-01 |      2021-03-31   |   25    |
 2    | 2021-07-23 |      2021-06-30   |   40    |
 2    | 2021-07-30 |      2021-06-30   |   40    |
 

С SQL я бы сделал что-то вроде

 SELECT A.group, A.date, (select MAX(B.date) from B where B.date <= A.date and A.group = B.group) as last_reported_val, B.val
FROM A
LEFT JOIN B
on A.group = B.group
 

Как бы вы сделали это в PySpark? Я пробовал join map , но безуспешно

 A.join(B, A['GROUP'] == B['GROUP'], 'left')

...

# This raises _thread.lock error 

A.rdd.map(lambda r: (..., A.filter(
    (A['a.date'] == r['a.date']) amp; (A['group'] == r['group'])
).agg(max_('b.date')).collect())


 

Ответ №1:

Вы можете сначала присоединиться к своему условию, затем для каждой даты в B, которая удовлетворяет условию более чем для 1 даты в A, выберите максимальную дату в B, создав окно:

 from pyspark.sql import functions as F, Window as W

o = (A.alias("A").join(B.alias("B"),on=[F.col("A.GROUP")==F.col("B.GROUP"),
                                  F.col("B.date") <= F.col("A.date")]
                                , how='left')
.select("A.*",F.col("B.date").alias("last_reported_val"),"B.val"))

w = W.partitionBy("GROUP","date").orderBy(F.desc("last_reported_val"))

o.withColumn("Rnum",F.row_number().over(w)).filter("Rnum==1").drop("Rnum").show()
 

  ----- ---------- ----------------- ---- 
|GROUP|      date|last_reported_val| val|
 ----- ---------- ----------------- ---- 
|    1|2021-02-01|             null|null|
|    1|2021-04-01|       2021-03-31|  15|
|    1|2021-07-23|       2021-03-31|  15|
|    1|2021-07-30|       2021-03-31|  15|
|    2|2021-02-01|             null|null|
|    2|2021-04-01|       2021-03-31|  25|
|    2|2021-07-23|       2021-06-30|  40|
|    2|2021-07-30|       2021-06-30|  40|
 ----- ---------- ----------------- ----