Почему изменился тип данных при вызове UDF в scala

#scala #apache-spark #apache-spark-sql

#scala #apache-spark #apache-spark-sql

Вопрос:

У меня есть df:

 joined.printSchema
root
 |-- cc_num: long (nullable = true)
 |-- lat: double (nullable = true)
 |-- long: double (nullable = true)
 |-- merch_lat: double (nullable = true)
 |-- merch_long: double (nullable = true)
  

У меня есть udf:

 def getDistance (lat1:Double, lon1:Double, lat2:Double, lon2:Double) = {
    val r : Int = 6371 //Earth radius
    val latDistance : Double = Math.toRadians(lat2 - lat1)
    val lonDistance : Double = Math.toRadians(lon2 - lon1)
    val a : Double = Math.sin(latDistance / 2) * Math.sin(latDistance / 2)   Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) * Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2)
    val c : Double = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
    val distance : Double = r * c
    distance
  }
  

Мне нужно сгенерировать новый столбец в df с:

 joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
  

Я получил сообщение об ошибке ниже:

 Name: Unknown Error
Message: <console>:35: error: type mismatch;
 found   : String("lat")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                          ^
<console>:35: error: type mismatch;
 found   : String("long")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                 ^
<console>:35: error: type mismatch;
 found   : String("merch_lat")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                         ^
<console>:35: error: type mismatch;
 found   : String("merch_long")
 required: Double
       joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
                                                                                      ^
  

Как вы можете видеть из схемы, все задействованные поля имеют тип double , который соответствует определению типа параметра udf, почему я вижу ошибку несоответствия типа данных?

Кто-нибудь может пояснить здесь, что не так и как это исправить?

Большое вам спасибо.

Ответ №1:

Ваш getDistance метод НЕ является UDF, это метод Scala, ожидающий 4 Double аргумента, и вместо этого вы передаете 4 строки.

Чтобы исправить это, вам необходимо:

  • «Оберните» свой метод UDF и
  • При применении UDF передавайте аргументы столбца, а не строки, что вы можете сделать, добавив к имени столбца префикс $
 import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import spark.implicits._ // assuming "spark" is your SparkSession

val distanceUdf: UserDefinedFunction = udf(getDistance _)

joined.withColumn("distance", distanceUdf($"lat", $"long", $"merch_lat", $"merch_long"))
  

Комментарии:

1. Вам org.apache.spark.sql.expressions.UserDefinedFunction тоже нужно будет импортировать; И я предположил spark , что именно так называется ваша SparkSession; Если она названа по-другому, замените spark ссылкой на сеанс. Импорт импликаций сеанса не является обязательным, но очень полезен.