Применить UDFs к фрейму данных pyspark на основе значения строки

#apache-spark #pyspark #apache-spark-sql #pyspark-dataframes #apache-spark-2.0

#apache-spark #pyspark #apache-spark-sql #pyspark-фреймы данных #apache-spark-2.0

Вопрос:

У меня есть фрейм данных pyspark со следующей схемой

  ----------- --------- ---------- ----------- 
|     userID|grouping1| grouping2|   features|
 ----------- --------- ---------- ----------- 
|12462563356|        1|        A | [5.0,43.0]|
|12462563701|        2|        A |  [1.0,8.0]|
|12462563701|        1|        B | [2.0,12.0]|
|12462564356|        1|        C |  [1.0,1.0]|
|12462565487|        3|        C |  [2.0,3.0]|
|12462565698|        2|        D |  [1.0,1.0]|
|12462565698|        1|        E |  [1.0,1.0]|
|12462566081|        2|        C |  [1.0,2.0]|
|12462566081|        1|        D | [1.0,15.0]|
|12462566225|        2|        E |  [1.0,1.0]|
|12462566225|        1|        A | [9.0,85.0]|
|12462566526|        2|        C |  [1.0,1.0]|
|12462566526|        1|        D | [3.0,79.0]|
|12462567006|        2|        D |[11.0,15.0]|
|12462567006|        1|        B |[10.0,15.0]|
|12462567006|        3|        A |[10.0,15.0]|
|12462586595|        2|        B | [2.0,42.0]|
|12462586595|        3|        D | [2.0,16.0]|
|12462589343|        3|        E |  [1.0,1.0]|
 ----------- --------- ---------- ----------- 
 

Для значений в grouping2 A , B , C и D мне нужно применить UDF_A , UDF_B , UDF_C и UDF_D соответственно. Есть ли способ, которым я могу написать что-то вроде

dataset = dataset.withColumn('outputColName', selectUDF(**params))

где selectUDF определяется как

 def selectUDF(**params):
    if row[grouping2] == A:
        return UDF_A(**params)
    elif row[grouping2] == B:
        return UDF_B(**params)
    elif row[grouping2] == C:
        return UDF_C(**params)
    elif row[grouping2] == D:
        return UDF_D(**params)
 

Используя следующий пример, чтобы проиллюстрировать, что я пытаюсь сделать
Да, я тоже так думал. Я использую следующий игрушечный код, чтобы проверить это

 >>> df = sc.parallelize([[1,2,3], [2,3,4]]).toDF(("a", "b", "c"))
>>> df.show()
 --- --- --- 
|  a|  b|  c|
 --- --- --- 
|  1|  2|  3|
|  2|  3|  4|
 --- --- --- 
>>> def udf1(col):
...     return col1*col1
...
>>> def udf2(col):
...     return col2*col2*col2
...
>>> def select_udf(col1, col2):
...     if col1 == 2:
...         return udf1(col2)
...     elif col1 == 3:
...         return udf2(col2)
...     else:
...         return 0
...
>>> from pyspark.sql.functions import col
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> select_udf = udf(select_udf, IntegerType())
>>> udf1 = udf(udf1, IntegerType())
>>> udf2 = udf(udf2, IntegerType())
>>> df.withColumn("outCol", select_udf(col("b"), col("c"))).show()
[Stage 9:============================================>              (3   1) / 4]
 

Похоже, это застряло на этом этапе навсегда. Кто-нибудь может подсказать, что здесь может быть не так?

Комментарии:

1. Похоже, это сработает. Если UDF_A UDF_B это Spark UDFS, которые возвращают тот же тип spark, то вы должны иметь возможность selectUDF обернуть udf() его и использовать. Однако вам не обязательно передавать accept **params . Вы можете просто принять одно значение grouping2 . Вы сталкиваетесь с ошибкой при выполнении этого подхода?

2. Да, я тоже так думал. Отредактировал исходное сообщение и добавил макет кода, с которым я его тестирую.

Ответ №1:

Вам не нужно selectUDF , просто используйте when expression для применения желаемого udf в зависимости от значения grouping2 столбца:

 from pyspark.sql.functions import col, when

df = df.withColumn(
    "outCol",
    when(col("grouping2") == "A", UDF_A(*params))
        .when(col("grouping2") == "B", UDF_B(*params))
        .when(col("grouping2") == "C", UDF_C(*params))
        .when(col("grouping2") == "D", UDF_D(*params))
)