pyspark, когда оператор в противном случае возвращает неверный вывод

#apache-spark #pyspark #apache-spark-sql #case

Вопрос:

Я вставил свой код ниже. Я ожидаю , что когда col2 = 7 он должен вернуться 1 , но он возвращается 1 раз, а 2 — в другое время. Я не выполняю никаких операций col2 , как только он установлен. Кто-нибудь когда-нибудь сталкивался с таким странным поведением? Или проблема связана с тем, что ограничения для каждого условия перекрываются?

  df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                             .when(F.col('col2').between(7,14), 2)
                             .when(F.col('col2').between(14,21), 3)
                             .when(F.col('col2').between(21,28), 4)
                             .otherwise(5))
 

Ответ №1:

Я бы сказал, что это нечто неожиданное, потому case-when что код будет преобразован в последовательность if s с помощью CodeGen. Следовательно, вы всегда должны видеть 'col2' значение 1.

Вы можете просмотреть фактический код , созданный с помощью Spark QueryExecution.debug.codegen , примерно так:

 >>> df = spark.range(1000)
>>> from pyspark.sql.functions import *
>>> dff = df.withColumn('col1',when(col('id').between(1,7),1).when(col('id').between(7,14),2).otherwise(3))

>>> dff._jdf.queryExecution().debug().codegen()

Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 ==
*(1) Project [id#4L, CASE WHEN ((id#4L >= 1) amp;amp; (id#4L <= 7)) THEN 1 WHEN ((id#4L >= 7) amp;amp; (id#4L <= 14)) THEN 2 ELSE 3 END AS col1#6]
 - *(1) Range (0, 1000, step=1, splits=2)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_number_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private int project_project_value_2_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 031 */
/* 032 */   }
/* 033 */
/* 034 */   private void project_doConsume_0(long project_expr_0_0) throws java.io.IOException {
/* 035 */     byte project_caseWhenResultState_0 = -1;
/* 036 */     do {
/* 037 */       boolean project_value_4 = false;
/* 038 */       project_value_4 = project_expr_0_0 >= 1L;
/* 039 */       boolean project_value_3 = false;
/* 040 */
/* 041 */       if (project_value_4) {
/* 042 */         boolean project_value_7 = false;
/* 043 */         project_value_7 = project_expr_0_0 <= 7L;
/* 044 */         project_value_3 = project_value_7;
/* 045 */       }
/* 046 */       if (!false amp;amp; project_value_3) {
/* 047 */         project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
/* 048 */         project_project_value_2_0 = 1;
/* 049 */         continue;
/* 050 */       }
/* 051 */
/* 052 */       boolean project_value_12 = false;
/* 053 */       project_value_12 = project_expr_0_0 >= 7L;
/* 054 */       boolean project_value_11 = false;
/* 055 */
/* 056 */       if (project_value_12) {
/* 057 */         boolean project_value_15 = false;
/* 058 */         project_value_15 = project_expr_0_0 <= 14L;
/* 059 */         project_value_11 = project_value_15;
/* 060 */       }
/* 061 */       if (!false amp;amp; project_value_11) {
/* 062 */         project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
/* 063 */         project_project_value_2_0 = 2;
/* 064 */         continue;
/* 065 */       }
/* 066 */
/* 067 */       project_caseWhenResultState_0 = (byte)(false ? 1 : 0);
/* 068 */       project_project_value_2_0 = 3;
/* 069 */
/* 070 */     } while (false);
/* 071 */     // TRUE if any condition is met and the result is null, or no any condition is met.
/* 072 */     final boolean project_isNull_2 = (project_caseWhenResultState_0 != 0);
/* 073 */     range_mutableStateArray_0[2].reset();
/* 074 */
/* 075 */     range_mutableStateArray_0[2].zeroOutNullBytes();
/* 076 */
/* 077 */     range_mutableStateArray_0[2].write(0, project_expr_0_0);
/* 078 */
/* 079 */     range_mutableStateArray_0[2].write(1, project_project_value_2_0);
/* 080 */     append((range_mutableStateArray_0[2].getRow()));
/* 081 */
/* 082 */   }
/* 083 */
...
 

Нас интересует метод private void project_doConsume_0(... (начиная со строки 34).

Комментарии:

1. Спасибо за подробное объяснение. Теперь мы понимаем, в чем дело, и немного углубились в вывод и поняли, что десятичные дроби с 7 цифрами были причиной этой проблемы. Теперь мы исправили это, и код работает нормально.

Ответ №2:

Первый пункт: между включительно, и у вас есть некоторое перекрытие в вашем интервале (7 может быть как истинным в первом, так и во втором интервале, поскольку оба они содержат 7)

Так что это должно улучшиться:

  df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                             .when(F.col('col2').between(8,14), 2)
                             .when(F.col('col2').between(15,21), 3)
                             .when(F.col('col2').between(22,28), 4)
                             .otherwise(5))
 

Но также при работе с несколькими F.when() у меня меньше проблем, если вложить их внутрь .otherwise(F.when()) , как показано ниже:

  df = df.withColumn('col1', F.when(F.col('col2').between(1,7), 1)
                             .otherwise(F.when(F.col('col2').between(8,14), 2)
                             .otherwise(F.when(F.col('col2').between(15,21), 3)
                             .otherwise(F.when(F.col('col2').between(22,28), 4)
                             .otherwise(5)))))
 

Комментарии:

1. Спасибо, что ответили. Мы обновили код, удалив «между» и используя > и > Мы также поняли, что числа возвращали 8-значное десятичное число, которое не попадало в вывод. Когда вывод равен 7.000009, условие переходит к следующему оператору. Мы округлили результаты, и теперь код ведет себя так, как ожидалось.