App下載

C語言實現(xiàn)矩陣乘法的高效方法

被風(fēng)吹過灼思 2023-06-30 10:03:11 瀏覽數(shù) (3530)
反饋

本文將介紹一種使用C語言實現(xiàn)矩陣乘法的高效方法,即分塊算法。分塊算法的基本思想是將兩個大矩陣分成若干個小矩陣,然后對每對小矩陣進行乘法運算,最后將結(jié)果合并成一個大矩陣。這樣可以減少緩存失效的次數(shù),提高運算速度。下面給出具體的代碼實現(xiàn)。

#include <stdio.h>
#include <stdlib.h>
#include <time.h>


#define N 1000 // 矩陣的大小
#define B 100 // 分塊的大小


// 生成一個隨機矩陣
void generate_matrix(double *A) {
    srand(time(NULL));
    for (int i = 0; i < N * N; i++) {
        A[i] = rand() % 10;
    }
}


// 打印一個矩陣
void print_matrix(double *A) {
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < N; j++) {
            printf("%.2f ", A[i * N + j]);
        }
        printf("\n");
    }
}


// 普通的矩陣乘法
void normal_multiply(double *A, double *B, double *C) {
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < N; j++) {
            double sum = 0;
            for (int k = 0; k < N; k++) {
                sum += A[i * N + k] * B[k * N + j];
            }
            C[i * N + j] = sum;
        }
    }
}


// 分塊的矩陣乘法
void block_multiply(double *A, double *B, double *C) {
    for (int i = 0; i < N; i += B) {
        for (int j = 0; j < N; j += B) {
            for (int k = 0; k < N; k += B) {
                // 對每個小矩陣進行乘法運算
                for (int ii = i; ii < i + B && ii < N; ii++) {
                    for (int jj = j; jj < j + B && jj < N; jj++) {
                        double sum = 0;
                        for (int kk = k; kk < k + B && kk < N; kk++) {
                            sum += A[ii * N + kk] * B[kk * N + jj];
                        }
                        C[ii * N + jj] += sum;
                    }
                }
            }
        }
    }
}


// 測試兩種方法的運行時間
void test_time() {
    double *A = malloc(sizeof(double) * N * N);
    double *B = malloc(sizeof(double) * N * N);
    double *C1 = malloc(sizeof(double) * N * N);
    double *C2 = malloc(sizeof(double) * N * N);


    generate_matrix(A);
    generate_matrix(B);


    clock_t start, end;


    start = clock();
    normal_multiply(A, B, C1);
    end = clock();
    printf("Normal multiply time: %.3f s\n", (double)(end - start) / CLOCKS_PER_SEC);


    start = clock();
    block_multiply(A, B, C2);
    end = clock();
    printf("Block multiply time: %.3f s\n", (double)(end - start) / CLOCKS_PER_SEC);


    free(A);
    free(B);
    free(C1);
    free(C2);
}


// 主函數(shù)
int main() {
    test_time();
    return 0;
}

C語言相關(guān)課程推薦:C語言相關(guān)課程

C

0 人點贊