Press "Enter" to skip to content

How to build a neural network in Java

Artificial neural networks are a form of deep learning and one of the mainstays of modern AI. The best way to really understand how these things work is to build one. This article will be a hands-on introduction to building and training a neural network in Java.

See my previous article, Machine Learning Styles: An Introduction to Neural Networks for an overview of how artificial neural networks work. Our example for this article is by no means a production grade system; instead, it shows all the major components in a demo designed to be easy to understand.

A basic neural network

A neural network is a graph of nodes called neurons. The neuron is the basic unit of computation. It receives inputs and processes them using a weight-per-input, bias-per-node, and final-function processor algorithm (known as the activation function). You can see a two-input neuron illustrated in Figure 1.

Diagram of a neural network for machine learning. IDG

Figure 1. A neuron with two inputs in a neural network.

This model has a wide range of variability, but we’ll use this exact setup for the demo.

Our first step is to model a Neuron class that will hold these values. you can see the Neuron class in Listing 1. Note that this is an early version of the class. It will change as we add functionality.

Listing 1. A simple Neuron class


class Neuron {
    Random random = new Random();
    private Double bias = random.nextDouble(-1, 1); 
    public Double weight1 = random.nextDouble(-1, 1); 
    private Double weight2 = random.nextDouble(-1, 1);
   
    public double compute(double input1, double input2){
      double preActivation = (this.weight1 * input1) + (this.weight2 * input2) + this.bias;
      double output = Util.sigmoid(preActivation);
      return output;
    }
  }

You can see that the Neuron The class is quite simple, with three members: bias, weight1and weight2. Each member is initialized to a random double between -1 and 1.

When we calculate the output of the neuron, we follow the algorithm shown in Figure 1: we multiply each input by its weight, plus the bias: input1 * weight1 + input2 * weight2 + bias. This gives us the raw calculation (i.e., preActivation) that we execute via the activation function. In this case, we use the sigmoid activation functionwhich compresses the values ​​in a range from -1 to 1. Listing 2 shows the Util.sigmoid() static method.

Listing 2. Sigmoid activation function


public class Util {
  public static double sigmoid(double in){
    return 1 / (1 + Math.exp(-in));
  }
}

Now that we’ve seen how neurons work, let’s put some neurons in a network. we will use a Network class with a list of neurons as shown in Listing 3.

Listing 3. The neural network class


class Network {
    List<Neuron> neurons = Arrays.asList(
      new Neuron(), new Neuron(), new Neuron(), /* input nodes */
      new Neuron(), new Neuron(),               /* hidden nodes */
      new Neuron());                            /* output node */
    }
}

Although the list of neurons is one-dimensional, we will connect them during use so that they form a network. The first three neurons are inputs, the second and third are hidden, and the last one is the output node.

make a prediction

Now, let’s use the network to make a prediction. We are going to use a simple data set of two input integers and a response format of 0 to 1. My example uses a combination of weight and height to guess a person’s gender based on the assumption that more weight and height indicate that a person is male. . We could use the same formula for any two-factor single output probability. We could think of the input as a vector and therefore the general function of neurons as the transformation of a vector into a scalar value.

The network prediction phase looks like Listing 4.

Listing 4. Network prediction


public Double predict(Integer input1, Integer input2){
  return neurons.get(5).compute(
    neurons.get(4).compute(
      neurons.get(2).compute(input1, input2),
      neurons.get(1).compute(input1, input2)
    ),
    neurons.get(3).compute(
      neurons.get(1).compute(input1, input2),
      neurons.get(0).compute(input1, input2)
    )
  );
}

Listing 4 shows that the two inputs are fed to the first three neurons, whose output is then piped to neurons 4 and 5, which in turn feed the output neuron. This process is known as a feedback.

Now we could ask the network to make a prediction, as shown in Listing 5.

Listing 5. Get a prediction


Network network = new Network();
Double prediction = network.predict(Arrays.asList(115, 66));
System.out.println(“prediction: “ + prediction);

Sure we would get something, but it would be the result of random weights and biases. For a real prediction, we must first train the network.

train the network

The training of a neural network follows a process known as backpropagation, which I will present in more depth in my next article. Backpropagation basically pushes changes back through the network to cause the output to move toward the desired target.

We can do backpropagation using function differencing, but for our example, we’re going to do something different. We will give each neuron the ability to “mutate”. In each round of training (known as a epoch), we choose a different neuron to make a small random adjustment to one of its properties (weight1, weight2either bias) and then check if the results improved. If the results improved, we will maintain that change with a remember() method. If the results worsen, we will abandon the change with a forget() method.

We will add members of the class (old* versions of weights and biases) to track changes. you can see the mutate(), remember()and forget() methods in Listing 6.

Also Read:  Why observability in dataops? | InfoWorld

Listing 6. mutate(), remember(), forget()


public class Neuron() {
  private Double oldBias = random.nextDouble(-1, 1), bias = random.nextDouble(-1, 1); 
 public Double oldWeight1 = random.nextDouble(-1, 1), weight1 = random.nextDouble(-1, 1); 
 private Double oldWeight2 = random.nextDouble(-1, 1), weight2 = random.nextDouble(-1, 1);
public void mutate(){
      int propertyToChange = random.nextInt(0, 3);
      Double changeFactor = random.nextDouble(-1, 1);
      if (propertyToChange == 0){ 
        this.bias += changeFactor; 
      } else if (propertyToChange == 1){ 
        this.weight1 += changeFactor; 
      } else { 
        this.weight2 += changeFactor; 
      };
    }
    public void forget(){
      bias = oldBias;
      weight1 = oldWeight1;
      weight2 = oldWeight2;
    }
    public void remember(){
      oldBias = bias;
      oldWeight1 = weight1;
      oldWeight2 = weight2;
    }
}

Simple enough: The mutate() The method picks a property at random and a value between -1 and 1 at random, then changes the property. He forget() The method returns that change to the previous value. He remember() The method copies the new value to the buffer.

Now, to make use of our Neuronnew capabilities, we added a train() method for Networkas shown in Listing 7.

Listing 7. The Network.train() method


public void train(List<List<Integer>> data, List<Double> answers){
  Double bestEpochLoss = null;
  for (int epoch = 0; epoch < 1000; epoch++){
    // adapt neuron
    Neuron epochNeuron = neurons.get(epoch % 6);
    epochNeuron.mutate(this.learnFactor);

    List<Double> predictions = new ArrayList<Double>();
    for (int i = 0; i < data.size(); i++){
      predictions.add(i, this.predict(data.get(i).get(0), data.get(i).get(1)));
    }
    Double thisEpochLoss = Util.meanSquareLoss(answers, predictions);

    if (bestEpochLoss == null){
      bestEpochLoss = thisEpochLoss;
        epochNeuron.remember();
      } else {
    if (thisEpochLoss < bestEpochLoss){
      bestEpochLoss = thisEpochLoss;
      epochNeuron.remember();
    } else {
      epochNeuron.forget();
    }
  }
}

He train() The method iterates a thousand times over the data and answers Lists in the argument. These are training sets of the same size; data contains input values ​​and answers has its well-known, good answers. The method then iterates over them and gets a value of how well the network guessed the result compared to the known correct answers. It then mutates a random neuron, keeping the change if a new test reveals that it was a better prediction.

Check the results

We can verify the results using the mean square error (MSE) formula, a common way to test a set of results in a neural network. You can see our MSE function in Listing 8.

Listing 8. MSE function


public static Double meanSquareLoss(List<Double> correctAnswers,   List<Double> predictedAnswers){
  double sumSquare = 0;
  for (int i = 0; i < correctAnswers.size(); i++){
    double error = correctAnswers.get(i) - predictedAnswers.get(i);
    sumSquare += (error * error);
  }
  return sumSquare / (correctAnswers.size());
}

tune up the system

Now all that’s left is to put some training data in the network and test it with more predictions. Listing 9 shows how we provide training data.

Listing 9. Training data


List<List<Integer>> data = new ArrayList<List<Integer>>();
data.add(Arrays.asList(115, 66));
data.add(Arrays.asList(175, 78));
data.add(Arrays.asList(205, 72));
data.add(Arrays.asList(120, 67));
List<Double> answers = Arrays.asList(1.0,0.0,0.0,1.0);  

Network network = new Network();
network.train(data, answers);

In Listing 9, our training data is a list of two-dimensional integer sets (we could think of them as weight and height) and then a list of responses (where 1.0 is female and 0.0 is male).

If we add some logging to the training algorithm, running it will give a result similar to Listing 10.

Listing 10. Trainer registration


// Logging:
if (epoch % 10 == 0) System.out.println(String.format("Epoch: %s | bestEpochLoss: %.15f | thisEpochLoss: %.15f", epoch, bestEpochLoss, thisEpochLoss));

// output:
Epoch: 910 | bestEpochLoss: 0.034404863820424 | thisEpochLoss: 0.034437939546120
Epoch: 920 | bestEpochLoss: 0.033875954196897 | thisEpochLoss: 0.431451026477016
Epoch: 930 | bestEpochLoss: 0.032509260025490 | thisEpochLoss: 0.032509260025490
Epoch: 940 | bestEpochLoss: 0.003092720117159 | thisEpochLoss: 0.003098025397281
Epoch: 950 | bestEpochLoss: 0.002990128276146 | thisEpochLoss: 0.431062364628853
Epoch: 960 | bestEpochLoss: 0.001651762688346 | thisEpochLoss: 0.001651762688346
Epoch: 970 | bestEpochLoss: 0.001637709485751 | thisEpochLoss: 0.001636810460399
Epoch: 980 | bestEpochLoss: 0.001083365453009 | thisEpochLoss: 0.391527869500699
Epoch: 990 | bestEpochLoss: 0.001078338540452 | thisEpochLoss: 0.001078338540452

Listing 10 shows the loss (exactly right error divergence) slowly decreasing; that is, it is getting close to making accurate predictions. All that remains is to see how well our model predicts with real data, as shown in Listing 11.

Listing 11. Prediction


System.out.println("");
System.out.println(String.format("  male, 167, 73: %.10f", network.predict(167, 73)));
System.out.println(String.format("female, 105, 67: %.10", network.predict(105, 67))); 
System.out.println(String.format("female, 120, 72: %.10f | network1000: %.10f", network.predict(120, 72))); 
System.out.println(String.format("  male, 143, 67: %.10f | network1000: %.10f", network.predict(143, 67)));
System.out.println(String.format(" male', 130, 66: %.10f | network: %.10f", network.predict(130, 66)));

In Listing 11, we take our trained network and feed it some data, generating the predictions. We get something like Listing 12.

Listing 12. Trained predictions


  male, 167, 73: 0.0279697143 
female, 105, 67: 0.9075809407 
female, 120, 72: 0.9075808235 
  male, 143, 67: 0.0305401413
  male, 130, 66: network: 0.9009811922

In Listing 12, we see that the network has done a good job with most value pairs (alias vectors). It gives the female data sets an estimate of around .907, which is pretty close to one. Two men show .027 and .030, approaching 0. The outlier male data set (130, 67) is considered probably female, but with less confidence at .900.

Conclusion

There are several ways to adjust the dials on this system. On the one hand, the number of epochs in a training run is an important factor. The more epochs, the more tuned to the data the model becomes. Running more epochs can improve the accuracy of live data fitting training sets, but can also result in overtraining; that is, a model that confidently predicts erroneous results for extreme cases.

Please visit my GitHub repository for the full code for this tutorial, along with some additional bells and whistles.

Copyright © 2023 IDG Communications, Inc.

Be First to Comment

Leave a Reply

Your email address will not be published. Required fields are marked *