проверка метода эквивалентного набора данных spark 1.5 UDAF

#scala #apache-spark #apache-spark-sql #user-defined-functions

#scala #apache-spark #apache-spark-sql #определяемые пользователем функции

Вопрос:

Может ли кто-нибудь сказать мне эквивалентную функцию для collect_set в spark 1.5?

Есть ли какие-либо обходные пути для получения аналогичных результатов, таких как collect_set(col(name)) ?

Это правильный подход :

 class CollectSetFunction[T](val colType: DataType) extends UserDefinedAggregateFunction {

  def inputSchema: StructType =
    new StructType().add("inputCol", colType)

  def bufferSchema: StructType =
    new StructType().add("outputCol", ArrayType(colType))

  def dataType: DataType = ArrayType(colType)

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val list = buffer.getSeq[T](0)
    if (!input.isNullAt(0)) {
      val sales = input.getAs[T](0)
      buffer.update(0, list: sales)
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getSeq[T](0).toSet    buffer2.getSeq[T](0).toSet)
  }

  def evaluate(buffer: Row): Any = {
    buffer.getSeq[T](0)
  }
}
  

Ответ №1:

Его код выглядит правильно. Кроме того, я протестировал 1.6.2 в локальном режиме и получил тот же результат (см. Ниже). Я не знаю какой-либо более простой альтернативы с использованием DataFrame API. Используя RDD, это довольно просто, и может быть предпочтительнее иногда обходить RDD API в 1.5, поскольку фреймы данных реализованы не полностью.

 scala> val rdd = sc.parallelize((1 to 10)).map(x => (x%5,x))
scala> rdd.groupByKey.mapValues(_.toSet.toList)).toDF("k","set").show
 --- ------- 
|  k|    set|
 --- ------- 
|  0|[5, 10]|
|  1| [1, 6]|
|  2| [2, 7]|
|  3| [3, 8]|
|  4| [4, 9]|
 --- ------- 
  

И если вы хотите это исключить, начальная версия (которая может быть улучшена) может быть следующей

 def collectSet(df: DataFrame, k: Column, v: Column) = df
    .select(k.as("k"),v.as("v"))
    .map( r => (r.getInt(0),r.getInt(1)))
    .groupByKey()
    .mapValues(_.toSet.toList)
    .toDF("k","v")
  

но если вы хотите создать другие агрегации, вы не сможете избежать объединения.


 scala> val df = sc.parallelize((1 to 10)).toDF("v").withColumn("k", pmod('v,lit(5)))
df: org.apache.spark.sql.DataFrame = [v: int, k: int]

scala> val csudaf = new CollectSetFunction[Int](IntegerType)

scala> df.groupBy('k).agg(collect_set('v),csudaf('v)).show
 --- -------------- --------------------- 
|  k|collect_set(v)|CollectSetFunction(v)|
 --- -------------- --------------------- 
|  0|       [5, 10]|              [5, 10]|
|  1|        [1, 6]|               [1, 6]|
|  2|        [2, 7]|               [2, 7]|
|  3|        [3, 8]|               [3, 8]|
|  4|        [4, 9]|               [4, 9]|
 --- -------------- --------------------- 
  

тест 2:

 scala> val df = sc.parallelize((1 to 100000)).toDF("v").withColumn("k", floor(rand*10))
df: org.apache.spark.sql.DataFrame = [v: int, k: bigint]

scala> df.groupBy('k).agg(collect_set('v).as("a"),csudaf('v).as("b"))
         .groupBy('a==='b).count.show
 ------- -----                                                                  
|(a = b)|count|
 ------- ----- 
|   true|   10|
 ------- -----