#scala #apache-spark #apache-spark-sql
#scala #apache-spark #apache-spark-sql
Вопрос:
У меня есть df:
joined.printSchema
root
|-- cc_num: long (nullable = true)
|-- lat: double (nullable = true)
|-- long: double (nullable = true)
|-- merch_lat: double (nullable = true)
|-- merch_long: double (nullable = true)
У меня есть udf:
def getDistance (lat1:Double, lon1:Double, lat2:Double, lon2:Double) = {
val r : Int = 6371 //Earth radius
val latDistance : Double = Math.toRadians(lat2 - lat1)
val lonDistance : Double = Math.toRadians(lon2 - lon1)
val a : Double = Math.sin(latDistance / 2) * Math.sin(latDistance / 2) Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) * Math.sin(lonDistance / 2) * Math.sin(lonDistance / 2)
val c : Double = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a))
val distance : Double = r * c
distance
}
Мне нужно сгенерировать новый столбец в df с:
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
Я получил сообщение об ошибке ниже:
Name: Unknown Error
Message: <console>:35: error: type mismatch;
found : String("lat")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("long")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("merch_lat")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
<console>:35: error: type mismatch;
found : String("merch_long")
required: Double
joined = joined.withColumn("distance", getDistance("lat", "long", "merch_lat", "merch_long"))
^
Как вы можете видеть из схемы, все задействованные поля имеют тип double
, который соответствует определению типа параметра udf, почему я вижу ошибку несоответствия типа данных?
Кто-нибудь может пояснить здесь, что не так и как это исправить?
Большое вам спасибо.
Ответ №1:
Ваш getDistance
метод НЕ является UDF, это метод Scala, ожидающий 4 Double
аргумента, и вместо этого вы передаете 4 строки.
Чтобы исправить это, вам необходимо:
- «Оберните» свой метод UDF и
- При применении UDF передавайте аргументы столбца, а не строки, что вы можете сделать, добавив к имени столбца префикс
$
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import spark.implicits._ // assuming "spark" is your SparkSession
val distanceUdf: UserDefinedFunction = udf(getDistance _)
joined.withColumn("distance", distanceUdf($"lat", $"long", $"merch_lat", $"merch_long"))
Комментарии:
1. Вам
org.apache.spark.sql.expressions.UserDefinedFunction
тоже нужно будет импортировать; И я предположилspark
, что именно так называется ваша SparkSession; Если она названа по-другому, заменитеspark
ссылкой на сеанс. Импорт импликаций сеанса не является обязательным, но очень полезен.