#scala #apache-spark #apache-spark-sql #user-defined-functions
#scala #apache-spark #apache-spark-sql #определяемые пользователем функции
Вопрос:
Может ли кто-нибудь сказать мне эквивалентную функцию для collect_set в spark 1.5?
Есть ли какие-либо обходные пути для получения аналогичных результатов, таких как collect_set(col(name)) ?
Это правильный подход :
class CollectSetFunction[T](val colType: DataType) extends UserDefinedAggregateFunction {
def inputSchema: StructType =
new StructType().add("inputCol", colType)
def bufferSchema: StructType =
new StructType().add("outputCol", ArrayType(colType))
def dataType: DataType = ArrayType(colType)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val list = buffer.getSeq[T](0)
if (!input.isNullAt(0)) {
val sales = input.getAs[T](0)
buffer.update(0, list: sales)
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getSeq[T](0).toSet buffer2.getSeq[T](0).toSet)
}
def evaluate(buffer: Row): Any = {
buffer.getSeq[T](0)
}
}
Ответ №1:
Его код выглядит правильно. Кроме того, я протестировал 1.6.2 в локальном режиме и получил тот же результат (см. Ниже). Я не знаю какой-либо более простой альтернативы с использованием DataFrame
API. Используя RDD, это довольно просто, и может быть предпочтительнее иногда обходить RDD API в 1.5, поскольку фреймы данных реализованы не полностью.
scala> val rdd = sc.parallelize((1 to 10)).map(x => (x%5,x))
scala> rdd.groupByKey.mapValues(_.toSet.toList)).toDF("k","set").show
--- -------
| k| set|
--- -------
| 0|[5, 10]|
| 1| [1, 6]|
| 2| [2, 7]|
| 3| [3, 8]|
| 4| [4, 9]|
--- -------
И если вы хотите это исключить, начальная версия (которая может быть улучшена) может быть следующей
def collectSet(df: DataFrame, k: Column, v: Column) = df
.select(k.as("k"),v.as("v"))
.map( r => (r.getInt(0),r.getInt(1)))
.groupByKey()
.mapValues(_.toSet.toList)
.toDF("k","v")
но если вы хотите создать другие агрегации, вы не сможете избежать объединения.
scala> val df = sc.parallelize((1 to 10)).toDF("v").withColumn("k", pmod('v,lit(5)))
df: org.apache.spark.sql.DataFrame = [v: int, k: int]
scala> val csudaf = new CollectSetFunction[Int](IntegerType)
scala> df.groupBy('k).agg(collect_set('v),csudaf('v)).show
--- -------------- ---------------------
| k|collect_set(v)|CollectSetFunction(v)|
--- -------------- ---------------------
| 0| [5, 10]| [5, 10]|
| 1| [1, 6]| [1, 6]|
| 2| [2, 7]| [2, 7]|
| 3| [3, 8]| [3, 8]|
| 4| [4, 9]| [4, 9]|
--- -------------- ---------------------
тест 2:
scala> val df = sc.parallelize((1 to 100000)).toDF("v").withColumn("k", floor(rand*10))
df: org.apache.spark.sql.DataFrame = [v: int, k: bigint]
scala> df.groupBy('k).agg(collect_set('v).as("a"),csudaf('v).as("b"))
.groupBy('a==='b).count.show
------- -----
|(a = b)|count|
------- -----
| true| 10|
------- -----