#java #apache-spark #apache-spark-ml #apache-spark-dataset
Вопрос:
Я пытаюсь создать модель и оценить ее в Apache Spark 3.1.1 с помощью алгоритма OneR. У меня есть .csv
файл с нормализованными данными (все значения есть double
, но некоторые значения очень близки к 0).
Я читал основное руководство MLlib по OneVsRest, и код очень похож на этот:
SparkSession session = SparkSession
.builder()
.appName("Spark test")
.master("local")
.getOrCreate();
JavaRDD<LabeledPoint> data = loadData(session, "path.csv");
LogisticRegression logisticRegression = new LogisticRegression().setMaxIter(20);
OneVsRest oneR = new OneVsRest().setClassifier(logisticRegression);
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
MulticlassMetrics metrics = new MulticlassMetrics(data.rdd());
MulticlassClassificationEvaluator multiEvaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy");
JavaRDD<LabeledPoint>[] javaRDDS = data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingRDD = javaRDDS[0], testRDD = javaRDDS[1];
Dataset<Row> trainingDataset = session.createDataFrame(trainingRDD, LabeledPoint.class);
Dataset<Row> testDataset = session.createDataFrame(testRDD, LabeledPoint.class);
OneVsRestModel oneRModel = oneR.fit(trainingDataset);
Dataset<Row> oneRPredictions = oneRModel.transform(testDataset).select("prediction", "label");
double oneRAcc = evaluator.evaluate(oneRPredictions);
System.out.println("OneR: rn");
System.out.println("Accuracy: " oneRAcc);
System.out.println("--------------------------------");
session.close();
Этот код создает исключение:
Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: rawPredictionCol vectors must have length=2, but got 3
at scala.Predef$.require(Predef.scala:281)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.$anonfun$getMetrics$1(BinaryClassificationEvaluator.scala:126)
at scala.runtime.java8.JFunction1$mcVI$sp.apply(JFunction1$mcVI$sp.java:23)
at scala.Option.foreach(Option.scala:407)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.getMetrics(BinaryClassificationEvaluator.scala:126)
at org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate(BinaryClassificationEvaluator.scala:100)
at Classification.main(Classification.java:64)
Почему этот код не работает? Я думал , что проблема в .select("prediction", "label")
том, что я не знаю, что у меня будет Dataset<Row>
после transform
, но vectors must have length of 2 not 3
это странно. Я пытаюсь классифицировать в нескольких классах с 3 классами.
Редактировать
Я использовал BinaryClassificationEvaluator evaluator
вместо MulticlassClassificationEvaluator multiEvaluator
этого по ошибке. Теперь сообщение об ошибке имеет смысл.