概要
yolov8を転移学習したモデルをtflite形式に変換し、UnitV2で動かすことができた。
- 概要
- 背景と目的
- 詳細
- まとめ
背景と目的
UnitV2で、自前で学習させた物体検知モデルを動かす必要が出た。以前にedge impluseを用いてモデルを作成した動かしたものを参考に、モデル作成~組込みを行い、動かしてみる。
詳細
0. 環境
1. データセットとアノテーション
データセットは、4種類の工具の写真。枚数にして、全体で20枚程度。枚数が少ないが、ひとまず全体の作業の流れがつかめればよいので気にしない。
1.1 アノテーションツール
labelImgを用いることにした。
Windows側にて、以下の環境を作成。
mkdir -p labelImg cd labelImg python -m venv .venv .venv/Scripts/activate pip install labelImg
なお、pipでインストールした状態のままだと変数型のエラーがいたるところで出て動かないので、こちらあたりを参考にint型が必要とされる部分を修正した。
そのうえで、yolo形式で1つ1つアノテーションしていく。20枚程度なので、10分かかるかどうかというところ。
2. 転移学習
2.0 環境作成
転移学習、tflite形式の作成をWSL Linux (Ubuntu)環境(※)で行う。
※tflite形式作成の際に必要となるai-edge-litertというpythonパッケージのWindows用 .whlが用意されておらず、自分でビルドが必要となり面倒。
mkdir create_tflite python -m venv .venv source .venv/bin/activate
2.1 yolov8のインストール
pip install ultralytics
2.2 データセット作成
yoloの推奨形式に倣い、labelImgで出力したyolo形式のファイル(.txt)と元画像を以下のような構造のデータセットを準備。
dataset1/
dataset.yaml
images/
train/ --- 学習用に15枚分の元画像を入れる
val/ --- 検証用に5枚分の元画像を入れる
labels/
train/ --- 学習用に15枚分の.txtを入れる
val/ --- 検証用に5枚分の.txtを入れる
dataset.yamlは、学習用データへのパス(絶対パスが無難とのこと)と、クラス数、クラス名のリストを記載。
train: D:/labelImg/yolo_dataset/images/train val: D:/labelImg/yolo_dataset/images/val nc: 4 names: ['driver', 'nipper', 'needle nose plier', 'wire stripper']
2.3 モデル選択
UnitV2で実行するので、できるだけ小さなモデルが良いと思い、yolov8nを選択。こちらからダウンロード。
https://huggingface.co/Ultralytics/YOLOv8/blob/main/yolov8n.pt
2.4 学習
ultralyticsパッケージを使用したPythonスクリプトを作成した。実行すると、tools_inference_model/yolov8n_320_custom/weightsというディレクトリに、.ptファイルが出力された。
from ultralytics import YOLO import argparse parser = argparse.ArgumentParser() parser.add_argument('-e', '--epochs', type=int, default=10, help='エポック数') parser.add_argument('-b', '--batch', type=str, default=16, help='バッチサイズ') args = parser.parse_args() # モデルサイズを指定(例:'yolov8n' / 'yolov8s' / 'yolov8m' / 'yolov8l' / 'yolov8x') model = YOLO('yolov8n.pt') # ここを変えるとモデルの大きさが変わる # 転移学習を実行 model.train( data='dataset1/dataset.yaml', # データセットのyamlファイル epochs=args.epochs, # エポック数 imgsz=320, # 入力画像サイズ(例:320x320) batch=args.batch, # バッチサイズ(環境に応じて調整) workers=2, # dataloader用のworker数 project='tools_inference_model', # 結果保存先ディレクトリ name='yolov8n_320_custom', # 実行名 pretrained=True # 事前学習済みモデルで初期化 )
3. tflite形式の作成
作成されたモデルデータである.ptファイルをtflite形式に変換する。
3.1 環境作成
tensorflowバージョンに注意。
pip install tensorflow==2.12.0
3.2 変換を実行
以下のPythonスクリプトを作成して実行。tools_inference_model/yolov8n_320_custom/weights/best_saved_modelというディレクトリにyolov8n_float16.tfliteができた。
from ultralytics import YOLO # YOLOv8nモデルをロードしてONNX形式で保存 model = YOLO('tools_inference_model/yolov8n_320_custom/weights/best.pt') model.export(format='tflite')
4. tfliteのモデルを組み込んでプログラムを作成
以前利用したexample-standalone-inferencing-linuxを修正して使う。カメラで撮影した画像をモデルで判定し、結果をコマンドラインに出力するという動作は変わらないので、camera.cppが変更対象となる。
4.1 UnitV2用クロスコンパイル環境準備
下記と同様に、WSL環境に作成。
4.2 OpenCVライブラリのビルド
下記と同様に実施。
4.3 cpp-mjpeg-streamerの修正
このライブラリをアプリに組み込むことは必須ではないが、デバッグ目的で画像をリアルタイムで確認したいと思うので、下記と同様に実施。
4.4 camera.cpp修正
edge-impluse-sdk関連を使用しないので、それらに関連するコードはコメントアウトする前提で、必要なコードを追加していく。ここでは、重要な部分のみを記載する。
4.4.1 モデルの読み込み
main関数で、カメラをオープンした後に上記で転移学習させたモデルの読み込みを行い、メモリの確保などを行う。
// モデル読み込み const char* model_path = "yolov8n.tflite"; auto model = tflite::FlatBufferModel::BuildFromFile(model_path); if (!model) { std::cerr << "ERROR: failed to load model." << std::endl; return -1; } // インタープリタ作成 tflite::ops::builtin::BuiltinOpResolver resolver; std::unique_ptr<tflite::Interpreter> interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); // メモリ確保 interpreter->AllocateTensors(); // 入力テンソルの情報取得 int input = interpreter->inputs()[0]; TfLiteIntArray* dims = interpreter->tensor(input)->dims; int height = dims->data[1]; // EI_CLASSIFIER_INPUT_HEIGHTと同じになるはずのもの int width = dims->data[2]; // EI_CLASSIFIER_INPUT_WIDTHと同じになるはずのもの int channels = dims->data[3]; std::cout << "Input tensor shape: " << height << "x" << width << "x" << channels << std::endl;
4.4.3 推論実行と結果の出力部
whileループの中で、カメラで取得した画像をリサイズ後、モデルへの入力と推論実行、出力の取得、検出物体の情報(クラス、バウンディングボックス座標など)の取得、それらの描画、コンソールへの出力を行う。
// 画像をモデルに入力 float* input_data = interpreter->typed_tensor<float>(input); for (int i = 0; i < height * width * channels; ++i) { input_data[i] = cropped.data[i] / 255.0f; // 0〜1正規化された画像を想定 } // 推論実行 int64_t t0 = (int64_t)(get_current_time()); interpreter->Invoke(); // 出力テンソル取得 int output_index = interpreter->outputs()[0]; TfLiteTensor* output_tensor = interpreter->tensor(output_index); const float* output = output_tensor->data.f; int num_classes = output_tensor->dims->data[1] - 4; // クラス数分の次元 = data[1] - BB座標4次元 int num_proposals = output_tensor->dims->data[2]; // BB候補数: 入力画像サイズに依存する // 検出物体の情報(クラス、バウンディングボックス座標など) std::vector<Detection> detections = get_result(frame, width, height, output, num_classes, num_proposals); // frameにバウンディングボックスとクラス名を描画 for (int k = 0; k < detections.size(); k++) { Detection detection = detections[k]; cv::Rect bb = detection.box; cv::rectangle(frame, cv::Point(bb.x, bb.y), cv::Point(bb.x + bb.width, bb.y + bb.height), cv::Scalar(255, 0, 0), 2); cv::putText(frame, CLASS_NAMES[detection.class_id], cv::Point(bb.x, bb.y-5), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 0, 255), 1); } // コンソール出力 // std::cout を用いるだけなので、コードは省略
4.4.4 検出物体の情報
上記のget_resultという関数は、モデルの出力テンソルからクラス、バウンディングボックス座標、信頼度などを取り出すもの。また、バウンディングボックスの重複を排除する処理(NMS、Non-Maximum Suppression)を行う。NMSは、OpenCVのdnnモジュールに実装されているようだが、エラーになってしまいうまく使えなかったので、(上記OpenCVビルドの問題?)自前で簡単なものを実装した。
#include "opencv2/opencv.hpp" // 後処理パラメータ const float NMS_THRESH = 0.5f; // NMSにおける重複度合いの閾値 // 検出結果を格納するための構造体 struct Detection { cv::Rect box; // バウンディングボックス (x, y, width, height) float confidence; // 信頼度 int class_id; // クラスID }; // IOU、Intersection over Union算出関数 // IoUは、2つのBBの重複度合いを表す指標 float iou(const cv::Rect2f& a, const cv::Rect2f& b) { float inter = (a & b).area(); float union_ = a.area() + b.area() - inter; return inter / union_; } std::vector<Detection> get_result( cv::Mat frame, // 判別対象画像 const int input_width, // モデル入力サイズ const int input_height, // モデル入力サイズ const float* output, // 出力テンソルのポインタ int num_classes, // クラス数 int num_proposals, // BB候補数 const float score_thresh // スコア閾値 ) { // NMSで必要なデータを一時的に格納するベクター std::vector<cv::Rect> boxes; std::vector<float> confidences; std::vector<int> class_ids; // 画像サイズと入力サイズの比率(座標のスケーリング用) float x_factor = frame.cols / (input_width * 1.0f); float y_factor = frame.rows / (input_height * 1.0f); // printf("x_factor: %f, y_factor: %f\n", x_factor, y_factor); for (int i = 0; i < num_proposals; ++i) { const float* data = output + (4 * num_proposals) + i; // 最も信頼度の高いクラスとそのスコアを見つける float max_score = 0.0f; int best_class_id = -1; for (int j = 0; j < num_classes; ++j) { // クラスjのスコアは、ストライドnum_proposalsでアクセス if (data[j * num_proposals] > max_score) { max_score = data[j * num_proposals]; best_class_id = j; } } if (max_score < score_thresh) continue; // ボックス座標を取得 (cx, cy, w, h) float cx = output[0 * num_proposals + i]; float cy = output[1 * num_proposals + i]; float w = output[2 * num_proposals + i]; float h = output[3 * num_proposals + i]; // (cx, cy, w, h) から (x, y, w, h) へ変換 int left = static_cast<int>((cx - 0.5 * w) * x_factor * input_width); int top = static_cast<int>((cy - 0.5 * h) * y_factor * input_height); int width = static_cast<int>(w * x_factor * input_width); int height = static_cast<int>(h * y_factor * input_height); boxes.emplace_back(left, top, width, height); confidences.push_back(max_score); class_ids.push_back(best_class_id); printf("class=%d, score=%f, bb=[%d, %d, %d, %d]\n", best_class_id, max_score, left, top, width, height); } // NMS(Non-Maximum Suppression)処理 // 2つのBBのIOUが閾値以上である場合、信頼度の小さいBBを削除 bool nms[boxes.size()] = {false}; for (int index = 0; index < boxes.size(); index++) { if (nms[index]) continue; cv::Rect& a = boxes[index]; for (int other = index + 1; other < boxes.size(); other++) { cv::Rect& b = boxes[other]; float _iou = iou(a, b); printf("NMS %d vs %d: score=%f vs %f, iou=%f\n", index, other, confidences[index], confidences[other], _iou); if (_iou > NMS_THRESH) { if (confidences[index] > confidences[other]) { nms[other] = true; printf("NMS: %d removed\n", other); } else { nms[index] = true; printf("NMS: %d removed\n", index); break; } } } } std::vector<Detection> detections; for (int index = 0; index < boxes.size(); index++) { if (nms[index]) continue; Detection det; det.box = boxes[index]; det.confidence = confidences[index]; det.class_id = class_ids[index]; detections.push_back(det); } return detections; }
4.5 シリアル出力
camera.cppがコンソール出力したものを、UnitV2のシリアル端子に出力するには、以下と同様にPythonスクリプトを用いる。
5. UnitV2で動作確認
出来上がった実行ファイル、.tfliteファイル等をUnitV2に転送し実行したところ、Grove端子からのシリアル出力およびMJPEGストリームをWebブラウザで表示することができた。
まとめ
yolov8で転移学習したモデルをtflite形式に変換し、UnitV2で動かすことができた。工作品に組み込んで、活用していきたい。