151 lines
4.9 KiB
C
151 lines
4.9 KiB
C
|
|
#include <stdio.h>
|
||
|
|
#include <stdbool.h>
|
||
|
|
#include <stdlib.h>
|
||
|
|
|
||
|
|
#define SIZE 4 // 矩阵的大小
|
||
|
|
|
||
|
|
// 检查一个数是否是2的幂次方
|
||
|
|
bool isPowerOfTwo(int num) {
|
||
|
|
return (num & (num - 1)) == 0;
|
||
|
|
}
|
||
|
|
|
||
|
|
// 矩阵加法
|
||
|
|
void add(int* a, int* b, int* c, int length) {
|
||
|
|
for (int i = 0; i < length * length; i++) {
|
||
|
|
c[i] = a[i] + b[i];
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 获取结果矩阵
|
||
|
|
void getResult(int* a, int* b, int* result) {
|
||
|
|
int p1 = a[0] * (b[1] - b[3]);
|
||
|
|
int p2 = (a[0] + a[1]) * b[3];
|
||
|
|
int p3 = (a[2] + a[3]) * b[0];
|
||
|
|
int p4 = a[3] * (b[2] - b[0]);
|
||
|
|
int p5 = (a[0] + a[3]) * (b[0] + b[3]);
|
||
|
|
int p6 = (a[1] - a[3]) * (b[2] + b[3]);
|
||
|
|
int p7 = (a[0] - a[2]) * (b[0] + b[1]);
|
||
|
|
|
||
|
|
result[0] = p5 + p4 - p2 + p6;
|
||
|
|
result[1] = p1 + p2;
|
||
|
|
result[2] = p3 + p4;
|
||
|
|
result[3] = p5 + p1 - p3 - p7;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Strassen 矩阵乘法
|
||
|
|
void sMM(int* a, int* b, int* result, int length) {
|
||
|
|
if (length == 2) {
|
||
|
|
getResult(a, b, result);
|
||
|
|
} else {
|
||
|
|
int tlength = length / 2;
|
||
|
|
int* aa = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* ab = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* ac = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* ad = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* ba = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* bb = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* bc = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* bd = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* t1 = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* t2 = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* t3 = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* t4 = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* temp = (int*)malloc(length * length * sizeof(int));
|
||
|
|
|
||
|
|
for (int i = 0; i < length; i++) {
|
||
|
|
for (int j = 0; j < length; j++) {
|
||
|
|
if (i < tlength) {
|
||
|
|
if (j < tlength) {
|
||
|
|
aa[i * tlength + j] = a[i * length + j];
|
||
|
|
ba[i * tlength + j] = b[i * length + j];
|
||
|
|
} else {
|
||
|
|
ab[i * tlength + (j - tlength)] = a[i * length + j];
|
||
|
|
bb[i * tlength + (j - tlength)] = b[i * length + j];
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
if (j < tlength) {
|
||
|
|
ac[(i - tlength) * tlength + j] = a[i * length + j];
|
||
|
|
bc[(i - tlength) * tlength + j] = b[i * length + j];
|
||
|
|
} else {
|
||
|
|
ad[(i - tlength) * tlength + (j - tlength)] = a[i * length + j];
|
||
|
|
bd[(i - tlength) * tlength + (j - tlength)] = b[i * length + j];
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
int* t1_ = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* t2_ = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* t3_ = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
int* t4_ = (int*)malloc(tlength * tlength * sizeof(int));
|
||
|
|
|
||
|
|
sMM(aa, ba, t1_, tlength);
|
||
|
|
sMM(ab, bc, t2_, tlength);
|
||
|
|
add(t1_, t2_, t1_, tlength);
|
||
|
|
|
||
|
|
sMM(aa, bb, t2_, tlength);
|
||
|
|
sMM(ab, bd, t3_, tlength);
|
||
|
|
add(t2_, t3_, t2_, tlength);
|
||
|
|
|
||
|
|
sMM(ac, ba, t3_, tlength);
|
||
|
|
sMM(ad, bc, t4_, tlength);
|
||
|
|
add(t3_, t4_, t3_, tlength);
|
||
|
|
|
||
|
|
sMM(ac, bb, t4_, tlength);
|
||
|
|
sMM(ad, bd, t1_, tlength);
|
||
|
|
add(t4_, t1_, t4_, tlength);
|
||
|
|
|
||
|
|
for (int i = 0; i < length; i++) {
|
||
|
|
for (int j = 0; j < length; j++) {
|
||
|
|
if (i < tlength) {
|
||
|
|
if (j < tlength)
|
||
|
|
result[i * length + j] = t1[i * tlength + j];
|
||
|
|
else
|
||
|
|
result[i * length + j] = t2[i * tlength + (j - tlength)];
|
||
|
|
} else {
|
||
|
|
if (j < tlength)
|
||
|
|
result[i * length + j] = t3[(i - tlength) * tlength + j];
|
||
|
|
else
|
||
|
|
result[i * length + j] = t4[(i - tlength) * tlength + (j - tlength)];
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
free(aa);
|
||
|
|
free(ab);
|
||
|
|
free(ac);
|
||
|
|
free(ad);
|
||
|
|
free(ba);
|
||
|
|
free(bb);
|
||
|
|
free(bc);
|
||
|
|
free(bd);
|
||
|
|
free(t1_);
|
||
|
|
free(t2_);
|
||
|
|
free(t3_);
|
||
|
|
free(t4_);
|
||
|
|
free(temp);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
int main() {
|
||
|
|
int a[SIZE * SIZE] = { 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4 };
|
||
|
|
int b[SIZE * SIZE] = { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
|
||
|
|
int c[SIZE * SIZE];
|
||
|
|
|
||
|
|
sMM(a, b, c, SIZE);
|
||
|
|
|
||
|
|
for (int i = 0; i < SIZE * SIZE; i++) {
|
||
|
|
printf("%d ", c[i]);
|
||
|
|
if ((i + 1) % SIZE == 0) // 换行
|
||
|
|
printf("\n");
|
||
|
|
}
|
||
|
|
|
||
|
|
printf("%d\n", isPowerOfTwo(1));
|
||
|
|
printf("%d\n", isPowerOfTwo(2));
|
||
|
|
printf("%d\n", isPowerOfTwo(4));
|
||
|
|
printf("%d\n", isPowerOfTwo(6));
|
||
|
|
printf("%d\n", isPowerOfTwo(443));
|
||
|
|
|
||
|
|
return 0;
|
||
|
|
}
|