#python #pyspark #decision-tree
Вопрос:
Я нашел много вопросов по этому поводу, но, поскольку я не делаю никакого явного сравнения, я решил обратиться за дополнительной помощью… Я пытаюсь запустить модель дерева решений, но сталкиваюсь со следующей ошибкой:
org.apache.spark.SparkException: Job aborted due to stage failure: Task 4 in stage 3568.0 failed 4 times, most recent failure: Lost task 4.3 in stage 3568.0 (TID 5012) (10.132.136.68 executor driver): java.lang.IllegalArgumentException: Comparison method violates its general contract!
Вот полный журнал:
<command-4092656683152053> in <module>
1 dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', seed=2021)
----> 2 dtModel = dt.fit(training)
3 predictions = dtModel.transform(testing)
4 predictions.select('segmento', 'qtd_cat_1_antes', 'label', 'rawPrediction', 'prediction', 'probability').show(10)
/databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/_pyspark.py in patched_method(self, *args, **kwargs)
28 call_succeeded = False
29 try:
---> 30 result = original_method(self, *args, **kwargs)
31 call_succeeded = True
32 return result
/databricks/spark/python/pyspark/ml/base.py in fit(self, dataset, params)
159 return self.copy(params)._fit(dataset)
160 else:
--> 161 return self._fit(dataset)
162 else:
163 raise ValueError("Params must be either a param map or a list/tuple of param maps, "
/databricks/spark/python/pyspark/ml/wrapper.py in _fit(self, dataset)
333
334 def _fit(self, dataset):
--> 335 java_model = self._fit_java(dataset)
336 model = self._create_model(java_model)
337 return self._copyValues(model)
/databricks/spark/python/pyspark/ml/wrapper.py in _fit_java(self, dataset)
330 """
331 self._transfer_params_to_java()
--> 332 return self._java_obj.fit(dataset._jdf)
333
334 def _fit(self, dataset):
/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
1302
1303 answer = self.gateway_client.send_command(command)
-> 1304 return_value = get_return_value(
1305 answer, self.gateway_client, self.target_id, self.name)
1306
/databricks/spark/python/pyspark/sql/utils.py in deco(*a, **kw)
115 def deco(*a, **kw):
116 try:
--> 117 return f(*a, **kw)
118 except py4j.protocol.Py4JJavaError as e:
119 converted = convert_exception(e.java_exception)
/databricks/spark/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
325 if answer[1] == REFERENCE_TYPE:
--> 326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.n".
328 format(target_id, ".", name), value)
Py4JJavaError: An error occurred while calling o10045.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 4 in stage 3568.0 failed 4 times, most recent failure: Lost task 4.3 in stage 3568.0 (TID 5012) (10.132.136.68 executor driver): java.lang.IllegalArgumentException: Comparison method violates its general contract!
at org.apache.spark.util.collection.TimSort$SortState.mergeLo(TimSort.java:803)
at org.apache.spark.util.collection.TimSort$SortState.mergeAt(TimSort.java:534)
at org.apache.spark.util.collection.TimSort$SortState.mergeCollapse(TimSort.java:462)
at org.apache.spark.util.collection.TimSort$SortState.access$200(TimSort.java:325)
at org.apache.spark.util.collection.TimSort.sort(TimSort.java:153)
at org.apache.spark.util.collection.Sorter.sort(Sorter.scala:37)
at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsForContinuousFeatureEdge(RandomForest.scala:1357)
at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsForContinuousFeature(RandomForest.scala:1177)
at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findSplitsBySorting$11(RandomForest.scala:1073)
at scala.collection.Iterator$anon$10.next(Iterator.scala:459)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
at scala.collection.AbstractIterator.to(Iterator.scala:1429)
at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1429)
at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
at scala.collection.AbstractIterator.toArray(Iterator.scala:1429)
at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1038)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:150)
at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:119)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.Task.run(Task.scala:91)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:803)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1643)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:806)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:662)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2765)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2712)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2706)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2706)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1255)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1255)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1255)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2973)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2914)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2902)
at org.apache.spark.util.EventLoop$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1028)
at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2446)
at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1036)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:419)
at org.apache.spark.rdd.RDD.collect(RDD.scala:1034)
at org.apache.spark.rdd.PairRDDFunctions.$anonfun$collectAsMap$1(PairRDDFunctions.scala:737)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:419)
at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:736)
at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsBySorting(RandomForest.scala:1072)
at org.apache.spark.ml.tree.impl.RandomForest$.findSplits(RandomForest.scala:1040)
at org.apache.spark.ml.tree.impl.RandomForest$.run(RandomForest.scala:284)
at org.apache.spark.ml.classification.DecisionTreeClassifier.$anonfun$train$1(DecisionTreeClassifier.scala:136)
at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:284)
at scala.util.Try$.apply(Try.scala:213)
at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:284)
at org.apache.spark.ml.classification.DecisionTreeClassifier.train(DecisionTreeClassifier.scala:115)
at org.apache.spark.ml.classification.DecisionTreeClassifier.train(DecisionTreeClassifier.scala:47)
at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
at py4j.Gateway.invoke(Gateway.java:295)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:251)
at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.IllegalArgumentException: Comparison method violates its general contract!
at org.apache.spark.util.collection.TimSort$SortState.mergeLo(TimSort.java:803)
at org.apache.spark.util.collection.TimSort$SortState.mergeAt(TimSort.java:534)
at org.apache.spark.util.collection.TimSort$SortState.mergeCollapse(TimSort.java:462)
at org.apache.spark.util.collection.TimSort$SortState.access$200(TimSort.java:325)
at org.apache.spark.util.collection.TimSort.sort(TimSort.java:153)
at org.apache.spark.util.collection.Sorter.sort(Sorter.scala:37)
at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsForContinuousFeatureEdge(RandomForest.scala:1357)
at org.apache.spark.ml.tree.impl.RandomForest$.findSplitsForContinuousFeature(RandomForest.scala:1177)
at org.apache.spark.ml.tree.impl.RandomForest$.$anonfun$findSplitsBySorting$11(RandomForest.scala:1073)
at scala.collection.Iterator$anon$10.next(Iterator.scala:459)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
at scala.collection.TraversableOnce.to(TraversableOnce.scala:315)
at scala.collection.TraversableOnce.to$(TraversableOnce.scala:313)
at scala.collection.AbstractIterator.to(Iterator.scala:1429)
at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:307)
at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:307)
at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1429)
at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:294)
at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:288)
at scala.collection.AbstractIterator.toArray(Iterator.scala:1429)
at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1038)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$3(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.$anonfun$runTask$1(ResultTask.scala:75)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:55)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:150)
at org.apache.spark.scheduler.Task.$anonfun$run$1(Task.scala:119)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.scheduler.Task.run(Task.scala:91)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$13(Executor.scala:803)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1643)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:806)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.ExecutorFrameProfiler$.record(ExecutorFrameProfiler.scala:110)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:662)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more
Мой код сбит:
base = conta_contrato_final.drop('qtd_notas_at_dps')
base = base.na.replace(['', '*'], 'VAZIO')
cols = base.columns
categoricalColumns = ["status_campo", "segmento", "classe_social"]
stages = []
for categoricalCol in categoricalColumns:
stringIndexer = StringIndexer(inputCol = categoricalCol, outputCol = categoricalCol 'Index', handleInvalid="skip")
encoder = OneHotEncoder(inputCols=[stringIndexer.getOutputCol()], outputCols=[categoricalCol "classVec"])
stages = [stringIndexer, encoder]
label_stringIdx = StringIndexer(inputCol = 'flag_servico_dps', outputCol = 'label')
stages = [label_stringIdx]
numericCols = ['qtd_residencias', 'tempo_de_casa', 'idade', 'qtd_inadimplente']
assemblerInputs = [c "classVec" for c in categoricalColumns] numericCols
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features").setHandleInvalid("keep")
stages = [assembler]
pipeline = Pipeline(stages = stages)
pipelineModel = pipeline.fit(base)
base = pipelineModel.transform(base)
selectedCols = ['label', 'features'] cols
base = base.select(selectedCols)
base.printSchema()
(training,testing) = base.randomSplit([0.7,0.3], seed=1234)
dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', seed=2021)
dtModel = dt.fit(training)
Но ошибка возникает только в последней строке… У кого-нибудь есть хоть малейшее представление о том, почему это происходит?
Комментарии:
1. Откуда берется
training
переменная?2. о, извините, я забыл добавить строку, в которой я определяю эту переменную. Я добавил это сейчас
3. Поставьте что-нибудь вроде
print(type(training
))` прежде чем использовать его или использовать свой отладчик для его проверки.