Как я могу создать spark udf для интерполяции float в INT и как я могу написать лучшую логику, чем я сделал

#python #dataframe #pyspark #interpolation

#python #фрейм данных #pyspark #интерполяция

Вопрос:

Ниже приведен мой фрейм данных Spark, я хочу выполнить интерполяцию и написать Spark UDF для этого, я не уверен, как я могу написать лучшую логику и создать UDF сверху

Это для преобразования Position_float и интерполяции его в целое число для преобразования Position в соответствующее целочисленное значение

 def dirty_fill(df, id_col, y_cols):
    from pyspark.sql import types as T
    df = df.withColumn('position_plus', (df.position_float   0.5).cast(T.IntegerType()))
    df = df.withColumn('position_minus', (df.position_float - 0.5).cast(T.IntegerType()))
    df = df.withColumn('position', df.position_float.cast(T.IntegerType()))
    df1 = df.select([id_col, 'position_plus']   y_cols).withColumnRenamed('position_plus', 'position')
    df2 = df.select([id_col, 'position_minus']   y_cols).withColumnRenamed('position_minus', 'position')
    df3 = df.select([id_col, 'position']   y_cols)
    df123 = df1.union(df2).union(df3).sort([id_col, 'position']).dropDuplicates([id_col, 'position'])
    return df123
  
 y_cols = ['entry_temperature']
finish_mill_entry_filled = dirty_fill(finish_mill_entry, 'finish_mill_id', y_cols)

  

Это мой пример фрейма данных

 | Finishing_mill_id  | Sample  | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529         | 1       | 0.000000       | 1986.0     |
| 2015418529         | 2       | 2.192982       | 1997.0     |
| 2015418529         | 3       | 4.385965       | 2003.0     |
| 2018171498         | 445     | 495.535714     | 1643.0     |
| 2018171498         | 446     | 496.651786     | 1734.0     |
| 2018171498         | 447     | 497.767857     | 1748.0     |
| 2018171498         | 448     | 498.883929     | 1755.0     |
  

Мне нужно интерполировать float в целое число

Чего я хочу, так это

 | Finishing_mill_id  | Sample  | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529         | 1       | 0              | 1986.0     |
| 2015418529         | 2       | 1              | 1986       |
| 2015418529         | 3       | 2              | 1997.0     |
| 2015418529         | 4       | 3              | 1997       |
| 2015418529         | 5       | 4              | 2003.0     |
| 2018171498         | 445     | 496            | 1643.0     |
| 2018171498         | 446     | 497            | 1734.0     |
| 2018171498         | 447     | 498            | 1748.0     |
| 2018171498         | 448     | 499            | 1755.0     |
  

Для этого мне нужна функция spark user_defined, и не должно быть никаких пропущенных точек данных, поскольку у меня есть Position_float в диапазоне 0-500, мне также нужно позаботиться о том, чтобы были все точки, не пропуская ни одной точки. Необходимо соответствующим образом изменить мою логику интерполяции

Чтобы было немного понятно, скажем, у меня есть позиция 0.000 2.19, но у меня нет для этого точки данных, но то, что мне нужно, когда я делаю, мне нужно иметь позицию для 1.00..Мне нужно значение для позиции 1.00, даже если данных нет в виде линейной интерполяции.Надеюсь, это поможет

Ответ №1:

1. Оконные функции

Вы можете использовать оконные функции для заполнения пробелов и интерполяции значений.

Давайте начнем с примера фрейма данных:

 import pyspark.sql.functions as psf
import pyspark.sql.types as pst
from pyspark.sql import Window
import numpy as np

df = spark.createDataFrame(
        [[float(t)/10., float(v)] for t, v in zip(np.random.randint(0, 1000, 20), np.random.randint(100, 200, 20))], 
        schema=pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position', 'value']])) 
    .withColumn('position_round', psf.round('position'))

         -------- ----- -------------- 
        |position|value|position_round|
         -------- ----- -------------- 
        |    68.5|121.0|          69.0|
        |    76.3|126.0|          76.0|
        |    88.3|150.0|          88.0|
        |    59.0|197.0|          59.0|
        |    20.7|119.0|          21.0|
        |     0.1|167.0|           0.0|
        |    20.1|177.0|          20.0|
        |    81.9|199.0|          82.0|
        |    63.6|163.0|          64.0|
        |    32.4|115.0|          32.0|
        |    43.6|130.0|          44.0|
        |    11.9|175.0|          12.0|
        |    68.2|176.0|          68.0|
        |    28.9|184.0|          29.0|
        |    46.3|199.0|          46.0|
        |     9.7|155.0|          10.0|
        |    57.8|163.0|          58.0|
        |    83.6|173.0|          84.0|
        |    16.2|169.0|          16.0|
        |    87.1|127.0|          87.0|
         -------- ----- -------------- 
  

Чтобы заполнить пробелы, мы создадим диапазон целых чисел:

 start, end = list(df.agg(psf.min('position_round'), psf.max('position_round')).collect()[0])
pos_df = spark.range(start=start, end=end, step=1) 
    .withColumnRenamed('id', 'position_round')
  

Теперь мы можем объединить два фрейма данных:

 w1 = Window.orderBy('position_round')
w2 = Window.partitionBy('group').orderBy('position_round')

df_resample = df 
    .select(
        '*', 
        psf.lead('position_round', 1).over(w1).alias('next_position'), 
        psf.lead('value', 1).over(w1).alias('next_value')) 
    .join(pos_df, on='position_round', how='right') 
    .withColumn('group', psf.sum((~psf.isnull('position')).cast('int')).over(w1)) 
    .select(
        '*', 
        (psf.row_number().over(w2) - 1).alias('i'), 
        psf.first(psf.col('next_position') - psf.col('position_round')).over(w2).alias('dx'), 
        psf.first('value').over(w2).alias('value0'), 
        psf.first(psf.col('next_value') - psf.col('value')).over(w2).alias('dy')) 
    .withColumn(
        'value_round', 
        psf.when((psf.col('dx') > 0) | psf.isnull('next_value'), psf.col('value0')   psf.col('i') * psf.col('dy') / psf.col('dx')) 
            .otherwise(psf.col('value')))
  
  • Первая оконная функция заключается в сохранении next_value и next_position последующем вычислении наших dx и dy
  • Затем нам нужно идентифицировать каждый пробел с отдельным group идентификатором, чтобы мы могли интерполировать значения для каждого отдельного линейного сегмента
  • и последнее, но не менее важное: мы объединяем все элементы, которые нам нужны:
    • длина разрыва: dx
    • дельта в значениях: dy
    • индекс текущей строки в gap i

Теперь мы можем вычислить value_round интерполяцию value at position position_round

          -------------- -------- ----- ------------- ---------- ----- --- ---- ------ ----- ----------- 
        |position_round|position|value|next_position|next_value|group|  i|  dx|value0|   dy|value_round|
         -------------- -------- ----- ------------- ---------- ----- --- ---- ------ ----- ----------- 
        |             0|     0.1|167.0|         10.0|     155.0|    1|  0|10.0| 167.0|-12.0|      167.0|
        |             1|    null| null|         null|      null|    1|  1|10.0| 167.0|-12.0|      165.8|
        |             2|    null| null|         null|      null|    1|  2|10.0| 167.0|-12.0|      164.6|
        |             3|    null| null|         null|      null|    1|  3|10.0| 167.0|-12.0|      163.4|
        |             4|    null| null|         null|      null|    1|  4|10.0| 167.0|-12.0|      162.2|
        |             5|    null| null|         null|      null|    1|  5|10.0| 167.0|-12.0|      161.0|
        |             6|    null| null|         null|      null|    1|  6|10.0| 167.0|-12.0|      159.8|
        |             7|    null| null|         null|      null|    1|  7|10.0| 167.0|-12.0|      158.6|
        |             8|    null| null|         null|      null|    1|  8|10.0| 167.0|-12.0|      157.4|
        |             9|    null| null|         null|      null|    1|  9|10.0| 167.0|-12.0|      156.2|
        |            10|     9.7|155.0|         12.0|     175.0|    2|  0| 2.0| 155.0| 20.0|      155.0|
        |            11|    null| null|         null|      null|    2|  1| 2.0| 155.0| 20.0|      165.0|
        |            12|    11.9|175.0|         16.0|     169.0|    3|  0| 4.0| 175.0| -6.0|      175.0|
        |            13|    null| null|         null|      null|    3|  1| 4.0| 175.0| -6.0|      173.5|
        |            14|    null| null|         null|      null|    3|  2| 4.0| 175.0| -6.0|      172.0|
        |            15|    null| null|         null|      null|    3|  3| 4.0| 175.0| -6.0|      170.5|
        |            16|    16.2|169.0|         20.0|     177.0|    4|  0| 4.0| 169.0|  8.0|      169.0|
        |            17|    null| null|         null|      null|    4|  1| 4.0| 169.0|  8.0|      171.0|
        |            18|    null| null|         null|      null|    4|  2| 4.0| 169.0|  8.0|      173.0|
        |            19|    null| null|         null|      null|    4|  3| 4.0| 169.0|  8.0|      175.0|
         -------------- -------- ----- ------------- ---------- ----- --- ---- ------ ----- ----------- 
  

2. UDF

Если вы не хотите использовать оконные функции, вы можете написать a UDF для выполнения интерполяции, python а затем вернуть массив (позиция, значение) кортежей:

 def interpolate(pos, next_pos, value, next_value):
    if pos == next_pos or next_value is None:
        return [(pos, value)]
    return [[pos   i, value   i * (next_value - value) / (next_pos - pos)] for i in range(int(next_pos - pos))]
interpolate_udf = psf.udf(interpolate, pst.ArrayType(pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position_round', 'value_round']])))
  

Обратите внимание, что кортежи имеют тип StructType , чтобы упростить «сглаживание» кортежей в столбцы.

 w1 = Window.orderBy('position_round')
df_udf = df 
    .select(
        '*', 
        psf.lead('position_round', 1).over(w1).alias('next_position'), 
        psf.lead('value', 1).over(w1).alias('next_value')) 
    .withColumn('tmp', psf.explode(interpolate_udf('position_round', 'next_position', 'value', 'next_value'))) 
    .select('*', 'tmp.*').drop('tmp')
  

Вот что мы получаем:

          -------- ----- -------------- ------------- ---------- -------------- ---------- 
        |position|value|position_round|next_position|next_value|position_round|value_round|
         -------- ----- -------------- ------------- ---------- -------------- ---------- 
        |     0.1|167.0|           0.0|         10.0|     155.0|           0.0|     167.0|
        |     0.1|167.0|           0.0|         10.0|     155.0|           1.0|     165.8|
        |     0.1|167.0|           0.0|         10.0|     155.0|           2.0|     164.6|
        |     0.1|167.0|           0.0|         10.0|     155.0|           3.0|     163.4|
        |     0.1|167.0|           0.0|         10.0|     155.0|           4.0|     162.2|
        |     0.1|167.0|           0.0|         10.0|     155.0|           5.0|     161.0|
        |     0.1|167.0|           0.0|         10.0|     155.0|           6.0|     159.8|
        |     0.1|167.0|           0.0|         10.0|     155.0|           7.0|     158.6|
        |     0.1|167.0|           0.0|         10.0|     155.0|           8.0|     157.4|
        |     0.1|167.0|           0.0|         10.0|     155.0|           9.0|     156.2|
        |     9.7|155.0|          10.0|         12.0|     175.0|          10.0|     155.0|
        |     9.7|155.0|          10.0|         12.0|     175.0|          11.0|     165.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          12.0|     175.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          13.0|     173.5|
        |    11.9|175.0|          12.0|         16.0|     169.0|          14.0|     172.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          15.0|     170.5|
        |    16.2|169.0|          16.0|         20.0|     177.0|          16.0|     169.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          17.0|     171.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          18.0|     173.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          19.0|     175.0|
         -------- ----- -------------- ------------- ---------- -------------- ---------- 
  

Ответ №2:

Просто используйте round и введите приведение к IntegerType

 from pyspark.sql import functions as F
from pyspark.sql import types as T

df = df.withColumn('Position_float', F.round(F.col('Position_float')).cast(T.IntegerType()))
  

Комментарии:

1. Но у меня есть еще одна проблема, которую нужно решить, поэтому мне нужно интерполировать, скажем, у меня диапазон от 0 до 500, если я не получу такие точки, как мне это получить

2. не могли бы вы объяснить пример ввода и вывода.

3. Я добавил, пожалуйста, посмотрите