Files
2026-06-15 09:00:38 +08:00

151 lines
4.9 KiB
C

#include <stdio.h>
#include <stdbool.h>
#include <stdlib.h>
#define SIZE 4 // 矩阵的大小
// 检查一个数是否是2的幂次方
bool isPowerOfTwo(int num) {
return (num & (num - 1)) == 0;
}
// 矩阵加法
void add(int* a, int* b, int* c, int length) {
for (int i = 0; i < length * length; i++) {
c[i] = a[i] + b[i];
}
}
// 获取结果矩阵
void getResult(int* a, int* b, int* result) {
int p1 = a[0] * (b[1] - b[3]);
int p2 = (a[0] + a[1]) * b[3];
int p3 = (a[2] + a[3]) * b[0];
int p4 = a[3] * (b[2] - b[0]);
int p5 = (a[0] + a[3]) * (b[0] + b[3]);
int p6 = (a[1] - a[3]) * (b[2] + b[3]);
int p7 = (a[0] - a[2]) * (b[0] + b[1]);
result[0] = p5 + p4 - p2 + p6;
result[1] = p1 + p2;
result[2] = p3 + p4;
result[3] = p5 + p1 - p3 - p7;
}
// Strassen 矩阵乘法
void sMM(int* a, int* b, int* result, int length) {
if (length == 2) {
getResult(a, b, result);
} else {
int tlength = length / 2;
int* aa = (int*)malloc(tlength * tlength * sizeof(int));
int* ab = (int*)malloc(tlength * tlength * sizeof(int));
int* ac = (int*)malloc(tlength * tlength * sizeof(int));
int* ad = (int*)malloc(tlength * tlength * sizeof(int));
int* ba = (int*)malloc(tlength * tlength * sizeof(int));
int* bb = (int*)malloc(tlength * tlength * sizeof(int));
int* bc = (int*)malloc(tlength * tlength * sizeof(int));
int* bd = (int*)malloc(tlength * tlength * sizeof(int));
int* t1 = (int*)malloc(tlength * tlength * sizeof(int));
int* t2 = (int*)malloc(tlength * tlength * sizeof(int));
int* t3 = (int*)malloc(tlength * tlength * sizeof(int));
int* t4 = (int*)malloc(tlength * tlength * sizeof(int));
int* temp = (int*)malloc(length * length * sizeof(int));
for (int i = 0; i < length; i++) {
for (int j = 0; j < length; j++) {
if (i < tlength) {
if (j < tlength) {
aa[i * tlength + j] = a[i * length + j];
ba[i * tlength + j] = b[i * length + j];
} else {
ab[i * tlength + (j - tlength)] = a[i * length + j];
bb[i * tlength + (j - tlength)] = b[i * length + j];
}
} else {
if (j < tlength) {
ac[(i - tlength) * tlength + j] = a[i * length + j];
bc[(i - tlength) * tlength + j] = b[i * length + j];
} else {
ad[(i - tlength) * tlength + (j - tlength)] = a[i * length + j];
bd[(i - tlength) * tlength + (j - tlength)] = b[i * length + j];
}
}
}
}
int* t1_ = (int*)malloc(tlength * tlength * sizeof(int));
int* t2_ = (int*)malloc(tlength * tlength * sizeof(int));
int* t3_ = (int*)malloc(tlength * tlength * sizeof(int));
int* t4_ = (int*)malloc(tlength * tlength * sizeof(int));
sMM(aa, ba, t1_, tlength);
sMM(ab, bc, t2_, tlength);
add(t1_, t2_, t1_, tlength);
sMM(aa, bb, t2_, tlength);
sMM(ab, bd, t3_, tlength);
add(t2_, t3_, t2_, tlength);
sMM(ac, ba, t3_, tlength);
sMM(ad, bc, t4_, tlength);
add(t3_, t4_, t3_, tlength);
sMM(ac, bb, t4_, tlength);
sMM(ad, bd, t1_, tlength);
add(t4_, t1_, t4_, tlength);
for (int i = 0; i < length; i++) {
for (int j = 0; j < length; j++) {
if (i < tlength) {
if (j < tlength)
result[i * length + j] = t1[i * tlength + j];
else
result[i * length + j] = t2[i * tlength + (j - tlength)];
} else {
if (j < tlength)
result[i * length + j] = t3[(i - tlength) * tlength + j];
else
result[i * length + j] = t4[(i - tlength) * tlength + (j - tlength)];
}
}
}
free(aa);
free(ab);
free(ac);
free(ad);
free(ba);
free(bb);
free(bc);
free(bd);
free(t1_);
free(t2_);
free(t3_);
free(t4_);
free(temp);
}
}
int main() {
int a[SIZE * SIZE] = { 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4 };
int b[SIZE * SIZE] = { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
int c[SIZE * SIZE];
sMM(a, b, c, SIZE);
for (int i = 0; i < SIZE * SIZE; i++) {
printf("%d ", c[i]);
if ((i + 1) % SIZE == 0) // 换行
printf("\n");
}
printf("%d\n", isPowerOfTwo(1));
printf("%d\n", isPowerOfTwo(2));
printf("%d\n", isPowerOfTwo(4));
printf("%d\n", isPowerOfTwo(6));
printf("%d\n", isPowerOfTwo(443));
return 0;
}