Spark SQL — найти наибольшее количество стран, в которых побывал пассажир

#scala #apache-spark #count #apache-spark-sql #aggregate-functions

#scala #apache-spark #подсчет #apache-spark-sql #агрегатные функции

Вопрос:

У меня есть фрейм данных следующим образом…

  ----------- -------- ---- --- ---------- 
|passengerId|flightId|from| to|      date|
 ----------- -------- ---- --- ---------- 
|       3173|      41|  fr| cn|2017-01-11|
|       3173|      48|  cn| at|2017-01-13|
|       3173|      57|  at| pk|2017-01-17|
|       3173|      71|  pk| il|2017-01-21|
|       3173|     118|  il| se|2017-02-12|
|       3173|     137|  se| iq|2017-02-18|
|       3173|     154|  iq| at|2017-02-24|
|       3173|     231|  at| ar|2017-03-22|
|       3173|     245|  ar| cl|2017-03-28|
|       3173|     270|  cl| sg|2017-04-08|
|       3173|     287|  sg| iq|2017-04-14|
|       3173|     308|  iq| nl|2017-04-21|
|       3173|     317|  nl| dk|2017-04-24|
|       3173|     336|  dk| se|2017-04-29|
|       3173|     463|  se| th|2017-06-14|
|       3173|     480|  th| th|2017-06-20|
|       3173|     650|  th| th|2017-08-21|
|       3173|     660|  th| nl|2017-08-26|
|       3173|     670|  nl| sg|2017-09-01|
|       3173|     695|  sg| ca|2017-09-10|
 ----------- -------- ---- --- ---------- 
  

Я хотел бы найти наибольшее количество стран, в которых побывал пассажир, и не включать начальную страну. Например, если страны, в которых находился пассажир, были: at -> pk -> il -> se -> iq -> at, правильным ответом будет 4 страны.
Выходные данные должны быть в следующем формате:

 Passenger ID    Longest Run
3173                4
1234                n
…                   …
  

Комментарии:

1. Ваша проблема выглядит как найти количество стран, а не большое число. Почему вы используете слово greatest или это ошибка?

Ответ №1:

Начиная с spark 2.4: вы можете сделать это, используя комбинацию collect_list concat функций , array_remove , array_distinct и size spark.

 import org.apache.spark.sql.functions._
import spark.implicits._

val data = Seq(
  (3173, 41, "fr", "cn", "2017-01-11"),
  (3173, 48, "cn", "at", "2017-01-13"),
  (3173, 57, "at", "pk", "2017-01-17"),
  (3173, 71, "pk", "il", "2017-01-21"),
  (3173, 118, "il", "se", "2017-02-12"),
  (3173, 137, "se", "iq", "2017-02-18"),
  (3173, 154, "iq", "at", "2017-02-24"),
  (3173, 231, "at", "ar", "2017-03-22"),
  (3173, 245, "ar", "cl", "2017-03-28"),
  (3173, 270, "cl", "sg", "2017-04-08"),
  (3173, 287, "sg", "iq", "2017-04-14"),
  (3173, 308, "iq", "nl", "2017-04-21"),
  (3173, 317, "nl", "dk", "2017-04-24"),
  (3173, 336, "dk", "se", "2017-04-29"),
  (3173, 463, "se", "th", "2017-06-14"),
  (3173, 480, "th", "th", "2017-06-20"),
  (3173, 650, "th", "th", "2017-08-21"),
  (3173, 660, "th", "nl", "2017-08-26"),
  (3173, 670, "nl", "sg", "2017-09-01"),
  (3173, 695, "sg", "fr", "2017-09-10")
).toDF("passengerId", "flightId", "from", "to", "date")

// first we need to group by passenger to collect all his "from" and "to" countries
val dataWithCountries = data.groupBy("passengerId")
  .agg(
    // concat is for concatenate two lists of strings from columns "from" and "to"
    concat(
      // collect list gathers all values from the given column into array
      collect_list(col("from")),
      collect_list(col("to"))
    ).name("countries")
  )
  

После агрегирования у нас будет список всех стран для пассажира с дубликатами. Затем мы должны сначала удалить (см. array_remove Функцию) его страну (первое значение from столбца для пассажира) из списка отдельных стран (см. array_distinct ) и подсчитать страны с помощью size функции:

 val passengerLongestRuns = dataWithCountries.withColumn(
  "longest_run",
  size(array_remove(array_distinct(col("countries")), col("countries").getItem(0)))
)
passengerLongestRuns.show(false)
  

вывод:

  ----------- ----------- 
|passengerId|longest_run|
 ----------- ----------- 
|3173       |12         |
 ----------- ----------- 
  

Для spark < 2.4
вы можете определить remove и distinct как пользовательские функции:

 def removeAllFirstOccurrences(list: Seq[String]): Seq[String] = list.tail.filter(_ != list.head)

val removeFirstCountry = spark.udf.register[Seq[String], Seq[String]]("remove_first_country", removeAllFirstOccurrences)

def distinct(list: Seq[String]): Seq[String] = list.distinct

val distinctArray = spark.udf.register[Seq[String], Seq[String]]("array_distinct", distinct)

val passengerLongestRuns = dataWithCountries.withColumn(
  "longest_run",
  size(
    distinctArray(
      removeFirstCountry(
        col("countries")
      )
    )
  )
)
passengerLongestRuns.show(false)
  

вывод:

  ----------- --------- 
|passengerId|countries|
 ----------- --------- 
|3173       |12       |
 ----------- --------- 
  

Комментарии:

1. Спасибо за вашу помощь. Просто последующий вопрос: есть ли альтернатива array_remove и array_distinct в более старых версиях spark.