Исключение IllegalArgumentException: не выполнено требование: векторы rawPredictionCol должны иметь длину=2, но получили 3 при тестировании модели в Apache Spark

#java #apache-spark #apache-spark-ml #apache-spark-dataset

Вопрос:

Я пытаюсь создать модель и оценить ее в Apache Spark 3.1.1 с помощью алгоритма OneR. У меня есть .csv файл с нормализованными данными (все значения есть double , но некоторые значения очень близки к 0).

Я читал основное руководство MLlib по OneVsRest, и код очень похож на этот:

 SparkSession session = SparkSession
                .builder()
                .appName("Spark test")
                .master("local")
                .getOrCreate();

JavaRDD<LabeledPoint> data = loadData(session, "path.csv");

LogisticRegression logisticRegression = new LogisticRegression().setMaxIter(20);
OneVsRest oneR = new OneVsRest().setClassifier(logisticRegression);

BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
MulticlassMetrics metrics = new MulticlassMetrics(data.rdd());
MulticlassClassificationEvaluator multiEvaluator = new MulticlassClassificationEvaluator()
        .setMetricName("accuracy");
        
JavaRDD<LabeledPoint>[] javaRDDS = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingRDD = javaRDDS[0], testRDD = javaRDDS[1];
Dataset<Row> trainingDataset = session.createDataFrame(trainingRDD, LabeledPoint.class);
Dataset<Row> testDataset = session.createDataFrame(testRDD, LabeledPoint.class);

OneVsRestModel oneRModel = oneR.fit(trainingDataset);
Dataset<Row> oneRPredictions = oneRModel.transform(testDataset).select("prediction", "label");

double oneRAcc = evaluator.evaluate(oneRPredictions);
System.out.println("OneR: rn");
System.out.println("Accuracy: "   oneRAcc);
System.out.println("--------------------------------");

session.close();
 

Этот код создает исключение:

 Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: rawPredictionCol vectors must have length=2, but got 3
    at scala.Predef$.require(Predef.scala:281)
    at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.$anonfun$getMetrics$1(BinaryClassificationEvaluator.scala:126)
    at scala.runtime.java8.JFunction1$mcVI$sp.apply(JFunction1$mcVI$sp.java:23)
    at scala.Option.foreach(Option.scala:407)
    at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.getMetrics(BinaryClassificationEvaluator.scala:126)
    at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate(BinaryClassificationEvaluator.scala:100)
    at Classification.main(Classification.java:64)
 

Почему этот код не работает? Я думал , что проблема в .select("prediction", "label") том, что я не знаю, что у меня будет Dataset<Row> после transform , но vectors must have length of 2 not 3 это странно. Я пытаюсь классифицировать в нескольких классах с 3 классами.

Редактировать

Я использовал BinaryClassificationEvaluator evaluator вместо MulticlassClassificationEvaluator multiEvaluator этого по ошибке. Теперь сообщение об ошибке имеет смысл.