Sunday, 28 February 2016

Word2Vec Lightweight C# Port

In January 2013, Tomas Mikolov and a team from Google published a paper titled Efficient Estimation of Word Representations in Vector Space. This proposed two new architectures for computing continuous vector representations of words from very large data sets. Whilst word embedding (as the technique is more generally known) was nothing new, the approach taken by the team demonstrated large improvements in accuracy at much lower computational cost.

Later that year, the word2vec C language source code supporting the paper was open sourced and is now available on Google code.

The word2vec tool is generally trained on a very large text corpus and subsequently learns vector representations (or embeddings) of words. The resulting word vectors can be used as features in natural language processing or machine learning applications.

The word2vec code has been hugely popular and as a result been ported to other languages including Python and Java. However, to the best of my knowledge there has not been a lightweight C# port of word2vec – so I decided to make one!

For my own purposes I chose to implement a Continuous Bag of Words model, rather than Skipgram, which works just fine for my needs.

The code below contains three classes:
  • Word2vec.cs – this is where the vector representations are learned.
  • Model.cs – a simple class showing how to query the word vectors.
  • Program.cs – a console application tying it all together.

I train my model on a 100MB extract from Wikipedia, which yields really nice results.

A simple way to visualise the learned representations is to list the closest words for a user input. The console application provided displays the closest words and their cosine similarity to the user input.

For example, if you enter 'france', you should see an output similar to this:

Input> france
1.000  france
0.681  spain
0.679  belgium
0.661  netherlands
0.654  italy
0.642  england
0.627  switzerland
0.611  luxembourg
0.569  portugal
0.560  russia
0.542  germany


Once words are represented as vectors it’s easy to perform standard vector operations, such as addition and subtraction. Research has shown that word vectors capture many linguistic regularities. A couple of famous examples often cited are:

vector('paris') - vector('france') + vector('italy') is close to vector('rome')

vector('king') - vector('man') + vector('woman') is close to vector('queen').

I’ve not included this in the code, but it’s really easy to implement and I leave for my readers to do so should they wish.



I hope you find this interesting – if you have any questions, please post on my Google+ page.

word2vec.cs
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5.  
  6. class Word2Vec
  7. {
  8.     public static int MinCount = 10;
  9.  
  10.     private const float sample = 1e-3f;
  11.     private const float starting_alpha = 0.05f;     // Starting learning rate.
  12.     private const int dimensions = 75;              // Word vector dimensions.
  13.     private const int exp_table_size = 1000;
  14.     private const int iter = 5;                        // Training iterations.
  15.     private const int max_exp = 6;
  16.     private const int negative = 5;                    // Number of negative examples.
  17.     private const int window = 5;                    // Window size.
  18.  
  19.     private Dictionary<string, float[]> syn0 = new Dictionary<string, float[]>();
  20.     private Dictionary<string, float[]> syn1 = new Dictionary<string, float[]>();
  21.     private Dictionary<string, int> vocab = new Dictionary<string, int>();
  22.     private float[] expTable = new float[exp_table_size];
  23.     private long train_words = 0;
  24.     private Random rnd = new Random();
  25.     private string[] roulette;
  26.  
  27.     public void Train(string train_file, string model_file)
  28.     {
  29.         BuildExpTable();
  30.         LearnVocab(train_file);
  31.         InitVectors();
  32.         InitUnigramTable();
  33.         TrainModel(train_file);
  34.         WriteVectorsToFile(model_file);
  35.     }
  36.  
  37.     private void BuildExpTable()
  38.     {
  39.         for (int i = 0; i < exp_table_size; i++)
  40.         {
  41.             expTable[i] = (float)Math.Exp((i / (double)exp_table_size * 2 - 1) * max_exp);
  42.             expTable[i] = expTable[i] / (expTable[i] + 1);
  43.         }
  44.     }
  45.  
  46.     private void InitVectors()
  47.     {
  48.         foreach (var key in vocab.Keys)
  49.         {
  50.             syn0.Add(key, new float[dimensions]);
  51.             syn1.Add(key, new float[dimensions]);
  52.             for (int i = 0; i < dimensions; i++)
  53.                 syn0[key][i] = (float)rnd.NextDouble() - 0.5f;
  54.         }
  55.     }
  56.  
  57.     private void WriteVectorsToFile(string output_file)
  58.     {
  59.         using (BinaryWriter bw = new BinaryWriter(File.Open(output_file, FileMode.Create)))
  60.         {
  61.             bw.Write(vocab.Count);
  62.             bw.Write(dimensions);
  63.             foreach (var vec in syn0)
  64.             {
  65.                 bw.Write(vec.Key);
  66.                 for (int i = 0; i < dimensions; i++)
  67.                     bw.Write(vec.Value[i]);
  68.             }
  69.         }
  70.     }
  71.  
  72.     private void LearnVocab(string train_file)
  73.     {
  74.         using (StreamReader sr = new StreamReader(train_file))
  75.         {
  76.             string line;
  77.             while ((line = sr.ReadLine()) != null)
  78.             {
  79.                 foreach (var word in line.Split(' '))
  80.                 {
  81.                     if (word.Length == 0) continue;
  82.                     train_words++;
  83.                     if (train_words % 100000 == 0) Console.Write("\r{0}k words read", train_words / 1000);
  84.                     if (!vocab.ContainsKey(word)) vocab.Add(word, 1);
  85.                     else vocab[word]++;
  86.                 }
  87.             }
  88.         }
  89.         Console.WriteLine();
  90.  
  91.         var tmp = (from w in vocab
  92.                    where w.Value < MinCount
  93.                    select w.Key).ToList();
  94.  
  95.         foreach (var key in tmp)
  96.             vocab.Remove(key);
  97.  
  98.         Console.WriteLine("Vocab size: {0}", vocab.Count);
  99.     }
  100.  
  101.     private void InitUnigramTable()
  102.     {
  103.         List<string> tmp = new List<string>();
  104.         foreach (var word in vocab)
  105.         {
  106.             int count = (int)Math.Pow(word.Value, 0.75);
  107.             for (int i = 0; i < count; i++) tmp.Add(word.Key);
  108.         }
  109.         roulette = tmp.ToArray();
  110.     }
  111.  
  112.     private void TrainModel(string train_file)
  113.     {
  114.         float alpha = starting_alpha;
  115.         float[] neu1 = new float[dimensions];
  116.         float[] neu1e = new float[dimensions];
  117.         int last_word_count = 0;
  118.         int sentence_position = 0;
  119.         int word_count = 0;
  120.         List<string> sentence = new List<string>();
  121.         long word_count_actual = 0;
  122.         DateTime start = DateTime.Now;
  123.  
  124.         for (int local_iter = 0; local_iter < iter; local_iter++)
  125.         {
  126.             using (StreamReader sr = new StreamReader(train_file))
  127.             {
  128.                 while (true)
  129.                 {
  130.                     if (word_count - last_word_count > 10000)
  131.                     {
  132.                         word_count_actual += word_count - last_word_count;
  133.                         last_word_count = word_count;
  134.                         int seconds = (int)(DateTime.Now - start).TotalSeconds + 1;
  135.                         float prog = (float)word_count_actual * 100 / (iter * train_words);
  136.                         float rate = (float)word_count_actual / seconds / 1000;
  137.                         Console.Write("\rProgress: {0:0.00}%  Words/sec: {1:0.00}k", prog, rate);
  138.                         alpha = starting_alpha * (1 - word_count_actual / (float)(iter * train_words + 1));
  139.                         if (alpha < starting_alpha * 0.0001) alpha = starting_alpha * 0.0001f;
  140.                     }
  141.  
  142.                     if (sentence.Count == 0)
  143.                     {
  144.                         if (sr.EndOfStream)
  145.                         {
  146.                             word_count_actual = train_words * (local_iter + 1);
  147.                             word_count = 0;
  148.                             last_word_count = 0;
  149.                             sentence.Clear();
  150.                             break;
  151.                         }
  152.  
  153.                         sentence.Clear();
  154.                         sentence_position = 0;
  155.                         string line = sr.ReadLine();
  156.                         foreach (var key in line.Split(' '))
  157.                         {
  158.                             if (key.Length == 0) continue;
  159.                             if (!vocab.ContainsKey(key)) continue;
  160.                             word_count++;
  161.                             if (sample > 0)
  162.                             {
  163.                                 double ran = (Math.Sqrt(vocab[key] / (sample * train_words)) + 1) * (sample * train_words) / vocab[key];
  164.                                 if (ran < rnd.NextDouble()) continue;
  165.                             }
  166.                             sentence.Add(key);
  167.                         }
  168.                     }
  169.  
  170.                     if (sentence.Count == 0) continue;
  171.  
  172.                     string word = sentence[sentence_position];
  173.                     for (int i = 0; i < dimensions; i++) neu1[i] = 0;
  174.                     for (int i = 0; i < dimensions; i++) neu1e[i] = 0;
  175.  
  176.                     int cw = 0;
  177.                     for (int w = 0; w < window * 2 + 1; w++)
  178.                     {
  179.                         if (w != window)
  180.                         {
  181.                             int p = sentence_position - window + w;
  182.                             if (p < 0) continue;
  183.                             if (p >= sentence.Count) continue;
  184.                             string last_word = sentence[p];
  185.                             float[] tmp0 = syn0[last_word];
  186.                             for (int i = 0; i < dimensions; i++) neu1[i] += tmp0[i];
  187.                             cw++;
  188.                         }
  189.                     }
  190.  
  191.                     if (cw > 0)
  192.                     {
  193.                         for (int i = 0; i < dimensions; i++) neu1[i] /= cw;
  194.                         for (int w = 0; w < negative + 1; w++)
  195.                         {
  196.                             string target;
  197.                             int label;
  198.                             if (w == 0)
  199.                             {
  200.                                 target = word;
  201.                                 label = 1;
  202.                             }
  203.                             else
  204.                             {
  205.                                 target = roulette[rnd.Next(roulette.Length)];
  206.                                 if (target == word) continue;
  207.                                 label = 0;
  208.                             }
  209.                             float a = 0;
  210.                             float g = 0;
  211.                             float[] tmp1 = syn1[target];
  212.                             for (int i = 0; i < dimensions; i++) a += neu1[i] * tmp1[i];
  213.                             if (a > max_exp) g = (label - 1) * alpha;
  214.                             else if (a < -max_exp) g = (label - 0) * alpha;
  215.                             else g = (label - expTable[(int)((a + max_exp) * (exp_table_size / max_exp / 2))]) * alpha;
  216.                             for (int i = 0; i < dimensions; i++) neu1e[i] += g * tmp1[i];
  217.                             for (int i = 0; i < dimensions; i++) tmp1[i] += g * neu1[i];
  218.                         }
  219.  
  220.                         for (int w = 0; w < window * 2 + 1; w++)
  221.                         {
  222.                             if (w != window)
  223.                             {
  224.                                 int p = sentence_position - window + w;
  225.                                 if (p < 0) continue;
  226.                                 if (p >= sentence.Count) continue;
  227.                                 string last_word = sentence[p];
  228.                                 float[] tmp0 = syn0[last_word];
  229.                                 for (int i = 0; i < dimensions; i++) tmp0[i] += neu1e[i];
  230.                             }
  231.                         }
  232.                     }
  233.  
  234.                     sentence_position++;
  235.                     if (sentence_position >= sentence.Count)
  236.                     {
  237.                         sentence.Clear();
  238.                         continue;
  239.                     }
  240.                 }
  241.             }
  242.         }
  243.         Console.WriteLine();
  244.     }
  245. }

model.cs
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4.  
  5. public class Model
  6. {
  7.     public int Dimensions;
  8.  
  9.     private Dictionary<string, float[]> model = new Dictionary<string, float[]>();
  10.     private int wordCount;
  11.  
  12.     public Dictionary<string, float> NearestWords(string word, int count)
  13.     {
  14.         var vec = WordVector(word);
  15.         if (vec == null) return new Dictionary<string, float>();
  16.         var bestd = new float[count];
  17.         var bestw = new string[count];
  18.  
  19.         for (var n = 0; n < count; n++) bestd[n] = -1;
  20.  
  21.         foreach (var key in model.Keys)
  22.         {
  23.             var dist = 0f;
  24.             for (var i = 0; i < Dimensions; i++) dist += vec[i] * model[key][i];
  25.             for (var c = 0; c < count; c++)
  26.                 if (dist > bestd[c])
  27.                 {
  28.                     for (var i = count - 1; i > c; i--)
  29.                     {
  30.                         bestd[i] = bestd[i - 1];
  31.                         bestw[i] = bestw[i - 1];
  32.                     }
  33.                     bestd[c] = dist;
  34.                     bestw[c] = key;
  35.                     break;
  36.                 }
  37.         }
  38.  
  39.         var result = new Dictionary<string, float>();
  40.         for (var i = 0; i < count; i++) result.Add(bestw[i], bestd[i]);
  41.         return result;
  42.     }
  43.  
  44.     public float[] WordVector(string word)
  45.     {
  46.         if (!model.ContainsKey(word)) return null;
  47.         return model[word];
  48.     }
  49.  
  50.     public void LoadVectors(string model_file)
  51.     {
  52.         var file = model_file;
  53.         using (var br = new BinaryReader(File.Open(file, FileMode.Open)))
  54.         {
  55.             wordCount = br.ReadInt32();
  56.             Dimensions = br.ReadInt32();
  57.             for (var w = 0; w < wordCount; w++)
  58.             {
  59.                 var word = br.ReadString();
  60.                 var vec = new float[Dimensions];
  61.                 for (var d = 0; d < Dimensions; d++) vec[d] = br.ReadSingle();
  62.                 Normalise(vec);
  63.                 model[word] = vec;
  64.             }
  65.         }
  66.     }
  67.  
  68.     private void Normalise(float[] vec)
  69.     {
  70.         var len = 0f;
  71.         for (var i = 0; i < Dimensions; i++) len += vec[i] * vec[i];
  72.         len = (float)Math.Sqrt(len);
  73.         for (var i = 0; i < Dimensions; i++) vec[i] /= len;
  74.     }
  75. }

program.cs
  1. using System;
  2.  
  3. class Program
  4. {
  5.     static void Main(string[] args)
  6.     {
  7.         Console.Write("Train? {Y/N} ");
  8.         if (Console.ReadKey().Key == ConsoleKey.Y) Train();
  9.         Console.WriteLine();
  10.  
  11.         // Load model from file.
  12.         var model = new Model();
  13.         model.LoadVectors("model.bin");
  14.  
  15.         // Get 10 nearest words to user input
  16.         while (true)
  17.         {
  18.             Console.Write("Input> ");
  19.             var word = Console.ReadLine();
  20.             foreach (var item in model.NearestWords(word, 10))
  21.                 Console.WriteLine("{0:0.000}\t{1}", item.Value, item.Key);
  22.             Console.WriteLine();
  23.         }
  24.     }
  25.  
  26.     private static void Train()
  27.     {
  28.         // Train vector model and save to file.
  29.         var word2vec = new Word2Vec();
  30.         word2vec.Train("corpus.txt", "model.bin");
  31.     }
  32. }