本文來自 Flutter 社群的投稿
tflite_flutter 外掛的核心特性:
外掛提供了與 TFLite Java 和 Swift API 相似的 Dart API,所以其靈活性和在這些平臺上的效果是完全一樣的;外掛通過 dart:ffi 直接與 TensorFlow Lite C API 相繫結,所以它比其它平臺整合方式更加高效;無需編寫特定平臺的程式碼;通過 NNAPI 提供加速支援,在 Android 上使用 GPU Delegate,在 iOS 上使用 Metal Delegate。本文中,我們將使用 tflite_flutter 構建一個文字分類 Flutter 應用,帶您體驗 tflite_flutter 外掛。首先從新建一個 Flutter 專案text_classification_app開始。
初始化配置
Linux 和 Mac 使用者
將 install.sh 拷貝到您應用的根目錄,然後在根目錄執行 sh install.sh,本例中就是目錄 text_classification_app/。
Windows 使用者
將 install.bat 檔案拷貝到應用根目錄,並在根目錄執行批處理檔案 install.bat,本例中就是目錄 text_classification_app/。
它會自動從 GitHub 倉庫的 Releases 裡下載最新的二進位制資源,然後把它放到指定的目錄下。
tflite_flutter 的 GitHub 倉庫https://github.com/am15h/tflite_flutter_plugin
獲取外掛
在 pubspec.yaml 新增 tflite_flutter: ^<latest_version>
最新版本情況參考外掛的釋出地址https://pub.flutter-io.cn/packages/tflite_flutter下載模型
要在移動端上執行 TensorFlow 訓練模型,我們需要使用 .tflite 格式。如果需要了解如何將 TensorFlow 訓練的模型轉換為 .tflite 格式,請參閱官方指南。
這裡我們準備使用 TensorFlow 官方站點上預訓練的文字分類模型。
該預訓練的模型可以預測當前段落的情感是積極還是消極。它是基於來自 Mass 等人的 Large Movie Review Dataset v1.0 資料集進行訓練的。資料集由基於 IMDB 電影評論所標記的積極或消極標籤組成,檢視更多資訊。
將 text_classification.tflite 和 text_classification_vocab.txt 檔案拷貝到 text_classification_app/assets/ 目錄下。
在 pubspec.yaml 檔案中新增 assets/。
assets: - assets/
現在萬事俱備,我們可以開始寫程式碼了。
模型轉換器(Converter)的 Python API 指南https://tensorflow.google.cn/lite/convert/python_api預訓練的文字分類模型 (text_classification.tflite)https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification.tflite資料集 (text_classification_vocab.txt)https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification_vocab.txt
實現分類器
預處理
正如文字分類模型頁面裡所提到的。可以按照下面的步驟使用模型對段落進行分類:
對段落文字進行分詞,然後使用預定義的詞彙集將它轉換為一組詞彙 ID;將生成的這組詞彙 ID 輸入 TensorFlow Lite 模型裡;從模型的輸出裡獲取當前段落是積極或者是消極的概率值。我們首先寫一個方法對原始字串進行分詞,其中使用 text_classification_vocab.txt作為詞彙集。
在 lib/ 資料夾下建立一個新檔案 classifier.dart。
這裡先寫程式碼載入 text_classification_vocab.txt 到字典裡。
import 'package:flutter/services.dart';class Classifier { final _vocabFile = 'text_classification_vocab.txt'; Map<String, int> _dict; Classifier() { _loadDictionary(); } void _loadDictionary() async { final vocab = await rootBundle.loadString('assets/$_vocabFile'); var dict = <String, int>{}; final vocabList = vocab.split('\\n'); for (var i = 0; i < vocabList.length; i++) { var entry = vocabList[i].trim().split(' '); dict[entry[0]] = int.parse(entry[1]); } _dict = dict; print('Dictionary loaded successfully'); }}
△ 載入字典
現在我們來編寫一個函式對原始字串進行分詞。
import 'package:flutter/services.dart';class Classifier { final _vocabFile = 'text_classification_vocab.txt'; // 單句的最大長度 final int _sentenceLen = 256; final String start = '<START>'; final String pad = '<PAD>'; final String unk = '<UNKNOWN>'; Map<String, int> _dict; List<List<double>> tokenizeInputText(String text) { // 使用空格進行分詞 final toks = text.split(' '); // 建立一個列表,它的長度等於 _sentenceLen,並且使用 <pad> 的對應的字典值來填充 var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble()); var index = 0; if (_dict.containsKey(start)) { vec[index++] = _dict[start].toDouble(); } // 對於句子裡的每個單詞,在對映裡找到相應的索引值 for (var tok in toks) { if (index > _sentenceLen) { break; } vec[index++] = _dict.containsKey(tok) ? _dict[tok].toDouble() : _dict[unk].toDouble(); } // 按照我們的直譯器輸入 tensor 所需的格式 [1, 256] 返回 List<List<double>> return [vec]; }}
△ 分詞程式碼
使用 tflite_flutter 進行分析
此處的分析指的是在裝置上基於輸入的資料,使用 TensorFlow Lite 模型的處理過程。要使用 TensorFlow Lite 模型進行分析,需要通過直譯器來執行它.
建立直譯器,載入模型
tflite_flutter 提供了一個方法直接通過資源建立直譯器。
static Future<Interpreter> fromAsset(String assetName, {InterpreterOptions options})
由於我們的模型在 assets/資料夾下,需要使用上面的方法來建立解析器。對於 InterpreterOptions 的相關說明,請參考這裡。
import 'package:flutter/services.dart';// 引入 tflite_flutterimport 'package:tflite_flutter/tflite_flutter.dart';class Classifier { // 模型檔案的名稱 final _modelFile = 'text_classification.tflite'; // TensorFlow Lite 直譯器物件 Interpreter _interpreter; Classifier() { // 當分類器初始化以後載入模型 _loadModel(); } void _loadModel() async { // 使用 Interpreter.fromAsset 建立直譯器 _interpreter = await Interpreter.fromAsset(_modelFile); print('Interpreter loaded successfully'); }}
△ 建立直譯器的程式碼
如果您不希望將模型放在 assets/ 目錄下,tflite_flutter 還提供了工廠建構函式建立直譯器,更多資訊。
我們開始進行分析!
現在用下面方法啟動分析:
void run(Object input, Object output);
注意這裡的方法和 Java API 中的是一樣的。
Object input 和 Object output 必須是與 Input Tensor 和 Output Tensor 維度相同的列表。
要檢視 input tensor 和 output tensor 的維度,可以使用如下程式碼:
_interpreter.allocateTensors();// 列印 input tensor 列表print(_interpreter.getInputTensors());// 列印 output tensor 列表print(_interpreter.getOutputTensors());
在本例中 text_classification 模型的輸出如下:
InputTensorList:[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf280, name: embedding_input, type: TfLiteType.float32, shape: [1, 256], data: 1024]OutputTensorList:[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf140, name: dense_1/Softmax, type: TfLiteType.float32, shape: [1, 2], data: 8]
現在,我們實現分類方法,該方法返回值為 1 表示積極,返回值為 0 表示消極。
int classify(String rawText) { // tokenizeInputText 返回形狀為 [1, 256] 的 List<List<double>> List<List<double>> input = tokenizeInputText(rawText); // [1,2] 形狀的輸出 var output = List<double>(2).reshape([1, 2]); // run 方法會執行分析並且儲存輸出的值 _interpreter.run(input, output); var result = 0; // 如果輸出中第一個元素的值比第二個大,那麼句子就是消極的 if ((output[0][0] as double) > (output[0][1] as double)) { result = 0; } else { result = 1; } return result; }
△ 用於分析的程式碼
在 tflite_flutter 的 extension ListShape on List 下面定義了一些使用的擴充套件:
// 將提供的列表進行矩陣變形,輸入引數為元素總數並保持相等// 用法:List(400).reshape([2,10,20])// 返回 List<dynamic>List reshape(List<int> shape)// 返回列表的形狀List<int> get shape// 返回列表任意形狀的元素數量int get computeNumElements
最終的 classifier.dart 應該是這樣的:
import 'package:flutter/services.dart';// 引入 tflite_flutterimport 'package:tflite_flutter/tflite_flutter.dart';class Classifier { // 模型檔案的名稱 final _modelFile = 'text_classification.tflite'; final _vocabFile = 'text_classification_vocab.txt'; // 語句的最大長度 final int _sentenceLen = 256; final String start = '<START>'; final String pad = '<PAD>'; final String unk = '<UNKNOWN>'; Map<String, int> _dict; // TensorFlow Lite 直譯器物件 Interpreter _interpreter; Classifier() { // 當分類器初始化的時候載入模型 _loadModel(); _loadDictionary(); } void _loadModel() async { // 使用 Intepreter.fromAsset 建立解析器 _interpreter = await Interpreter.fromAsset(_modelFile); print('Interpreter loaded successfully'); } void _loadDictionary() async { final vocab = await rootBundle.loadString('assets/$_vocabFile'); var dict = <String, int>{}; final vocabList = vocab.split('\\n'); for (var i = 0; i < vocabList.length; i++) { var entry = vocabList[i].trim().split(' '); dict[entry[0]] = int.parse(entry[1]); } _dict = dict; print('Dictionary loaded successfully'); } int classify(String rawText) { // tokenizeInputText 返回形狀為 [1, 256] 的 List<List<double>> List<List<double>> input = tokenizeInputText(rawText); //輸出形狀為 [1, 2] 的矩陣 var output = List<double>(2).reshape([1, 2]); // run 方法會執行分析並且將結果儲存在 output 中。 _interpreter.run(input, output); var result = 0; // 如果第一個元素的輸出比第二個大,那麼當前語句是消極的 if ((output[0][0] as double) > (output[0][1] as double)) { result = 0; } else { result = 1; } return result; } List<List<double>> tokenizeInputText(String text) { // 用空格分詞 final toks = text.split(' '); // 建立一個列表,它的長度等於 _sentenceLen,並且使用 <pad> 對應的字典值來填充 var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble()); var index = 0; if (_dict.containsKey(start)) { vec[index++] = _dict[start].toDouble(); } // 對於句子中的每個單詞,在 dict 中找到相應的 index 值 for (var tok in toks) { if (index > _sentenceLen) { break; } vec[index++] = _dict.containsKey(tok) ? _dict[tok].toDouble() : _dict[unk].toDouble(); } // 按照我們的直譯器輸入 tensor 所需的形狀 [1,256] 返回 List<List<double>> return [vec]; }}
現在,可以根據您的喜好實現 UI 的程式碼,分類器的用法比較簡單。
// 建立 Classifier 物件Classifer _classifier = Classifier();// 將目標語句作為引數,呼叫 classify 方法_classifier.classify("I liked the movie");// 返回 1 (積極的)_classifier.classify("I didn't liked the movie");// 返回 0 (消極的)
請在這裡查閱完整程式碼:
△ 文字分類示例應用
了解更多關於 tflite_flutter 外掛的資訊,請訪問 GitHub repo: am15h/tflite_flutter_plugin。
你問我答
問:tflite_flutter 和 tflite v1.0.5 有哪些區別?
tflite v1.0.5 側重於為特定用途的應用場景提供高階特性,比如圖片分類、物體檢測等等。而新的 tflite_flutter 則提供了與 Java API 相同的特性和靈活性,而且可以用於任何 tflite 模型中,它還支援 delegate。
由於使用 dart:ffi (dart ↔️ (ffi) ↔️ C),tflite_flutter 非常快 (擁有低延時)。而 tflite 使用平臺整合 (dart ↔️ platform-channel ↔️ (Java/Swift) ↔️ JNI ↔️ C)。
問:如何使用 tflite_flutter 建立圖片分類應用?有沒有類似 TensorFlow Lite Android Support Library 的依賴包?
TensorFlow Lite Flutter Helper Library 為處理和控制輸入及輸出的 TFLite 模型提供了易用的架構。它的 API 設計和文件與 TensorFlow Lite Android Support Library 是一樣的。更多資訊請參考 TFLite Flutter Helper 的 GitHub 。
TFLite Flutter Helper 開發庫 GitHub 倉庫地址https://github.com/am15h/tflite_flutter_helper
以上是本文的全部內容,歡迎大家對 tflite_flutter 外掛進行反饋,請在 GitHub 報 bug 或提出功能需求。謝謝關注,感謝 Flutter 團隊的 Michael Thomsen。
向 tflite_flutter 外掛提出建議和反饋https://github.com/am15h/tflite_flutter_plugin/issues閱讀文中的連結,請點選閱讀原文或者下面 URL 檢視:https://flutter.cn/community/tutorials/text-classification-using-tensorflow-lite-plugin-for-flutter
譯者:Yuan,谷創字幕組審校:Xinlei、Lynn Wang、Alex,CFUG 社群