Вычислить средний целевой столбец объекта с помощью Spark в Scala

#scala #apache-spark #apache-spark-sql

#scala #apache-spark #apache-spark-sql

Вопрос:

Я пытаюсь улучшить свою модель логистической регрессии, и у меня возникают проблемы с вычислением среднего целевого столбца (https://towardsdatascience.com/why-you-should-try-mean-encoding-17057262cd0 ) функции с помощью Spark Scala

Допустим, у меня есть этот набор данных:

  --- -------- --- ------ 
|id |job     |age|target|
 --- -------- --- ------ 
|1  |Doctor  |54 |1     |
|2  |Doctor  |35 |0     |
|3  |Doctor  |28 |1     |
|4  |Doctor  |75 |0     |
|5  |Teacher |29 |1     |
|6  |Teacher |37 |1     |
|7  |Engineer|60 |0     |
|8  |Engineer|38 |1     |
|9  |Waiter  |31 |1     |
|10 |Driver  |31 |0     |
 --- -------- --- ------ 
 

И я хочу вычислить среднее целевое значение между заданием и целью и получить:

  --- -------- -------- ------ 
|id |job     |job_mean|target|
 --- -------- -------- ------ 
|1  |Doctor  |0.5     |1     |
|2  |Doctor  |0.5     |0     |
|3  |Doctor  |0.5     |1     |
|4  |Doctor  |0.5     |0     |
|5  |Teacher |1.0     |1     |
|6  |Teacher |1.0     |1     |
|7  |Engineer|0.5     |0     |
|8  |Engineer|0.5     |1     |
|9  |Waiter  |1.0     |1     |
|10 |Driver  |0.0     |0     |
 --- -------- -------- ------ 
 

Я не могу найти способ сделать это эффективно с agg() помощью and withColumn() .

Это то, что я делал до сих пор, но я изо всех сил пытаюсь воспользоваться этим, я уверен, что есть лучший способ сделать это и заставить его работать :

 def targetAverage(colName: String, colValue: String): Double = {
    val targetCounts = df
      .groupBy(colName, "target")
      .count()
      .orderBy(col(colName).desc)
    val totalCounts = targetCounts
      .groupBy(colName)
      .agg(sum("count").as("count"))
      .orderBy(col(colName).desc)

    val targetCount = targetCounts.where(s"target == 1 AND $colName == $colValue").first().getAs[Long]("count")
    val totalCount = totalCounts.where(s"$colName == $colValue").first().getAs[Long]("count")
    
    targetCount.floatValue()/totalCount.floatValue()
  }
  
 

Как я могу использовать его для вычисления нового столбца с помощью withColumn() ?

Ответ №1:

Вы можете использовать оконную функцию в сочетании с функцией среднего значения Spark:

 import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.avg

input.withColumn("job_mean", avg("target").over(Window.partitionBy("job")))
 

этот код создает следующий фрейм данных (переупорядоченный по идентификатору):

  --- -------- --- ------ -------- 
|id |job     |age|target|job_mean|
 --- -------- --- ------ -------- 
|1  |Doctor  |54 |1     |0.5     |
|2  |Doctor  |35 |0     |0.5     |
|3  |Doctor  |28 |1     |0.5     |
|4  |Doctor  |75 |0     |0.5     |
|5  |Teacher |29 |1     |1.0     |
|6  |Teacher |37 |1     |1.0     |
|7  |Engineer|60 |0     |0.5     |
|8  |Engineer|38 |1     |0.5     |
|9  |Waiter  |31 |1     |1.0     |
|10 |Driver  |31 |0     |0.0     |
 --- -------- --- ------ --------