2024/09/29

LightGBMを直接

Copilotに質問ぶつけながら、まあ、何故ここまで嘘を教えるのか疑問がありますが、間違いを指摘すると修正はしてくれます。しかし、その間違いを見つけるのに当然苦労します。LightGBMをC APIの利用で使おうとしてメソッドの引数の数がそもそも違ったり、順番が違ったりでどうにか動かす事が出来る様にはなりましたが、RMSEがもう一つだし、安定した結果が得られません。

大分苦労してどうにか動く所まではこぎ着けましたが、これが実際にモノになるのかが疑問です。ここまで苦労して結局AutoMLに及ばなかったりModel Builderに及ばなかったり、まあ、そもそもAutoMLで最新LightGBMを採用してくれたりModel Builderで採用してくれたりしたら結果的にはこの苦労が無駄になるのかも😓

  1. // LightGBM.dllの関数をインポート
  2. [DllImport("lib_lightgbm.dll", CallingConvention = CallingConvention.Cdecl)]
  3. // public static extern int LGBM_DatasetCreateFromFile(string filename, string parameters, ref IntPtr handle);
  4. public static extern int LGBM_DatasetCreateFromFile(
  5. string filename,
  6. string parameters,
  7. IntPtr reference,
  8. out IntPtr dataset);
  9.  
  10. [DllImport("lib_lightgbm.dll", CallingConvention = CallingConvention.Cdecl)]
  11. // public static extern int LGBM_BoosterCreate(IntPtr trainData, string parameters, ref IntPtr handle);
  12. public static extern int LGBM_BoosterCreate(
  13. IntPtr trainData,
  14. string parameters,
  15. out IntPtr booster);
  16.  
  17. [DllImport("lib_lightgbm.dll", CallingConvention = CallingConvention.Cdecl)]
  18. public static extern int LGBM_BoosterUpdateOneIter(
  19. IntPtr booster,
  20. out int isFinished);
  21.  
  22. [DllImport("lib_lightgbm.dll", CallingConvention = CallingConvention.Cdecl)]
  23. public static extern int LGBM_BoosterPredictForMat(
  24. IntPtr booster,
  25. double[] data,
  26. int dataType,
  27. int numRows,
  28. int numCols,
  29. int isRowMajor,
  30. int predictType,
  31. int startIteration,
  32. int numIteration,
  33. string parameters,
  34. out int outLen,
  35. IntPtr outResult);
  36.  
  37. [DllImport("lib_lightgbm.dll", CallingConvention = CallingConvention.Cdecl)]
  38. public static extern IntPtr LGBM_GetLastError();

LightGBMの最新バージョンのソースをダウンロードしてDLLをビルド。プロジェクトにlib_lightgbm.dllとlib_lightbgm.pdbを追加する。

  1. string filename = this.CSV_name; // データファイルのパス
  2. var lines = File.ReadAllLines(filename);
  3. var header = lines[0].Split(',');
  4. var data = lines.Skip(1).Select(line => line.Split(',').Select(double.Parse).ToArray()).ToArray();
  5.  
  6. // データのランダム分割
  7. var rand = new Random();
  8. data = data.OrderBy(x => rand.Next()).ToArray();
  9. int trainSize = (int)(data.Length * 0.8);
  10. var trainData = data.Take(trainSize).ToArray();
  11. var testData = data.Skip(trainSize).ToArray();
  12.  
  13. // トレーニングデータの保存
  14. string trainFilename = "train_data.csv";
  15. File.WriteAllLines(trainFilename, new[] { string.Join(",", header) }.Concat(trainData.Select(row => string.Join(",", row))));
  16.  
  17. // テストデータの保存
  18. string testFilename = "test_data.csv";
  19. File.WriteAllLines(testFilename, new[] { string.Join(",", header) }.Concat(testData.Select(row => string.Join(",", row))));

学習データはこんな感じで準備して

  1. // データセットの作成
  2. IntPtr reference = IntPtr.Zero;
  3. IntPtr dataset;
  4. string datasetParameters = "max_bin=255 header=true label_column=name:Souha";
  5. int datasetResult = LGBM_DatasetCreateFromFile(trainFilename, datasetParameters, reference, out dataset);

データセットを作成して

  1. // ブースターの作成
  2. string boosterParameters = "objective=regression;metric=rmse;num_leaves=31;learning_rate=0.05;feature_fraction=0.8";
  3. IntPtr booster;
  4. int boosterResult = LGBM_BoosterCreate(dataset, boosterParameters, out booster);

ブースターの作成? これ学習モデルなんだと思います。

  1. // ブースターのトレーニング
  2. int numIterations = (int)nudRepeat.Value;
  3. int earlyStoppingRounds = 10; // 早期停止のラウンド数
  4. double bestScore = double.MaxValue;
  5. int bestIteration = 0;
  6.  
  7. for (int i = 0; i < numIterations; i++)
  8. {
  9. int isFinished;
  10. int updateResult = LGBM_BoosterUpdateOneIter(booster, out isFinished);
  11. if (updateResult != 0)
  12. {
  13. rtbLog.AppendText("Failed to update booster at iteration " + i + Environment.NewLine);
  14. break;
  15. }
  16.  
  17. // 早期停止のチェック
  18. IntPtr outResult = Marshal.AllocHGlobal(sizeof(double) * trainSize);
  19. int outLength;
  20. LGBM_BoosterPredictForMat(booster, trainData.SelectMany(x => x).ToArray(),
  21. 1, trainSize, trainData[0].Length,
  22. 1, 0, 0, -1, "predict_type=normal", out outLength, outResult);
  23. double[] resultArray = new double[trainSize];
  24. Marshal.Copy(outResult, resultArray, 0, trainSize);
  25. double rmse = CalculateRMSE(trainData.Select(row => row.Last()).ToArray(), resultArray);
  26.  
  27. if (rmse < bestScore)
  28. {
  29. bestScore = rmse;
  30. bestIteration = i;
  31. }
  32. else if (i - bestIteration >= earlyStoppingRounds)
  33. {
  34. rtbLog.AppendText("Early stopping at iteration " + i + Environment.NewLine);
  35. break;
  36. }
  37. }

過剰学習を防ぐ為(?)に早期停止とかの確認入れてって事で

  1. // テストデータの準備
  2. var testFeatures = testData.Select(row => row.Take(row.Length - 1).ToArray()).ToArray();
  3. var actualValues = testData.Select(row => row.Last()).ToArray();
  4. double[] testDataFlat = testFeatures.SelectMany(x => x).ToArray();
  5. int numRows = testFeatures.Length;
  6. int numCols = testFeatures[0].Length;
  7. int isRowMajor = 1; // 1: 行優先, 0: 列優先
  8. string predictParameters = "predict_type=normal";
  9. IntPtr outResultPtr = Marshal.AllocHGlobal(sizeof(double) * numRows);
  10. int outLen;
  11.  
  12. // 予測の実行
  13. int predictResult = LGBM_BoosterPredictForMat(booster, testDataFlat, 1, numRows, numCols, isRowMajor, 0, 0, -1, predictParameters, out outLen, outResultPtr);
  14.  
  15. if (predictResult == 0)
  16. {
  17. // Console.WriteLine("Prediction completed successfully.");
  18. rtbLog.AppendText("Prediction completed successfully." + Environment.NewLine);
  19. // 予測結果を表示
  20. double[] resultArray = new double[numRows];
  21. Marshal.Copy(outResultPtr, resultArray, 0, numRows);
  22.  
  23. // RMSEの計算
  24. double rmse = CalculateRMSE(actualValues, resultArray);
  25. rtbLog.AppendText("RMSE: " + rmse + Environment.NewLine);
  26. }
  27. else
  28. {
  29. rtbLog.AppendText("Failed to predict." + Environment.NewLine);
  30. }

こんな感じにRMSEも表示出来る。

追記 2024.10.1
早期停止判断時のRMSE用の予想とRMSEの計算時に渡しているものが微妙らしく、Copilotのコードでは判断時のRMSEを表示させてみるとおかしな値になっていたので、終了時にひょうじされるRMSEがそれっぽかったので、そちらを採用して書き直してみた所、やっとそれっぽくなってきた😁

0 件のコメント:

コメントを投稿