# 神經網路前向與反向傳播流程(C 語言實作) ## 🔗 專案連結 👉 原始碼與執行範例請參考 GitHub 專案:[🔗 tinyDL](https://github.com/Cedricyu/tinyDL) > 包含: > - `tensor.cuh / tensor.cu`:Tensor 建立與反向傳播邏輯 > - `linear.cuh / linear.cu`:Linear 層實作(含 forward) > - `test_main.c`:前向 + 模擬 loss + backward 範例 > - `Makefile`:支援編譯與 `nsys` 整合 ## 使用 Nsight Systems 分析 CUDA 神經網路推論與反向傳播 本文件記錄如何使用 `nsys` 工具對 CUDA 程式進行性能分析,以及一次完整的執行流程範例與輸出結果。 [🔗 nsys](https://developer.nvidia.com/nsight-systems) ## 前置條件 請確認已安裝: - NVIDIA Nsight Systems (`nsys`) - NVIDIA 驅動與 CUDA Toolkit - 支援 CUDA 的 GPU(本機為 GeForce MX350) ## 執行分析指令 ```bash $ make $ nsys profile -o profile_report ./test_main Collecting data... Device 0: NVIDIA GeForce MX350 Max Grid Dimensions: 2147483647 x 65535 x 65535 Max Threads per Block: 1024 Max Threads per SM: 2048 Total Global Memory: 1994 MB Number of SMs: 5 Running tests... Tensor data: 0.205817 0.832004 0.097134 0.461740 0.290924 0.068599 0.413493 0.314878 linear1: Tensor data: 0.127502 0.020209 0.602682 0.158443 0.960665 0.296314 0.458430 0.275413 0.855358 0.953153 0.478732 0.333036 0.730359 0.381478 0.094883 0.085736 0.411295 0.646642 0.818586 0.493656 linear2: Tensor data: 0.787915 0.054083 0.807755 0.104963 0.192709 0.704022 0.658168 0.365205 0.052052 0.820926 0.896487 0.588777 0.631892 0.632408 0.871375 dout[0] = 0.358866 dout[1] = 0.607835 dout[2] = 0.722711 dout[3] = 1.159301 dout[4] = 1.227906 dout[5] = 0.282369 dout[6] = 0.304543 dout[7] = 0.699840 dout[8] = 0.520265 dout[9] = 0.539541 dout[0] = 2.549826 dout[1] = 2.216318 dout[2] = 2.507960 dout[3] = 1.483091 dout[4] = 1.137166 dout[5] = 1.255380 Final Output: 2.549826 2.216318 2.507960 1.483091 1.137166 1.255380 linear2: Tensor grad: 1.986892 1.617203 1.807298 3.269224 2.665474 2.979790 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 linear1: Tensor grad: 0.875782 0.313628 0.749618 0.099318 0.483881 3.716241 1.327807 3.182652 0.421649 2.050058 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ``` - -o profile_report:指定輸出報告的檔名(產出 profile_report.qdrep 等)。 - ./test_main:目標可執行檔。 以下程式碼展示了如何在 C/C++ 中手動建立簡單的前向神經網路結構,進行推論、手動設置 loss 梯度,並呼叫 `tensor_backward()` 自動進行反向傳播。 ```cpp Tensor *x = tensor_create(batch_size, in_features, 1); // requires_grad = 1 fill_tensor_with_random(x); tensor_print(x); Linear linear1 = Linear(in_features, hidden_features); Linear linear2 = Linear(hidden_features, out_features); linear1.print_weight("linear1"); linear2.print_weight("linear2"); // 前向傳遞 Tensor *h = linear1.forward(x); // h = x @ W1 + b1 Tensor *y = linear2.forward(h); // y = h @ W2 + b2 printf("Final Output:\n"); for (int i = 0; i < batch_size * out_features; ++i) { printf("%f ", y->data[i]); } printf("\n"); // 模擬 output 的梯度 (手動設置 dL/dy) y->requires_grad = 1; y->grad = (float *)calloc(batch_size * out_features, sizeof(float)); memcpy(y->grad, y->data, sizeof(float) * y->batch_size * y->features); tensor_backward(y); // 自動反傳至 linear2, linear1, x linear2.print_grad("linear2"); linear1.print_grad("linear1"); ``` 本專案使用NVIDIA Nsight Systems 的效能分析報告來觀察程式效能,內容包含了 主機系統呼叫時間統計、CUDA API 呼叫摘要、GPU Kernel 執行時間 以及 GPU 記憶體搬移的時間與大小統計。 ```bash= ** OS Runtime Summary (osrt_sum): Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- ------------ --------- -------- ----------- ------------ ---------------------- 43.9 196,905,094 8 24,613,136.8 832,972.0 377,615 134,344,351 48,664,084.4 sem_wait 43.6 195,673,297 11 17,788,481.5 855,853.0 3,061 151,422,233 45,006,052.8 poll 8.1 36,484,476 509 71,678.7 7,061.0 1,006 7,558,509 516,801.1 ioctl 3.7 16,452,368 27 609,347.0 2,649.0 1,081 16,289,009 3,133,625.5 fopen 0.3 1,289,036 9 143,226.2 143,879.0 51,896 372,587 93,067.6 sem_timedwait 0.2 1,036,101 39 26,566.7 3,278.0 2,489 455,984 86,567.0 mmap64 0.1 237,655 6 39,609.2 27,866.0 21,196 85,230 25,353.4 pthread_create 0.0 208,716 1 208,716.0 208,716.0 208,716 208,716 0.0 pthread_cond_wait 0.0 168,838 55 3,069.8 2,463.0 1,284 7,815 1,722.6 open64 0.0 55,344 10 5,534.4 3,063.0 1,018 25,985 7,542.2 mmap 0.0 48,854 9 5,428.2 4,034.0 1,442 9,925 3,208.4 fread 0.0 40,424 11 3,674.9 1,747.0 1,003 10,967 3,795.9 fclose 0.0 26,326 1 26,326.0 26,326.0 26,326 26,326 0.0 fgets 0.0 21,161 6 3,526.8 3,356.5 1,138 7,030 1,968.0 open 0.0 16,980 3 5,660.0 4,270.0 3,237 9,473 3,342.3 pipe2 0.0 14,964 9 1,662.7 1,291.0 1,013 4,330 1,035.6 read 0.0 14,657 10 1,465.7 1,355.0 1,141 2,317 353.0 write 0.0 10,297 5 2,059.4 2,080.0 1,061 3,327 898.7 close 0.0 10,063 2 5,031.5 5,031.5 3,522 6,541 2,134.8 socket 0.0 7,856 3 2,618.7 3,159.0 1,357 3,340 1,096.4 fwrite 0.0 6,560 1 6,560.0 6,560.0 6,560 6,560 0.0 connect 0.0 6,133 2 3,066.5 3,066.5 2,240 3,893 1,168.8 fcntl 0.0 5,928 2 2,964.0 2,964.0 2,883 3,045 114.6 munmap 0.0 4,795 2 2,397.5 2,397.5 1,616 3,179 1,105.2 putc 0.0 4,580 1 4,580.0 4,580.0 4,580 4,580 0.0 pthread_cond_broadcast 0.0 1,253 1 1,253.0 1,253.0 1,253 1,253 0.0 bind Processing [/tmp/nsys-report-77e7.sqlite] with [/opt/nvidia/nsight-systems/2025.3.1/host-linux-x64/reports/cuda_api_sum.py]... ** CUDA API Summary (cuda_api_sum): Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- ----------- -------- -------- ---------- ------------ --------------------------------- 98.8 58,382,980 18 3,243,498.9 2,996.5 1,324 57,998,660 13,665,122.3 cudaMalloc 0.5 270,345 6 45,057.5 5,962.5 3,750 233,452 92,440.7 cudaLaunchKernel 0.4 253,566 18 14,087.0 2,289.5 1,285 81,739 20,961.0 cudaFree 0.2 127,892 18 7,105.1 4,471.5 3,026 19,612 4,720.4 cudaMemcpy 0.1 49,116 1 49,116.0 49,116.0 49,116 49,116 0.0 cudaGetDeviceProperties_v2_v12000 0.0 10,183 2 5,091.5 5,091.5 4,570 5,613 737.5 cudaDeviceSynchronize 0.0 584 1 584.0 584.0 584 584 0.0 cuModuleGetLoadingMode Processing [/tmp/nsys-report-77e7.sqlite] with [/opt/nvidia/nsight-systems/2025.3.1/host-linux-x64/reports/cuda_gpu_kern_sum.py]... ** CUDA GPU Kernel Summary (cuda_gpu_kern_sum): Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ------------------------------------------------------------------ 57.1 9,408 4 2,352.0 2,336.0 1,920 2,816 409.4 matrixTransposeMulKernel(float *, float *, float *, int, int, int) 42.9 7,072 2 3,536.0 3,536.0 3,424 3,648 158.4 matrixMultiplyKernel(float *, float *, float *, int, int, int) Processing [/tmp/nsys-report-77e7.sqlite] with [/opt/nvidia/nsight-systems/2025.3.1/host-linux-x64/reports/cuda_gpu_mem_time_sum.py]... ** CUDA GPU MemOps Summary (by Time) (cuda_gpu_mem_time_sum): Time (%) Total Time (ns) Count Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Operation -------- --------------- ----- -------- -------- -------- -------- ----------- ---------------------------- 63.0 7,136 6 1,189.3 1,104.0 1,056 1,504 173.4 [CUDA memcpy Device-to-Host] 37.0 4,192 12 349.3 304.0 224 512 131.9 [CUDA memcpy Host-to-Device] Processing [/tmp/nsys-report-77e7.sqlite] with [/opt/nvidia/nsight-systems/2025.3.1/host-linux-x64/reports/cuda_gpu_mem_size_sum.py]... ** CUDA GPU MemOps Summary (by Size) (cuda_gpu_mem_size_sum): Total (MB) Count Avg (MB) Med (MB) Min (MB) Max (MB) StdDev (MB) Operation ---------- ----- -------- -------- -------- -------- ----------- ---------------------------- 0.001 12 0.000 0.000 0.000 0.000 0.000 [CUDA memcpy Host-to-Device] 0.000 6 0.000 0.000 0.000 0.000 0.000 [CUDA memcpy Device-to-Host] ``` ## Host 系統呼叫時間統計 ✨ 觀察重點: - sem_wait 與 poll 佔據了將近 87% 的總時間,意味著程式大部分時間可能在等待同步或 I/O 事件。 📌 可能優化方向: - 如果這些操作不是必要的同步機制,可檢查是否能減少 sem_wait 使用或進行非阻塞優化。 ```bash= Time (%) Total Time (ns) Num Calls Avg (ns) Name -------- ---------------- ---------- ------------ ------------------- 43.9 196,905,094 8 24,613,136.8 sem_wait 43.6 195,673,297 11 17,788,481.5 poll ``` ## CUDA API 呼叫統計 ```bash= Time (%) Total Time (ns) Num Calls Avg (ns) Name -------- ---------------- ---------- ------------ --------------------------------- 98.8 58,382,980 18 3,243,498.9 cudaMalloc 0.5 270,345 6 45,057.5 cudaLaunchKernel 0.4 253,566 18 14,087.0 cudaFree ``` cudaMalloc 佔了近 99% 的 CUDA API 時間,表示動態記憶體分配代價非常高。 --- ## CUDA 核函數 (Kernel) 執行統計 | 時間佔比 | 總耗時 (ns) | 呼叫次數 | 平均耗時 (ns) | Kernel 名稱 | |----------|-------------|----------|----------------|-------------| | 57.1% | 9,408 | 4 | 2,352.0 | `matrixTransposeMulKernel` | | 42.9% | 7,072 | 2 | 3,536.0 | `matrixMultiplyKernel` | ### ✅ 可能優化方向: - 可針對上述兩個 kernel 做 memory coalescing、shared memory 等最佳化處理。 ```cu __global__ void matrixMultiplyKernel(float *A, float *B, float *C, int M, int N, int K) { __shared__ float shared_A[BLOCK_SIZE][BLOCK_SIZE]; __shared__ float shared_B[BLOCK_SIZE][BLOCK_SIZE]; int tx = threadIdx.x; int ty = threadIdx.y; int row = blockIdx.y * BLOCK_SIZE + ty; int col = blockIdx.x * BLOCK_SIZE + tx; float sum = 0.0f; for (int t = 0; t < (N + BLOCK_SIZE - 1) / BLOCK_SIZE; ++t) { if (row < M && t * BLOCK_SIZE + tx < N) shared_A[ty][tx] = A[row * N + t * BLOCK_SIZE + tx]; else shared_A[ty][tx] = 0.0f; if (col < K && t * BLOCK_SIZE + ty < N) shared_B[ty][tx] = B[(t * BLOCK_SIZE + ty) * K + col]; else shared_B[ty][tx] = 0.0f; __syncthreads(); for (int i = 0; i < BLOCK_SIZE; ++i) sum += shared_A[ty][i] * shared_B[i][tx]; __syncthreads(); } if (row < M && col < K) C[row * K + col] = sum; } __global__ void matrixTransposeMulKernel(float *A, float *B, float *C, int M, int N, int K) { int row = blockIdx.y * blockDim.y + threadIdx.y; // over N (Aᵗ 的 row) int col = blockIdx.x * blockDim.x + threadIdx.x; // over K (B 的 col) if (row < N && col < K) { float sum = 0.0f; for (int i = 0; i < M; ++i) { float a = A[i * N + row]; // Aᵗ[row, i] = A[i, row] float b = B[i * K + col]; // B[i, col] sum += a * b; } C[row * K + col] = sum; } } ``` --- ## CUDA 記憶體搬移時間統計 (by Time) | 時間佔比 | 總耗時 (ns) | 次數 | 平均耗時 (ns) | 操作類型 | |----------|-------------|------|----------------|----------| | 63.0% | 7,136 | 6 | 1,189.3 | `CUDA memcpy Device-to-Host` | | 37.0% | 4,192 | 12 | 349.3 | `CUDA memcpy Host-to-Device` | --- ## CUDA 記憶體搬移統計 (by Size) | 總傳輸量 (MB) | 次數 | 平均 (MB) | 操作類型 | |---------------|------|-----------|----------| | 0.001 | 12 | 0.000 | `Host-to-Device` | | 0.000 | 6 | 0.000 | `Device-to-Host` | ### ✅ 可能優化方向: - 雖然傳輸量小,但頻繁搬移造成 overhead。 - 合併小型搬移,或考慮使用 `cudaMemcpyAsync` 非同步搬移方式。 --- ## 🧠 總結建議 | 問題區域 | 建議方案 | |----------------|----------| | `sem_wait`, `poll` 等等待時間高 | 優化同步策略、使用非阻塞通知 | | `cudaMalloc` 分配太頻繁 | 改為初始化時分配,使用記憶體池 | | 核函數效率需提升 | 最佳化 memory access 與 thread/block 配置 | | 記憶體搬移次數多 | 合併搬移、改用非同步傳輸 |