Deeplearning4J RNN 训练:预期 RNN 层的 3D 输入异常,得到 2

问题描述 投票:0回答:1

使用以下代码(使用不同的参数调整了几个小时),我不断收到异常 java.lang.IllegalStateException:预期 RNN 层的 3D 输入,得到 2 我想要完成的是训练 RNN 以基于一堆训练序列来预测序列中的下一个值(双精度)。我使用简单的随机数据生成器生成特征,并使用序列中的最后一个 val 作为训练标签(在本例中为预测值)。

我的代码:

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import java.util.Random;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class RnnPredictionExample {

  public static void main(String[] args) {
    //generate 100 rows of data that have 50 columns/features each
    DataSet trainingdata = getRandomDataset(100, 51, 1);
    // Train the RNN model...
    MultiLayerNetwork trainedModel = trainRnnModel(trainingdata, 50, 10, 1);

    // generate a sequence, and Perform next value prediction on the sequence
    double[] inputSequence = randomData(50, 1);
    double predictedValue = predictNextValue(trainedModel, inputSequence);
    System.out.println("Predicted Next Value: " + predictedValue);
  }

  public static MultiLayerNetwork trainRnnModel(DataSet trainingdataandlabels, int sequenceLength, int numHiddenUnits, int numEpochs) {
    // ... Create network configuration ...

    // Create and initialize the network
    MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
            //.seed(123)
            .list()
            .layer(new LSTM.Builder()
                    .nIn(1)
                    .nOut(50)
                    .build()
            )
            .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                    .activation(Activation.IDENTITY)
                    .nIn(50)
                    .nOut(1) // Set nOut to 1
                    .build()
            )
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(config);
    net.init();

    for (int i = 0; i < numEpochs; i++) {
      net.fit(trainingdataandlabels);
    }

    return net;
  }

  public static double predictNextValue(MultiLayerNetwork trainedModel, double[] inputSequence) {
    INDArray inputArray = Nd4j.create(inputSequence);
    INDArray predicted = trainedModel.rnnTimeStep(inputArray);

    // Predicted value is the last element of the predicted sequence
    return predicted.getDouble(predicted.length() - 1);
  }

  static Random random = new Random();

  public static double[] randomData(int length, int rangeMultiplier) {

    double[] out = new double[length];
    for (int i = 0; i < out.length; i++) {
      out[i] = random.nextDouble() * rangeMultiplier;
    }
    return out;
  }

  //assumes labes is the last val in each sequence
  public static DataSet getRandomDataset(int numRows, int lengthEach, int rangeMultiplier) {
    INDArray training = Nd4j.zeros(numRows, lengthEach - 1);
    INDArray labels = Nd4j.zeros(numRows, 1);

    for (int i = 0; i < numRows; i++) {
      double[] randomData = randomData(lengthEach, rangeMultiplier);
      for (int j = 0; j < randomData.length - 1; j++) {
        training.putScalar(new int[]{i, j}, randomData[j]);
      }
      labels.putScalar(new int[]{i, 0}, randomData[randomData.length - 1]);

    }

    return new DataSet(training, labels);

  }
}

谢谢

java deep-learning recurrent-neural-network deeplearning4j dl4j
1个回答
0
投票

在 DL4J 中,DataSet 对象中始终有一批数据。 这意味着,如果您的训练数据具有形状

(n, f)
,它将被解释为
n
示例,每个示例具有
f
特征。

RNN 期望每个示例有几个步骤,这意味着您的数据需要具有

(n, f, t)
的形状,这样您就有
n
示例、
f
特征和
t
步骤。

我想您需要

n=1
示例的批量大小。因此,解决你的困境的最简单的解决方案是调用
Nd4j.expandDims(arr, 0)
为其提供额外的奇异维度。

© www.soinside.com 2019 - 2024. All rights reserved.