#scala #apache-spark #apache-spark-sql
#scala #apache-spark #apache-spark-sql
Вопрос:
Я пытаюсь улучшить свою модель логистической регрессии, и у меня возникают проблемы с вычислением среднего целевого столбца (https://towardsdatascience.com/why-you-should-try-mean-encoding-17057262cd0 ) функции с помощью Spark Scala
Допустим, у меня есть этот набор данных:
--- -------- --- ------
|id |job |age|target|
--- -------- --- ------
|1 |Doctor |54 |1 |
|2 |Doctor |35 |0 |
|3 |Doctor |28 |1 |
|4 |Doctor |75 |0 |
|5 |Teacher |29 |1 |
|6 |Teacher |37 |1 |
|7 |Engineer|60 |0 |
|8 |Engineer|38 |1 |
|9 |Waiter |31 |1 |
|10 |Driver |31 |0 |
--- -------- --- ------
И я хочу вычислить среднее целевое значение между заданием и целью и получить:
--- -------- -------- ------
|id |job |job_mean|target|
--- -------- -------- ------
|1 |Doctor |0.5 |1 |
|2 |Doctor |0.5 |0 |
|3 |Doctor |0.5 |1 |
|4 |Doctor |0.5 |0 |
|5 |Teacher |1.0 |1 |
|6 |Teacher |1.0 |1 |
|7 |Engineer|0.5 |0 |
|8 |Engineer|0.5 |1 |
|9 |Waiter |1.0 |1 |
|10 |Driver |0.0 |0 |
--- -------- -------- ------
Я не могу найти способ сделать это эффективно с agg()
помощью and withColumn()
.
Это то, что я делал до сих пор, но я изо всех сил пытаюсь воспользоваться этим, я уверен, что есть лучший способ сделать это и заставить его работать :
def targetAverage(colName: String, colValue: String): Double = {
val targetCounts = df
.groupBy(colName, "target")
.count()
.orderBy(col(colName).desc)
val totalCounts = targetCounts
.groupBy(colName)
.agg(sum("count").as("count"))
.orderBy(col(colName).desc)
val targetCount = targetCounts.where(s"target == 1 AND $colName == $colValue").first().getAs[Long]("count")
val totalCount = totalCounts.where(s"$colName == $colValue").first().getAs[Long]("count")
targetCount.floatValue()/totalCount.floatValue()
}
Как я могу использовать его для вычисления нового столбца с помощью withColumn()
?
Ответ №1:
Вы можете использовать оконную функцию в сочетании с функцией среднего значения Spark:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.avg
input.withColumn("job_mean", avg("target").over(Window.partitionBy("job")))
этот код создает следующий фрейм данных (переупорядоченный по идентификатору):
--- -------- --- ------ --------
|id |job |age|target|job_mean|
--- -------- --- ------ --------
|1 |Doctor |54 |1 |0.5 |
|2 |Doctor |35 |0 |0.5 |
|3 |Doctor |28 |1 |0.5 |
|4 |Doctor |75 |0 |0.5 |
|5 |Teacher |29 |1 |1.0 |
|6 |Teacher |37 |1 |1.0 |
|7 |Engineer|60 |0 |0.5 |
|8 |Engineer|38 |1 |0.5 |
|9 |Waiter |31 |1 |1.0 |
|10 |Driver |31 |0 |0.0 |
--- -------- --- ------ --------