#java #machine-learning #seq2seq #dl4j #computation-graph
#java #машинное обучение #seq2seq #dl4j #вычисление-график
Вопрос:
Я пытаюсь реализовать модель прогнозирования Seq2Seq в DL4J. В конечном итоге я хочу использовать временной ряд INPUT_SIZE
точек данных для прогнозирования следующего временного ряда OUTPUT_SIZE
точек данных, используя этот тип модели. Каждая точка данных имеет numFeatures
особенности. Теперь в DL4J есть некоторый пример кода, объясняющий, как реализовать очень простую модель Seq2Seq. Я добился некоторого прогресса в расширении их примера для своих собственных нужд; приведенная ниже модель компилируется, но сделанные ею прогнозы бессмысленны.
ComputationGraphConfiguration configuration = new
NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.25))
.seed(42)
.graphBuilder()
.addInputs("in_data", "last_in")
.setInputTypes(InputType.recurrent(numFeatures), InputType.recurrent(numFeatures))
//The inputs to the encoder will have size = minibatch x featuresize x timesteps
//Note that the network only knows of the feature vector size. It does not know how many time steps unless it sees an instance of the data
.addLayer("encoder", new LSTM.Builder().nIn(numFeatures).nOut(hiddenLayerWidth).activation(Activation.LEAKYRELU).build(), "in_data")
//Create a vertex indicating the very last time step of the encoder layer needs to be directed to other places in the comp graph
.addVertex("lastTimeStep", new LastTimeStepVertex("in_data"), "encoder")
//Create a vertex that allows the duplication of 2d input to a 3d input
//In this case the last time step of the encoder layer (viz. 2d) is duplicated to the length of the timeseries "sumOut" which is an input to the comp graph
//Refer to the javadoc for more detail
.addVertex("duplicateTimeStep", new DuplicateToTimeSeriesVertex("last_in"), "lastTimeStep")
//The inputs to the decoder will have size = size of output of last timestep of encoder (numHiddenNodes) size of the other input to the comp graph,sumOut (feature vector size)
.addLayer("decoder", new LSTM.Builder().nIn(numFeatures hiddenLayerWidth).nOut(hiddenLayerWidth).activation(Activation.LEAKYRELU).build(), "last_in","duplicateTimeStep")
.addLayer("output", new RnnOutputLayer.Builder().nIn(hiddenLayerWidth).nOut(numFeatures).activation(Activation.LEAKYRELU).lossFunction(LossFunctions.LossFunction.MSE).build(), "decoder")
.setOutputs("output")
.build();
ComputationGraph net = new ComputationGraph(configuration);
net.init();
net.setListeners(new ScoreIterationListener(1));
Способ, которым я структурирую свои входные / помеченные данные, заключается в том, что у меня входные данные разделены между первыми INPUT_SIZE - 1
наблюдениями за временными рядами (соответствующими in_data
входным данным в ComputationGraph), а затем последним наблюдением за временными рядами (соответствующим lastIn
входным данным). Метки — это один временной шаг в будущее; чтобы делать прогнозы, я просто вызываю net.output()
OUTPUT_SIZE
times, чтобы получить все нужные мне прогнозы. Чтобы лучше видеть это, вот как я инициализирую свой ввод / метки:
INDArray[] input = new INDArray[] {Nd4j.zeros(batchSize, numFeatures, INPUT_SIZE - 1), Nd4j.zeros(batchSize, numFeatures, 1)};
INDArray[] labels = new INDArray[] {Nd4j.zeros(batchSize, numFeatures, 1)};
Я полагаю, что моя ошибка связана с ошибкой в архитектуре моего графика вычислений, а не с тем, как я готовлю свои данные / делаю прогнозы / что-то еще, поскольку я выполнял другие мини-проекты с более простыми архитектурами и у меня не было проблем.
Мои данные нормализованы, чтобы иметь среднее значение 0 и std. отклонение на 1. Таким образом, большинство записей должно быть около 0, однако большинство прогнозов, которые я получаю, представляют собой значения с абсолютным значением, намного большим нуля (порядка 10-100 секунд). Это явно неверно. Я работал над этим в течение некоторого времени и не смог найти проблему; любые предложения о том, как это исправить, были бы высоко оценены.
Другие ресурсы, которые я использовал: Пример модели Seq2Seq можно найти здесь, начиная со строки 88. Документацию ComputationGraph можно найти здесь; Я подробно прочитал это, чтобы посмотреть, смогу ли я найти ошибку, но безрезультатно.
Комментарии:
1. вы определили ошибку?