Sunday, 7 February 2016

Fisher Iris Dataset Classification with Torch

Note: The code supporting this post can be found on GitHub.

For people interested in deep learning, there's never been a better selection of open-source frameworks available. I recently took some time to review the top five (Caffe, CNTK, TensorFlow, Theano & Torch) and my favourite by far is still Torch.

Torch is a collection of flexible and powerful neural network and optimisation libraries. It sits on top of the Lua programming language, so has the potential to be super portable. It scales really well too, making full use of your CPU and GPU architectures.

However, for me, the most powerful reason to choose Torch is its large and ever growing community. This platform is here to stay, with active research happening in machine learning, computer vision, signal processing and parallel processing, amongst others. Already, some big names like Facebook, Google and Twitter have adopted Torch and actively contribute to the community.

A few years ago, when I was still hand-crafting neural nets, I wrote a blog post about classifying the UCI Fisher Iris dataset using back propagation. I thought I'd update this example, showing how I would do the same thing today with Torch. I hope you find the comparison useful

trainSet = {}
testSet = {}

function trainSet:size()
    return trainCount

-- Download data if not local.
if not paths.filep('') then
    print("Getting data...")
    data = ""
    os.execute('wget ' .. data)

-- Load data.
trainCount = 0; testCount = 0
file ='')
for line in file:lines() do
    if (string.len(line) > 0) then

        -- Read line from file.
        x1, x2, x3, x4, species = unpack(line:split(","))
        input = torch.Tensor({x1, x2, x3, x4});
        output = torch.Tensor(3):zero();

        -- Set output based on species.
        if (species == "Iris-setosa") then
            output[1] = 1
        elseif (species == "Iris-versicolor") then
            output[2] = 1
            output[3] = 1

        -- Keep 20% of data aside for testing.
        if (math.random() > 0.2) then
            table.insert(trainSet, {input, output})
            trainCount = trainCount + 1
            table.insert(testSet, {input, output})
            testCount = testCount + 1


-- Initialise the network.
require "nn"
inputs = 4; outputs = 3; hidden = 10;
mlp = nn.Sequential();
mlp:add(nn.Linear(inputs, hidden))
mlp:add(nn.Linear(hidden, outputs))

-- Train the network.
criterion = nn.MSECriterion() 
trainer = nn.StochasticGradient(mlp, criterion)
trainer.learningRate = 0.01
trainer.maxIteration = 25

-- Test the network.
correct = 0
for i = 1, testCount do
    val = mlp:forward(testSet[i][1])
    out = testSet[i][2]
    z = val:add(out)
    if (torch.max(z) > 1.5) then
        correct = correct + 1
print(string.format("%.2f%s", 100 * correct / testCount, "% correct"))