Sunday, 19 October 2008

Back Propagation Multi-Output Neural Network C#

In this post, I publish an updated version of my multi layer Perceptron. The motivation for this is based on the great comments received on this blog. New features and improvements include:
  • A graphical representation of network activity
  • The ability to handle multiple network outputs
  • The ability to define any network architecture in a single line of code
  • A cleaner separation of responsibility within the code
In order to put this new network through its paces, I downloaded the classic Iris Plants Database from the UCI Machine Learning Repository. The data set contains three classes of fifty instances each, where each class refers to a type of iris plant. One class is linearly separable from the other two; the latter are not linearly separable from each other. To use this data, save the web page as a text file with a csv file extension. Open with Excel and add three new columns, on the right, filling with 0,0,1 or 0,1,0 or 1,0,0 depending on the species of plant. Delete the species name column and you now have a file which can be fed into your network. The video below shows the graphical output and clearly displays how the network converges on a solution.



There is quite a lot of code so I will cover each class individually.



Unlike my previous back propagation project, this one is a Windows Forms Application, so in Visual Studio begin by creating a new project and delete Form1.cs and Program.cs, which were automatically created for you. Create a new class file for each of the code blocks below.

Graph.cs
This is the container for the project and provides a graphical display of how the network is proceeding with training. In order to use the graphing feature, the penultimate layer of your network must contain two dimensions, which is not as restrictive as is sounds. I have found that being able to visualise network activity helps enormously when tuning your network. Things to look out for are inactivity, meaning the network has become stuck in a local minimum, or erratic activity meaning your learning rate is too high. If at any time you want to restart training, simply hit the space key and the network will reset. I guarantee that once you have tried this, you will never go back to gazing at streams of numbers…

  1. using System;
  2. using System.Drawing;
  3. using System.Windows.Forms;
  4.  
  5. public class Graph : Form
  6. {
  7.     private int iteration;
  8.     private Network network;
  9.     private Brush[] brushes = { Brushes.Red, Brushes.Green, Brushes.Blue };
  10.  
  11.     [STAThread]
  12.     static void Main()
  13.     {
  14.         Application.Run(new Graph());
  15.     }
  16.  
  17.     public Graph()
  18.     {
  19.         int[] dims = { 4, 4, 2, 3 };    // Replace with your network dimensions.
  20.         string file = "iris_data.csv";  // Replace with your input file location.
  21.         network = new Network(dims, file);
  22.         Initialise();
  23.     }
  24.  
  25.     private void Initialise()
  26.     {
  27.         ClientSize = new Size(400, 400);
  28.         SetStyle(ControlStyles.AllPaintingInWmPaint | ControlStyles.Opaque, true);
  29.         FormBorderStyle = FormBorderStyle.FixedToolWindow;
  30.         StartPosition = FormStartPosition.CenterScreen;
  31.         KeyDown += new KeyEventHandler(Graph_KeyDown);
  32.     }
  33.  
  34.     protected override void OnPaint(PaintEventArgs e)
  35.     {
  36.         double error = network.Train();
  37.         UpdatePlotArea(e.Graphics, error);
  38.         iteration++;
  39.         Invalidate();
  40.     }
  41.  
  42.     private void UpdatePlotArea(Graphics g, double error)
  43.     {
  44.         string banner = "Iteration={0}  Error={1:0.00}";
  45.         Text = string.Format(banner, iteration, error);
  46.         g.FillRectangle(Brushes.White, 0, 0, 400, 400);
  47.         foreach (float[] point in network.Points2D())
  48.         {
  49.             g.FillRectangle(brushes[(int)point[2]], point[0] * 395, point[1] * 395, 5, 5);
  50.         }
  51.         foreach (double[] line in network.HyperPlanes())
  52.         {
  53.             double a = -line[0] / line[1];
  54.             double c = -line[2] / line[1];
  55.             Point left = new Point(0, (int)(c * 400));
  56.             Point right = new Point(400, (int)((a + c) * 400));
  57.             g.DrawLine(new Pen(Color.Gray), left, right);
  58.         }
  59.     }
  60.  
  61.     private void Graph_KeyDown(object sender, KeyEventArgs e)
  62.     {
  63.         if (e.KeyCode == Keys.Space)
  64.         {
  65.             network.Initialise();
  66.             iteration = 0;
  67.         }
  68.     }
  69. }

Network.cs
As the name suggests, this is where the logic of the network resides. The big difference here between this and my previous project is that it can handle multiple outputs. In addition, the new constructor provides the ability to define any network architecture using an integer array. Thus, { 4, 4, 2, 3 } represents a network with four inputs, a first hidden layer with four neurons, a second hidden layer with two neurons, and a final output layer with three neurons.

  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4.  
  5. public class Network : List<Layer>
  6. {
  7.     private int[] dimensions;
  8.     private List<Pattern> patterns;
  9.  
  10.     public Network(int[] dimensions, string file)
  11.     {
  12.         this.dimensions = dimensions;
  13.         Initialise();
  14.         LoadPatterns(file);
  15.     }
  16.  
  17.     public void Initialise()
  18.     {
  19.         base.Clear();
  20.         base.Add(new Layer(dimensions[0]));
  21.         for (int i = 1; i < dimensions.Length; i++)
  22.         {
  23.             base.Add(new Layer(dimensions[i], base[i - 1], new Random()));
  24.         }
  25.     }
  26.  
  27.     private void LoadPatterns(string file)
  28.     {
  29.         patterns = new List<Pattern>();
  30.         StreamReader reader = File.OpenText(file);
  31.         while (!reader.EndOfStream)
  32.         {
  33.             string line = reader.ReadLine();
  34.             patterns.Add(new Pattern(line, Inputs.Count, Outputs.Count));
  35.         }
  36.         reader.Close();
  37.     }
  38.  
  39.     public double Train()
  40.     {
  41.         double error = 0;
  42.         foreach (Pattern pattern in patterns)
  43.         {
  44.             Activate(pattern);
  45.             for (int i = 0; i < Outputs.Count; i++)
  46.             {
  47.                 double delta = pattern.Outputs[i] - Outputs[i].Output;
  48.                 Outputs[i].CollectError(delta);
  49.                 error += Math.Pow(delta, 2);
  50.             }
  51.             AdjustWeights();
  52.         }
  53.         return error;
  54.     }
  55.  
  56.     private void Activate(Pattern pattern)
  57.     {
  58.         for (int i = 0; i < Inputs.Count; i++)
  59.         {
  60.             Inputs[i].Output = pattern.Inputs[i];
  61.         }
  62.         for (int i = 1; i < base.Count; i++)
  63.         {
  64.             foreach (Neuron neuron in base[i])
  65.             {
  66.                 neuron.Activate();
  67.             }
  68.         }
  69.     }
  70.  
  71.     private void AdjustWeights()
  72.     {
  73.         for (int i = base.Count - 1; i > 0; i--)
  74.         {
  75.             foreach (Neuron neuron in base[i])
  76.             {
  77.                 neuron.AdjustWeights();
  78.             }
  79.         }
  80.     }
  81.  
  82.     public List<double[]> HyperPlanes()
  83.     {
  84.         List<double[]> lines = new List<double[]>();
  85.         foreach (Neuron n in Outputs)
  86.         {
  87.             lines.Add(n.HyperPlane);
  88.         }
  89.         return lines;
  90.     }
  91.  
  92.     public List<float[]> Points2D()
  93.     {
  94.         int penultimate = base.Count - 2;
  95.         if (base[penultimate].Count != 2)
  96.         {
  97.             throw new Exception("Penultimate layer must be 2D for graphing.");
  98.         }
  99.         List<float[]> points = new List<float[]>();
  100.         for (int i = 0; i < patterns.Count; i++)
  101.         {
  102.             Activate(patterns[i]);
  103.             float[] point = new float[3];
  104.             point[0] = (float)base[penultimate][0].Output;
  105.             point[1] = (float)base[penultimate][1].Output;
  106.             if (Outputs.Count > 1)
  107.             {
  108.                 point[2] = patterns[i].MaxOutput;
  109.             }
  110.             else
  111.             {
  112.                 point[2] = (patterns[i].Outputs[0] >= 0.5) ? 1 : 0;
  113.             }
  114.             points.Add(point);
  115.         }
  116.         return points;
  117.     }
  118.  
  119.     private Layer Inputs
  120.     {
  121.         get { return base[0]; }
  122.     }
  123.  
  124.     private Layer Outputs
  125.     {
  126.         get { return base[base.Count - 1]; }
  127.     }
  128. }

Neuron.cs
This class represents the neuron, and contains activation and training logic.

  1. using System;
  2. using System.Collections.Generic;
  3.  
  4. public class Neuron
  5. {
  6.     private double bias;                       // Bias value.
  7.     private double error;                      // Sum of error.
  8.     private double input;                      // Sum of inputs.
  9.     private double gradient = 5;               // Steepness of sigmoid curve.
  10.     private double learnRate = 0.01;           // Learning rate.
  11.     private double output = double.MinValue;   // Preset value of neuron.
  12.     private List<Weight> weights;              // Collection of weights to inputs.
  13.  
  14.     public Neuron() { }
  15.  
  16.     public Neuron(Layer inputs, Random rnd)
  17.     {
  18.         weights = new List<Weight>();
  19.         foreach (Neuron input in inputs)
  20.         {
  21.             Weight w = new Weight();
  22.             w.Input = input;
  23.             w.Value = rnd.NextDouble() * 2 - 1;
  24.             weights.Add(w);
  25.         }
  26.     }
  27.  
  28.     public void Activate()
  29.     {
  30.         error = 0;
  31.         input = 0;
  32.         foreach (Weight w in weights)
  33.         {
  34.             input += w.Value * w.Input.Output;
  35.         }
  36.     }
  37.  
  38.     public void CollectError(double delta)
  39.     {
  40.         if (weights != null)
  41.         {
  42.             error += delta;
  43.             foreach (Weight w in weights)
  44.             {
  45.                 w.Input.CollectError(error * w.Value);
  46.             }
  47.         }
  48.     }
  49.  
  50.     public void AdjustWeights()
  51.     {
  52.         for (int i = 0; i < weights.Count; i++)
  53.         {
  54.             weights[i].Value += error * Derivative * learnRate * weights[i].Input.Output;
  55.         }
  56.         bias += error * Derivative * learnRate;
  57.     }
  58.  
  59.     private double Derivative
  60.     {
  61.         get
  62.         {
  63.             double activation = Output;
  64.             return activation * (1 - activation);
  65.         }
  66.     }
  67.  
  68.     public double Output
  69.     {
  70.         get
  71.         {
  72.             if (output != double.MinValue)
  73.             {
  74.                 return output;
  75.             }
  76.             return 1 / (1 + Math.Exp(-gradient * (input + bias)));
  77.         }
  78.         set
  79.         {
  80.             output = value;
  81.         }
  82.     }
  83.  
  84.     public double[] HyperPlane
  85.     {
  86.         get
  87.         {
  88.             double[] line = new double[3];
  89.             line[0] = weights[0].Value;
  90.             line[1] = weights[1].Value;
  91.             line[2] = bias;
  92.             return line;
  93.         }
  94.     }
  95. }

Layer.cs
This class inherits from List<Neuron> and simply holds a collection of neurons representing a discrete layer within the network.

  1. using System;
  2. using System.Collections.Generic;
  3.  
  4. public class Layer : List<Neuron>
  5. {
  6.     public Layer(int size)
  7.     {
  8.         for (int i = 0; i < size; i++)
  9.             base.Add(new Neuron());
  10.     }
  11.  
  12.     public Layer(int size, Layer layer, Random rnd)
  13.     {
  14.         for (int i = 0; i < size; i++)
  15.             base.Add(new Neuron(layer, rnd));
  16.     }
  17. }

Weight.cs
This class holds a pointer to a contributing neuron and the associated weight value.

  1. public class Weight
  2. {
  3.     public Neuron Input;
  4.     public double Value;
  5. }

Pattern.cs
This class holds a single line of data from your input file in a form that can be consumed by the network.

  1. using System;
  2.  
  3. public class Pattern
  4. {
  5.     private double[] inputs;
  6.     private double[] outputs;
  7.  
  8.     public Pattern(string value, int inputDims, int outputDims)
  9.     {
  10.         string[] line = value.Split(',');
  11.         if (line.Length != inputDims + outputDims)
  12.             throw new Exception("Input does not match network configuration");
  13.         inputs = new double[inputDims];
  14.         for (int i = 0; i < inputDims; i++)
  15.         {
  16.             inputs[i] = double.Parse(line[i]);
  17.         }
  18.         outputs = new double[outputDims];
  19.         for (int i = 0; i < outputDims; i++)
  20.         {
  21.             outputs[i] = double.Parse(line[i + inputDims]);
  22.         }
  23.     }
  24.  
  25.     public int MaxOutput
  26.     {
  27.         get
  28.         {
  29.             int item = -1;
  30.             double max = double.MinValue;
  31.             for (int i = 0; i < Outputs.Length; i++)
  32.             {
  33.                 if (Outputs[i] > max)
  34.                 {
  35.                     max = Outputs[i];
  36.                     item = i;
  37.                 }
  38.             }
  39.             return item;
  40.         }
  41.     }
  42.  
  43.     public double[] Inputs
  44.     {
  45.         get { return inputs; }
  46.     }
  47.  
  48.     public double[] Outputs
  49.     {
  50.         get { return outputs; }
  51.     }
  52. }

I hope you find this network useful. If you have any question or suggestions, I would be happy to hear from you.