Skip to main content

Classify BBC news headlines with Microsoft ML.NET


This sample tutorial illustrates using ML.NET to create a multiclass classifier via a .NET Core console application using C# in Visual Studio 2017.

The sample is the console app that uses ML.NET to train a model that classifies and predicts the category of the news headlines. It also evaluates the model with a second dataset for quality analysis. The news articles datasets are from the BBC News.

Prerequisites

Create a console application

  1. Open Visual Studio 2017. Select File > New > Project from the menu bar. In the New Project* dialog, select the Visual C# node followed by the .NET Core node. Then select the Console App (.NET Core) project template. In the Name text box, type "BbcNewsClassifier" and then select the OK button.
  2. Extract dataset from the zip file into the project folder
  3. Install the Microsoft.ML NuGet Package:
    In Solution Explorer, right-click on your project and select Manage NuGet Packages. Choose "nuget.org" as the Package source, select the Browse tab, search for Microsoft.ML, select that package in the list and select the Install button. Select the OK button on the Preview Changes dialog and then select the I Accept button on the License Acceptance dialog if you agree with the license terms for the packages listed.

Create classes and define paths

 Add the following additional using statements to the top of the Program.cs file:
using System;
using Microsoft.ML.Models;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using System.IO;

You need to create two global variables to hold the path to the training and test sets.

Add the following code to the line right above the Main method to specify required global variables:
static string _testSet = "news-test.txt";
static string _trainingSet = "news-train.txt";

static List<string> _categories = new List<string> { "business", "entertainment", "politics", "sport", "tech" };

You need to create some classes for your input data and predictions:
using Microsoft.ML.Runtime.Api;

public class NewsData
{
    [Column(ordinal: "0")]
    public string Text;

    [Column(ordinal: "1", name: "Label")]
    public string Label;
}
public class NewsPrediction
{
    [ColumnName("Score")]
    public float[] Score;
}

NewsData is the input dataset class and has a string (Label) that has a value for news category, and a string for the news headline (Text). Both fields have Column attributes attached to them. This attribute describes the order of each field in the data file, and which is the Label field. NewsPrediction is the class used for prediction after the model has been trained. It has a float[] array (Score) and a Score ColumnName attribute. The Label is used to create and train the model, and it's also used with a second dataset to evaluate the model. The Score is used during prediction and evaluation. For evaluation, an input with training data, the predicted values, and the model are used.

 

Data preparation

BBC News dataset consists of 5 folders (one for each category: business, entertainment, politics, sport, tech). Each folder has files with news articles. We need to pre-process this data before we can continue.

In the Program.cs file, replace the Console.WriteLine("Hello World!") line with the following code in the Main method:
PrepareData();

The PrepareData method executes the following tasks:
  • Read all news articles by categories
  • Get the first two paragraphs from each article (news headline and a summary)
  •  Cleanup new lines and tabs
  • Separate articles in each category between training and test sets (20% - test set and 80% - training)
  • Save train and test datasets in separate files
Create the PrepareData method, just after the Main method, using the following code:
private static void PrepareData()
{
    File.Delete(_trainingSet);
    File.Delete(_testSet);

    var basePath = "bbc/";
    var training = new List<NewsData>();
    var test = new List<NewsData>();

    for(var i = 0; i < _categories.Count(); i++)
    {
        var category = _categories[i];
        var path = basePath + category + "/";
        var files = Directory.GetFiles(path);

        var texts = new List<string>();
        foreach (var file in files)
        {
             var text = File.ReadAllText(file);
             var textParts = text.Split("\n").ToList();
             textParts.RemoveAll(s => string.IsNullOrEmpty(s));
             text = textParts[0] + " " + textParts[1];

             text = text.Replace(Environment.NewLine, " ");
             text = text.Replace("\n", " ");
             text = text.Replace("\r", " ");
             text = text.Replace("   ", " ");

             texts.Add(text);
        }
         texts = texts.OrderBy(s => _random.Next()).ToList();
         var trainingTextsCount = (texts.Count / 100) * 80;
         var trainingTexts = texts.GetRange(0, trainingTextsCount);
         training.AddRange(trainingTexts.Select(s => new NewsData { Text = s, Label = category }).ToList());

         var testTexts = texts.GetRange(trainingTextsCount, texts.Count - trainingTextsCount);
         test.AddRange(testTexts.Select(s => new NewsData { Text = s, Label = category }).ToList());
    }

     File.AppendAllLines(_testSet, test.Select(s => $"{s.Text}\t{s.Label}"));
     File.AppendAllLines(_trainingSet, training.Select(s => $"{s.Text}\t{s.Label}"));
}

Training model

In the Main method after PrepareData method call add following lines:
var model = Train();

The Train method executes the following tasks:
  • Load training data
  • Preprocess and featurize  the data
  • Train the model

private static PredictionModel<NewsData, NewsPrediction> Train()
{
    var pipeline = new LearningPipeline();
    pipeline.Add(new TextLoader<NewsData>(_trainingSet, useHeader: false, separator: "tab"));
    pipeline.Add(new TextFeaturizer("Features", "Text")
    {
        KeepDiacritics = false,
        KeepPunctuations = false,
        TextCase = TextNormalizerTransformCaseNormalizationMode.Lower,
        OutputTokens = true,
        Language = TextTransformLanguage.English,
        StopWordsRemover = new PredefinedStopWordsRemover(),
        VectorNormalizer = TextTransformTextNormKind.L2,
        CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false },
        WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = true }
    });
    pipeline.Add(new Dictionarizer("Label"));
    pipeline.Add(new StochasticDualCoordinateAscentClassifier());
    return pipeline.Train<NewsData, NewsPrediction>();
}

Here we used StochasticDualCoordinateAscentClassifier as a classification method, but you can also try other available methods in ML.NET and see how this will affect the results.

Evaluate the model

Now that you've created and trained the model, you need to evaluate it with a different dataset for quality assurance and validation. In the Evaluate method, the model created in Train is passed in to be evaluated. Create the Evaluate method, just after Train, as in the following code:
public static void Evaluate(PredictionModel<NewsData, NewsPrediction> model)
{
    var testData = new TextLoader<NewsData>(_trainingSet, useHeader: false, separator: "tab");
    var evaluator = new ClassificationEvaluator();
    var metrics = evaluator.Evaluate(model, testData);
            
    Console.WriteLine();
    Console.WriteLine("PredictionModel quality metrics evaluation");
    Console.WriteLine("------------------------------------------");
    Console.WriteLine($"AccuracyMacro: {metrics.AccuracyMacro:P2}");
    Console.WriteLine($"AccuracyMicro: {metrics.AccuracyMicro:P2}");
    Console.WriteLine($"LogLoss: {metrics.LogLoss:P2}");
}

Add a call to the new method from the Main method, right under the Train method call, using the following code:
Evaluate(model);

Results

Your results should be similar to the following. As the pipeline processes, it displays messages. You may see warnings, or processing messages. These have been removed from the following results for clarity.

PredictionModel quality metrics evaluation
------------------------------------------
AccuracyMacro: 99,71%
AccuracyMicro: 99,70%
LogLoss: 13,33%

Congratulations! You've now successfully built a machine learning model for classifying news headlines. You can find the source code for this tutorial at GitHub repository.

P.S.

In some cases, you don't want to train the model every time you start your application. And it makes sense to save the model somewhere and then load it on the application start. To do this in ML.NET you can make next:

model.WriteAsync(_modelPath);
model = PredictionModel.ReadAsync(_modelPath);

Here we write the model in the file in _modelPath, but ML.NET also support read/write from streams.

Comments

Popular posts from this blog

How to Build TypeScript App and Deploy it on GitHub Pages

Quick Summary In this post, I will show you how to easily build and deploy a simple TicksToDate time web app like this: https://zubialevich.github.io/ticks-to-datetime .

Pros and cons of different ways of storing Enum values in the database

Lately, I was experimenting with Dapper for the first time. During these experiments, I've found one interesting and unexpected behavior of Dapper for me. I've created a regular model with string and int fields, nothing special. But then I needed to add an enum field in the model. Nothing special here, right? Long story short, after editing my model and saving it to the database what did I found out? By default Dapper stores enums as integer values in the database (MySql in my case, can be different for other databases)! What? It was a surprise for me! (I was using ServiceStack OrmLite for years and this ORM by default set's enums to strings in database) Before I've always stored enum values as a string in my databases! After this story, I decided to analyze all pros and cons I can imagine of these two different ways of storing enums. Let's see if I will be able to find the best option here.

Caching strategies

One of the easiest and most popular ways to increase system performance is to use caching. When we introduce caching, we automatically duplicate our data. It's very important to keep your cache and data source in sync (more or less, depends on the requirements of your system) whenever changes occur in the system. In this article, we will go through the most common cache synchronization strategies, their advantages, and disadvantages, and also popular use cases.

How to maintain Rest API backward compatibility?

All minor changes in Rest API should be backward compatible. A service that is exposing its interface to internal or/and external clients should always be backward compatible between major releases. A release of a new API version is a very rare thing. Usually, a release of a new API version means some global breaking changes with a solid refactoring or change of business logic, models, classes and requests. In most of the cases, changes are not so drastic and should still work for existing clients that haven't yet implemented a new contract. So how to ensure that a Rest API doesn't break backward compatibility?