在上一篇文章中,我們已經(jīng)帶大家了解了多輸入多輸出(MIMO)能力的架構(gòu)設(shè)計(jì)思路。
今天,小編將繼續(xù)深入解析如何將架構(gòu)設(shè)計(jì)真正落地到可運(yùn)行代碼,并帶來一套可復(fù)用的核心實(shí)現(xiàn)。會(huì)介紹多輸入多輸出支持框架的關(guān)鍵組成部分。通過清晰的結(jié)構(gòu)化設(shè)計(jì)、類型安全的接口抽象,為復(fù)雜的嵌入式 AI 模型建立一個(gè)高擴(kuò)展性、高可維護(hù)性的基礎(chǔ)底座。
下面,我們就將通過頭文件設(shè)計(jì)、基礎(chǔ)數(shù)據(jù)結(jié)構(gòu)構(gòu)建、生命周期管理等內(nèi)容一步步展示一個(gè)完整的MIMO支持框架是如何搭建起來的。
話不多說,上代碼?。ùa超載預(yù)警)
頭文件設(shè)計(jì):構(gòu)建類型安全的基礎(chǔ)
首先,我們需要一個(gè)類型和接口定義完備、可擴(kuò)展性強(qiáng)的頭文件model.h。
這一部分為后續(xù)的MIMO管理、張量訪問、預(yù)處理、模型統(tǒng)計(jì)等功能奠定了堅(jiān)實(shí)基礎(chǔ)。
#ifndefMODEL_H #defineMODEL_H #include"tensorflow/lite/c/common.h" // ============================================================================= // 配置常量 // ============================================================================= #defineMAX_INPUT_TENSORS 8 // 最大輸入張量數(shù)量 #defineMAX_OUTPUT_TENSORS 8 // 最大輸出張量數(shù)量 #defineMAX_TENSOR_DIMS 6 // 最大張量維度數(shù) #defineMODEL_NAME_MAX_LEN 64 // 模型名稱最大長(zhǎng)度 // ============================================================================= // 狀態(tài)碼定義 // ============================================================================= typedefenum{ kStatus_Success =0, kStatus_Fail = 1, kStatus_InvalidParam =2, kStatus_OutOfRange =3, kStatus_NotInitialized =4, kStatus_InsufficientMemory =5 }status_t; // ============================================================================= // 張量相關(guān)類型定義 // ============================================================================= typedefenum{ kTensorType_FLOAT32 =0, kTensorType_UINT8 =1, kTensorType_INT8 =2, kTensorType_INT32 =3, kTensorType_BOOL =4, kTensorType_UNKNOWN =255 }tensor_type_t; typedefstruct{ intsize; // 維度數(shù)量 int data[MAX_TENSOR_DIMS]; // 各維度的大小 }tensor_dims_t; // 單個(gè)張量的完整信息 typedefstruct{ intindex; // 張量索引 tensor_dims_t dims; // 維度信息 tensor_type_t type; // 數(shù)據(jù)類型 uint8_t* data; // 數(shù)據(jù)指針 size_t size_bytes; // 數(shù)據(jù)大?。ㄗ止?jié)) constchar* name; // 張量名稱(可選) }tensor_info_t; // 多張量信息結(jié)構(gòu) typedefstruct{ intcount; // 張量數(shù)量 tensor_info_t tensors[MAX_INPUT_TENSORS]; // 張量信息數(shù)組 }multi_tensor_info_t; // ============================================================================= // 模型統(tǒng)計(jì)信息 // ============================================================================= typedefstruct{ size_t arena_used_bytes; // 已使用的內(nèi)存 size_t arena_total_bytes; // 總內(nèi)存大小 int input_count; // 輸入張量數(shù)量 int output_count; // 輸出張量數(shù)量 constchar* model_name; // 模型名稱 }model_stats_t; // ============================================================================= // 核心接口聲明 // ============================================================================= // 模型生命周期管理 status_tMODEL_Init(void); status_tMODEL_Deinit(void); status_tMODEL_RunInference(void); // 模型信息查詢 intMODEL_GetInputTensorCount(void); intMODEL_GetOutputTensorCount(void); status_tMODEL_GetModelStats(model_stats_t* stats); constchar*MODEL_GetModelName(void); // 單張量操作接口 uint8_t*MODEL_GetInputTensorData(intindex, tensor_dims_t* dims,tensor_type_t* type); uint8_t*MODEL_GetOutputTensorData(intindex, tensor_dims_t* dims,tensor_type_t* type); // 增強(qiáng)的單張量接口 status_tMODEL_GetInputTensorInfo(intindex, tensor_info_t* info); status_tMODEL_GetOutputTensorInfo(intindex, tensor_info_t* info); // 批量操作接口 status_t MODEL_GetAllInputTensors(multi_tensor_info_t* input_info); status_t MODEL_GetAllOutputTensors(multi_tensor_info_t* output_info); // 數(shù)據(jù)預(yù)處理接口 status_t MODEL_ConvertInput(inttensor_index,uint8_t* data, const tensor_dims_t* dims,tensor_type_ttype); // 工具函數(shù) size_t MODEL_GetTensorSizeBytes(consttensor_dims_t* dims,tensor_type_ttype); constchar* MODEL_GetTensorTypeName(tensor_type_ttype); status_t MODEL_ValidateTensorDims(consttensor_dims_t* dims); #endif// MODEL_H核心實(shí)現(xiàn):從設(shè)計(jì)到代碼
接下來,進(jìn)入到實(shí)際實(shí)現(xiàn)部分。為了提高代碼可讀性,整體實(shí)現(xiàn)拆分為以下模塊:
全局變量與初始化
內(nèi)部工具函數(shù)
生命周期管理(Init / Deinit / Invoke)
全局變量和初始化:
#include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/micro_op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" #include"fsl_debug_console.h" #include"model.h" #include"model_data.h" // ============================================================================= // 全局變量 // ============================================================================= staticconsttflite::Model* s_model =nullptr; statictflite::MicroInterpreter* s_interpreter =nullptr; staticbools_model_initialized =false; // 張量?jī)?nèi)存區(qū)域 - 根據(jù)具體模型調(diào)整大小 staticuint8_ts_tensorArena[kTensorArenaSize] __ALIGNED(16); // 外部函數(shù)聲明 externtflite::MicroOpResolver &MODEL_GetOpsResolver(); // ============================================================================= // 內(nèi)部輔助函數(shù) // ============================================================================= // 獲取數(shù)據(jù)類型的字節(jié)大小 staticsize_tGetTypeSize(tensor_type_ttype) { switch(type) { case kTensorType_FLOAT32: case kTensorType_INT32: return 4; case kTensorType_UINT8: case kTensorType_INT8: case kTensorType_BOOL: return 1; default: return 0; } } // TensorFlow Lite類型轉(zhuǎn)換為我們的類型 statictensor_type_tConvertTfLiteType(TfLiteType tf_type) { switch(tf_type) { case kTfLiteFloat32: return kTensorType_FLOAT32; case kTfLiteUInt8: return kTensorType_UINT8; case kTfLiteInt8: return kTensorType_INT8; case kTfLiteInt32: return kTensorType_INT32; case kTfLiteBool: return kTensorType_BOOL; default: return kTensorType_UNKNOWN; } } // 從TensorFlow Lite張量提取信息 staticstatus_tExtractTensorInfo(TfLiteTensor* tf_tensor, intindex,tensor_info_t* info) { if(tf_tensor == nullptr|| info ==nullptr) { return kStatus_InvalidParam; } // 基本信息 info->index = index; info->type = ConvertTfLiteType(tf_tensor->type); info->data = tf_tensor->data.uint8; if (info->type == kTensorType_UNKNOWN) { PRINTF("Unsupported tensor type: %d ", tf_tensor->type); return kStatus_Fail; } // 維度信息 info->dims.size = tf_tensor->dims->size; if (info->dims.size > MAX_TENSOR_DIMS) { PRINTF("Tensor dimensions exceed maximum: %d > %d ", info->dims.size, MAX_TENSOR_DIMS); return kStatus_OutOfRange; } size_t total_elements =1; for(inti =0; i < info->dims.size; i++) { info->dims.data[i] = tf_tensor->dims->data[i]; total_elements *= info->dims.data[i]; } // 計(jì)算數(shù)據(jù)大小 info->size_bytes = total_elements *GetTypeSize(info->type); // 張量名稱(如果可用) info->name = nullptr; // TensorFlow Lite Micro通常不保存名稱 return kStatus_Success; }模型生命周期管理
這部分主要包括:
模型初始化(加載模型 / 創(chuàng)建解釋器 / 分配張量?jī)?nèi)存)
模型反初始化
執(zhí)行推理(Invoke)
//
模型生命周期管理
//
status_tMODEL_Init(void)
{
if (s_model_initialized) {
PRINTF("Model already initialized
");
return kStatus_Success;
}
// 加載模型
s_model= tflite::GetModel(model_data);
if (s_model->version()!=TFLITE_SCHEMA_VERSION) {
PRINTF("Model schema version %d not supported (expected %d)
",
s_model->version(),TFLITE_SCHEMA_VERSION);
return kStatus_Fail;
}
// 獲取操作解析器
tflite::MicroOpResolverµ_op_resolver= MODEL_GetOpsResolver();
// 創(chuàng)建解釋器
static tflite::MicroInterpreterstatic_interpreter(
s_model, micro_op_resolver, s_tensorArena, kTensorArenaSize);
s_interpreter= &static_interpreter;
// 分配張量?jī)?nèi)存
TfLiteStatus allocate_status=s_interpreter->AllocateTensors();
if (allocate_status!=kTfLiteOk) {
PRINTF("AllocateTensors() failed with status: %d
", allocate_status);
return kStatus_InsufficientMemory;
}
s_model_initialized=true;
// 打印模型信息
PRINTF("Model '%s' initialized successfully:
", MODEL_GetModelName());
PRINTF("- Input tensors: %d
", s_interpreter->inputs_size());
PRINTF("- Output tensors: %d
", s_interpreter->outputs_size());
PRINTF("- Arena used: %zu bytes
", s_interpreter->arena_used_bytes());
return kStatus_Success;
}
status_tMODEL_Deinit(void)
{
if (!s_model_initialized) {
return kStatus_NotInitialized;
}
// TensorFlow Lite Micro使用靜態(tài)內(nèi)存,無需顯式釋放
s_model= nullptr;
s_interpreter= nullptr;
s_model_initialized=false;
PRINTF("Model deinitialized
");
return kStatus_Success;
}
status_tMODEL_RunInference(void)
{
if (!s_model_initialized||s_interpreter==nullptr) {
PRINTF("Model not initialized
");
return kStatus_NotInitialized;
}
TfLiteStatus invoke_status=s_interpreter->Invoke();
if (invoke_status!=kTfLiteOk) {
PRINTF("Model inference failed with status: %d
", invoke_status);
return kStatus_Fail;
}
return kStatus_Success;
}
信息查詢接口
包含:
輸入/輸出張量數(shù)量查詢
模型統(tǒng)計(jì)信息讀取
模型名稱查詢
//
模型信息查詢
//
intMODEL_GetInputTensorCount(void)
{
if (!s_model_initialized || s_interpreter ==nullptr) {
return0;
}
return s_interpreter->inputs_size();
}
intMODEL_GetOutputTensorCount(void)
{
if (!s_model_initialized || s_interpreter ==nullptr) {
return0;
}
return s_interpreter->outputs_size();
}
status_t MODEL_GetModelStats(model_stats_t* stats)
{
if(stats == nullptr) {
return kStatus_InvalidParam;
}
if (!s_model_initialized || s_interpreter ==nullptr) {
return kStatus_NotInitialized;
}
stats->arena_used_bytes = s_interpreter->arena_used_bytes();
stats->arena_total_bytes = kTensorArenaSize;
stats->input_count = s_interpreter->inputs_size();
stats->output_count = s_interpreter->outputs_size();
stats->model_name =MODEL_GetModelName();
return kStatus_Success;
}
constchar*MODEL_GetModelName(void)
{
return MODEL_NAME;
}
下期預(yù)告
由于篇幅有限,本篇重點(diǎn)展示了:
頭文件設(shè)計(jì):類型安全、結(jié)構(gòu)清晰
核心實(shí)現(xiàn)框架:生命周期管理 + 內(nèi)部工具函數(shù)
基本模型信息查詢接口
在下一篇(系列最終章)中,我們將重點(diǎn)講解:
張量數(shù)據(jù)訪問接口(Input/Output Data APIs)完整實(shí)現(xiàn)
批量張量操作的高效實(shí)現(xiàn)方案
更實(shí)際的代碼示例與最佳實(shí)踐
-
嵌入式
+關(guān)注
關(guān)注
5202文章
20516瀏覽量
335059 -
模型
+關(guān)注
關(guān)注
1文章
3772瀏覽量
52160 -
代碼
+關(guān)注
關(guān)注
30文章
4972瀏覽量
74095 -
tensorflow
+關(guān)注
關(guān)注
13文章
336瀏覽量
62252
原文標(biāo)題:突破限制!為TensorFlow Lite Micro添加多輸入多輸出的完整方案解析(二)
文章出處:【微信號(hào):NXP_SMART_HARDWARE,微信公眾號(hào):恩智浦MCU加油站】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
多輸入多輸出收發(fā)器系統(tǒng)的無線通信設(shè)計(jì)
如何在TensorFlow Lite Micro中添加自定義操作符(1)
請(qǐng)問simulink的s-function模塊如何添加多輸入輸出接口
支持不同的輸出軌的多輸出電源參考設(shè)計(jì)
存在信道估計(jì)誤差的有限反饋多用戶多輸入多輸出傳輸性能分析
多輸入-多輸出線性系統(tǒng)有限時(shí)間觀測(cè)器設(shè)計(jì)方法
8發(fā)8收多輸入多輸出正交頻分多址系統(tǒng)平臺(tái)
多輸入多輸出天線系統(tǒng)MIMO分析
簡(jiǎn)介多輸入多輸出(Multiple-input Multiple-output)雷達(dá)
多輸出數(shù)據(jù)支持向量回歸學(xué)習(xí)算法
多輸入多輸出雷達(dá)信號(hào)與目標(biāo)干擾優(yōu)化
多輸入多輸出無線終端的在空中無線測(cè)試方法
如何為TensorFlow Lite Micro添加多輸入多輸出支持(一)
如何為TensorFlow Lite Micro添加多輸入多輸出支持(二)
評(píng)論