今回はML.NET による画像分類をおこないます。
ここでは犬、猫、鳩の画像を読み込ませて、それがどの画像なのかを推論させます。推論のための学習用の画像データと実際に推論させる画像データは検索して適当に集めます。
上は学習用のデータの一部です。下はテスト用のデータです。
分類結果は以下のように出力されます。
Contents
Microsoft.MLで検索してもMicrosoft.MLがでてこない
WindowsForms(.NET Framework)でプロジェクトを作成するのですが、いきなりハマりどころがあります。ML.NETをNuGetでインストールしたいのですが、Microsoft.MLで検索してもMicrosoft.MLがでてこないのです。
解決法は対象プラットフォームをAny CPUからX64に変更します。これでMicrosoft.MLが表示されます。
インストールするのは以下のものです。
Microsoft.ML
Microsoft.ML.Vision
Microsoft.ML.ImageAnalytics
SciSharp.TensorFlow.Redist V2.3.1
最新のSciSharp.TensorFlow.Redistでは例外が発生する
ここもハマりどころがあってSciSharp.TensorFlow.Redistの最新のものをいれると実行時に例外が発生します。バージョン2.3.1のものをインストールします。
WindowsForms(.NET Framework)でアプリを作成する
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
// 先頭に以下4行を追加する using Microsoft.ML; using Microsoft.ML.Data; using Microsoft.ML.Transforms; using Microsoft.ML.Vision; public partial class Form1 : Form { // モデルを保存する場所 string _modelFilePath = Application.StartupPath + @"\model.zip"; // 学習用の画像ファイルがあるフォルダ const string _trainFilesFolder = @"D:\PictureTrain"; // 推論させる画像ファイルがあるフォルダ const string _testFilesFolder = @"D:\PictureTest"; public Form1() { InitializeComponent(); label1.Text = ""; } } |
それから以下のようなクラスを定義します。
1 2 3 4 5 6 7 8 9 10 11 |
public class PetData { public string Breed { get; set; } public string ImageFilePath { get; set; } } public class PetDataPrediction : PetData { public string PredictedBreed { get; set; } public float[] Score { get; set; } } |
学習させる
button1がクリックされたら学習用の画像ファイルで学習させます。ファイル名は cat_XX.png とか dog_XX.png のような名前にしてファイル名から答えがわかるようにしておきます。
1 2 3 4 5 6 7 |
public partial class Form1 : Form { Task.Run(()=> { MLContext mlContext = new MLContext(seed: 1); CreateTrainedModel(mlContext, _trainFilesFolder, _modelFilePath); }); } |
第二引数にフォルダのパスを指定しているので、そこにあるファイルをすべて取得します。
PetDataオブジェクトを生成してファイルのパスと品種名を格納します。品種名はファイル名に書かれているのでそれを切り出して使います。オブジェクトをpetDataSetに格納したらMLContext.Data.LoadFromEnumerable(petDataSet)を実行してデータのロード、シャッフルします。
そのあと若干の加工を加えます。品種文字列を数値に変換して列名を Label とします。パスから画像をロードしてデータセット用に Transformer を生成します。そのあとデータセットを学習データ(70%)と検証データ(30%)に分割します。
どのように学習するのかを定義して学習させ、学習モデルをファイルに保存します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
public partial class Form1 : Form { void CreateTrainedModel(MLContext mlContext, string trainFilesFolder, string modelFilePath) { label1.Text = "処理中"; string[] imagefilePaths = Directory.GetFiles(trainFilesFolder); // パス一覧の読み込み var petDataSet = imagefilePaths.Select(path => new PetData() { // ファイル名に品種名が含まれるので、品種名を切り出して Breed に設定 Breed = path.Substring(path.LastIndexOf('\\') + 1, path.LastIndexOf('_') - path.LastIndexOf('\\') - 1), ImageFilePath = path } ); // データのロード IDataView petDataView = mlContext.Data.LoadFromEnumerable(petDataSet); // データセットをシャッフル IDataView shuffledPetDataView = mlContext.Data.ShuffleRows(petDataView); // データセットの加工 IDataView transformedDataView = ProcessingDataset(mlContext, shuffledPetDataView); // データセットを学習データ(70%)と検証データ(30%)に分割 var trainValidationTestSplit = mlContext.Data.TrainTestSplit(transformedDataView, testFraction: 0.3); IDataView trainDataView = trainValidationTestSplit.TrainSet; IDataView validationDataView = trainValidationTestSplit.TestSet; // 学習の定義 var option = new ImageClassificationTrainer.Options() { LabelColumnName = "Label", //ラベル列 FeatureColumnName = "RawImageBytes", // 特徴列 Arch = ImageClassificationTrainer.Architecture.ResnetV250, //転移学習モデルの選択 Epoch = 200, BatchSize = 10, LearningRate = 0.01f, ValidationSet = validationDataView, // 検証データを設定 MetricsCallback = (metrics) => label1.Text = metrics.ToString(), WorkspacePath = Application.StartupPath + "\\Workspace", }; var val = mlContext.Transforms.Conversion.MapKeyToValue( // 推論結果のラベルを数値から品種文字列に変換 inputColumnName: "PredictedLabel", outputColumnName: "PredictedBreed" ); var trainer = mlContext.MulticlassClassification.Trainers.ImageClassification(option).Append(val); ITransformer trainedModel = trainer.Fit(trainDataView); // 学習モデルをファイルに保存 mlContext.Model.Save(trainedModel, trainDataView.Schema, modelFilePath); label1.Text = "完了"; } IDataView ProcessingDataset(MLContext mlContext, IDataView dataView) { return mlContext.Transforms.Conversion.MapValueToKey( // 品種文字列を数値に変換して列名を Label とする inputColumnName: nameof(PetData.Breed), outputColumnName: "Label", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue) .Append(mlContext.Transforms.LoadRawImageBytes( // パスから画像をロード inputColumnName: nameof(PetData.ImageFilePath), imageFolder: null, outputColumnName: "RawImageBytes")) .Fit(dataView) // データセット用に Transformer を生成 .Transform(dataView); // Transformer をデータセットに適用 } } |
推論させる
学習モデルが生成されたらこれをもとに推論させます。button2がクリックされたら、_testFilesFolderで指定されたフォルダ内にある画像データが犬なのか猫なのか鳩なのかを推論させ、その結果をHTMLファイルとして出力します。
ここではフォルダ内のファイルから自作メソッドのCreateTestDataでIDataViewを取得して、自作メソッドのCreateHTMLに渡して結果を出力しています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
public partial class Form1 : Form { private void button2_Click(object sender, EventArgs e) { Task.Run(() => { label1.Text = "処理中"; // コンテキストの生成 MLContext mlContext = new MLContext(seed: 1); IDataView testDataView = CreateTestData(mlContext, _testFilesFolder); string htmlPath = CreateHTML(mlContext, testDataView, _modelFilePath); label1.Text = "完了"; if (htmlPath != "") System.Diagnostics.Process.Start(htmlPath); }); } } |
フォルダ内のファイルからデータをロードしてIDataViewを取得しています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
public partial class Form1 : Form { IDataView CreateTestData(MLContext mlContext, string folderPath) { List<PetData> petDatas = new List<PetData>(); string[] paths = Directory.GetFiles(folderPath); foreach (string path in paths) { PetData petData = new PetData() { ImageFilePath = path }; petDatas.Add(petData); } IDataView dataView = mlContext.Data.LoadFromEnumerable(petDatas); // データセットの加工 return ProcessingDataset(mlContext, dataView); } } |
CreateHTMLメソッドはモデルデータを読み出し、テストデータを推論しています。また犬、猫、鳩のスコアも出力しています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
public partial class Form1 : Form { string CreateHTML(MLContext mlContext, IDataView testDataView, string modelFilePath) { if (!File.Exists(_modelFilePath)) { MessageBox.Show($"{modelFilePath}が存在しません"); return ""; } ITransformer trainedModel = mlContext.Model.Load(modelFilePath, out DataViewSchema dataViewSchema); // テストデータで推論を実行 IDataView testDataPredictionsDataView = trainedModel.Transform(testDataView); // ラベルと品種文字列のキーバリューを取得 VBuffer<ReadOnlyMemory<char>> keyValues = default; dataViewSchema["Label"].GetKeyValues(ref keyValues); IEnumerable<PetDataPrediction> predictions = mlContext.Data.CreateEnumerable<PetDataPrediction>(testDataPredictionsDataView, reuseRowObject: true).Take(10); string testFilePath = $@"{Application.StartupPath}\{DateTimeOffset.Now:yyyyMMddHHmmss}.html"; // HTML で評価結果を書き出し using (var writer = new StreamWriter(testFilePath)) { writer.WriteLine($"<html>\n<head>\n<title>推論結果</title>"); writer.WriteLine("<link rel=\"stylesheet\" href=\"https://cdn.jsdelivr.net/npm/bootstrap@4.5.3/dist/css/bootstrap.min.css\">"); writer.WriteLine("</head>\n<body>"); // 評価データ毎の分類結果 writer.WriteLine($"<h1>推論結果</h1>"); writer.WriteLine($"<div>\n<table class=\"table table-bordered\">"); foreach (var prediction in predictions) { writer.WriteLine($"<tr><td>"); // 画像ファイル名 writer.WriteLine($"画像ファイル名:{Path.GetFileName(prediction.ImageFilePath)}<br />"); // 画像 if(prediction.ImageFilePath.IndexOf("\\") == -1) writer.WriteLine($"<img class=\"img-fluid\" src=\"{prediction.ImageFilePath}\" width = \"200\" /><br />"); else writer.WriteLine($"<img class=\"img-fluid\" src=\"file:///{prediction.ImageFilePath}\" width = \"200\" /><br />"); // 推論結果 writer.WriteLine($"推論結果: {prediction.PredictedBreed}<br />"); // クラス毎の推論結果 writer.WriteLine($"クラス毎の推論結果<br />"); prediction.Score.Select((s, i) => (Index: i, Label: keyValues.GetItemOrDefault(i), Score: s)) .ToList().ForEach(c => writer.WriteLine($"{c.Label}: {c.Score:P}<br />")); writer.WriteLine("</td></tr>"); } writer.WriteLine("</table>\n</div>"); writer.WriteLine("</body>\n</html>"); } return testFilePath; } } |