Метод сравнения нарушает его генеральный контракт! при попытке запустить модель дерева решений

#python #pyspark #decision-tree

Вопрос:

Я нашел много вопросов по этому поводу, но, поскольку я не делаю никакого явного сравнения, я решил обратиться за дополнительной помощью… Я пытаюсь запустить модель дерева решений, но сталкиваюсь со следующей ошибкой:

org.apache.spark.SparkException: Job aborted due to stage failure: Task 4 in stage 3568.0 failed 4 times, most recent failure: Lost task 4.3 in stage 3568.0 (TID 5012) (10.132.136.68 executor driver): java.lang.IllegalArgumentException: Comparison method violates its general contract!

Вот полный журнал:

 <command-4092656683152053> in <module>
      1 dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', seed=2021)
----> 2 dtModel = dt.fit(training)
      3 predictions = dtModel.transform(testing)
      4 predictions.select('segmento', 'qtd_cat_1_antes', 'label', 'rawPrediction', 'prediction', 'probability').show(10)

/databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/_pyspark.py in patched_method(self, *args, **kwargs)
     28             call_succeeded = False
     29             try:
---> 30                 result = original_method(self, *args, **kwargs)
     31                 call_succeeded = True
     32                 return result

/databricks/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
    159                 return self.copy(params)._fit(dataset)
    160             else:
--> 161                 return self._fit(dataset)
    162         else:
    163             raise ValueError("Params must be either a param map or a list/tuple of param maps, "

/databricks/spark/python/pyspark/ml/wrapper.py in _fit(self, dataset)
    333 
    334     def _fit(self, dataset):
--> 335         java_model = self._fit_java(dataset)
    336         model = self._create_model(java_model)
    337         return self._copyValues(model)

/databricks/spark/python/pyspark/ml/wrapper.py in _fit_java(self, dataset)
    330         """
    331         self._transfer_params_to_java()
--> 332         return self._java_obj.fit(dataset._jdf)
    333 
    334     def _fit(self, dataset):

/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1302 
   1303         answer = self.gateway_client.send_command(command)
-> 1304         return_value = get_return_value(
   1305             answer, self.gateway_client, self.target_id, self.name)
   1306 

/databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
    115     def deco(*a, **kw):
    116         try:
--> 117             return f(*a, **kw)
    118         except py4j.protocol.Py4JJavaError as e:
    119             converted = convert_exception(e.java_exception)

/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    324             value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
    325             if answer[1] == REFERENCE_TYPE:
--> 326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.n".
    328                     format(target_id, ".", name), value)

Py4JJavaError: An error occurred while calling o10045.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 4 in stage 3568.0 failed 4 times, most recent failure: Lost task 4.3 in stage 3568.0 (TID 5012) (10.132.136.68 executor driver): java.lang.IllegalArgumentException: Comparison method violates its general contract!
    at org.apache.spark.util.collection.TimSort$SortState.mergeLo(TimSort.java:803)
    at org.apache.spark.util.collection.TimSort$SortState.mergeAt(TimSort.java:534)
    at org.apache.spark.util.collection.TimSort$SortState.mergeCollapse(TimSort.java:462)
    at org.apache.spark.util.collection.TimSort$SortState.access$200(TimSort.java:325)
    at org.apache.spark.util.collection.TimSort.sort(TimSort.java:153)
    at org.apache.spark.util.collection.Sorter.sort(Sorter.scala:37)
    at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsForContinuousFeatureEdge(RandomForest.scala:1357)
    at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsForContinuousFeature(RandomForest.scala:1177)
    at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findSplitsBySorting$11(RandomForest.scala:1073)
    at scala.collection.Iterator$anon$10.next(Iterator.scala:459)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
    at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
    at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
    at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
    at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
    at scala.collection.AbstractIterator.to(Iterator.scala:1429)
    at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
    at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
    at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1429)
    at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
    at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
    at scala.collection.AbstractIterator.toArray(Iterator.scala:1429)
    at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1038)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:150)
    at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:119)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.Task.run(Task.scala:91)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:803)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1643)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:806)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:662)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
    at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2765)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2712)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2706)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2706)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1255)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1255)
    at scala.Option.foreach(Option.scala:407)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1255)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2973)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2914)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2902)
    at org.apache.spark.util.EventLoop$anon$1.run(EventLoop.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1028)
    at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2446)
    at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1036)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:419)
    at org.apache.spark.rdd.RDD.collect(RDD.scala:1034)
    at org.apache.spark.rdd.PairRDDFunctions.$anonfun$collectAsMap$1(PairRDDFunctions.scala:737)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:419)
    at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:736)
    at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsBySorting(RandomForest.scala:1072)
    at org.apache.spark.ml.tree.impl.RandomForest$.findSplits(RandomForest.scala:1040)
    at org.apache.spark.ml.tree.impl.RandomForest$.run(RandomForest.scala:284)
    at org.apache.spark.ml.classification.DecisionTreeClassifier.$anonfun$train$1(DecisionTreeClassifier.scala:136)
    at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:284)
    at scala.util.Try$.apply(Try.scala:213)
    at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:284)
    at org.apache.spark.ml.classification.DecisionTreeClassifier.train(DecisionTreeClassifier.scala:115)
    at org.apache.spark.ml.classification.DecisionTreeClassifier.train(DecisionTreeClassifier.scala:47)
    at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
    at py4j.Gateway.invoke(Gateway.java:295)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:251)
    at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.IllegalArgumentException: Comparison method violates its general contract!
    at org.apache.spark.util.collection.TimSort$SortState.mergeLo(TimSort.java:803)
    at org.apache.spark.util.collection.TimSort$SortState.mergeAt(TimSort.java:534)
    at org.apache.spark.util.collection.TimSort$SortState.mergeCollapse(TimSort.java:462)
    at org.apache.spark.util.collection.TimSort$SortState.access$200(TimSort.java:325)
    at org.apache.spark.util.collection.TimSort.sort(TimSort.java:153)
    at org.apache.spark.util.collection.Sorter.sort(Sorter.scala:37)
    at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsForContinuousFeatureEdge(RandomForest.scala:1357)
    at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsForContinuousFeature(RandomForest.scala:1177)
    at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findSplitsBySorting$11(RandomForest.scala:1073)
    at scala.collection.Iterator$anon$10.next(Iterator.scala:459)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
    at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
    at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
    at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
    at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
    at scala.collection.AbstractIterator.to(Iterator.scala:1429)
    at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
    at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
    at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1429)
    at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
    at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
    at scala.collection.AbstractIterator.toArray(Iterator.scala:1429)
    at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1038)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:150)
    at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:119)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.scheduler.Task.run(Task.scala:91)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:803)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1643)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:806)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:662)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    ... 1 more
 

Мой код сбит:

 base = conta_contrato_final.drop('qtd_notas_at_dps')

base = base.na.replace(['', '*'], 'VAZIO')

cols = base.columns
categoricalColumns = ["status_campo", "segmento", "classe_social"]
stages = []

for categoricalCol in categoricalColumns:
    stringIndexer = StringIndexer(inputCol = categoricalCol, outputCol = categoricalCol   'Index', handleInvalid="skip")
    encoder = OneHotEncoder(inputCols=[stringIndexer.getOutputCol()], outputCols=[categoricalCol   "classVec"])
    stages  = [stringIndexer, encoder]

label_stringIdx = StringIndexer(inputCol = 'flag_servico_dps', outputCol = 'label')
stages  = [label_stringIdx]

numericCols = ['qtd_residencias', 'tempo_de_casa', 'idade', 'qtd_inadimplente']
assemblerInputs = [c   "classVec" for c in categoricalColumns]   numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features").setHandleInvalid("keep")
stages  = [assembler]

pipeline = Pipeline(stages = stages)
pipelineModel = pipeline.fit(base)
base = pipelineModel.transform(base)
selectedCols = ['label', 'features']   cols
base = base.select(selectedCols)
base.printSchema()

(training,testing) = base.randomSplit([0.7,0.3], seed=1234)

dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', seed=2021)
dtModel = dt.fit(training)
 

Но ошибка возникает только в последней строке… У кого-нибудь есть хоть малейшее представление о том, почему это происходит?

Комментарии:

1. Откуда берется training переменная?

2. о, извините, я забыл добавить строку, в которой я определяю эту переменную. Я добавил это сейчас

3. Поставьте что-нибудь вроде print(type(training ))` прежде чем использовать его или использовать свой отладчик для его проверки.