需要在矩阵中查找元素。这个矩阵的排列如下:
每一行都是有序的。每一列都是有序的。
比如一个小矩阵。
10 30
20 80
现在,我们需要在一个这样N * M的矩阵中找到一个元素,并返回其位置。
思路
首先,这个题不太适合二分查找。因为并不能保证第二行的元素都一定比第一行的任意元素大。
所以应该是很难找到O(lgN)的算法。
每次都是取右上角的元素v与x(我们要查找的元素)进行比较较:
如果发现v > x,由于v所在列肯定比x大。所以v所在列可以舍弃。
如果发现v < x,由于v所在行肯定比x小。所以v所在行可能舍弃。
最后余下的,可能是一行,也可能是一列。总之可以利用二分查找来实现了。效率O(min(N, M)) + log(abs(M-N)).
解题
这里直接粘出代码。
int find(int **a, const int row, const int col, const int x, int *rpos, int *cpos) {
//右上角为起始点.
int from_row = 0, from_col = col - 1, v;
int b, e, mid;
*rpos = *cpos = -1;
while (from_row < row && from_col >= 0) {
v = a[from_row][from_col];
if (v == x) {
*rpos = from_row;
*cpos = from_col;
return 1;
}
from_row += x > v;
from_col -= v > x;
}
//最后剩下一行
if (from_row == (row - 1) && from_col != 0) {
b = 0, e = from_col + 1;
while (b < e) {
mid = b + ((e-b)>>1);
v = a[from_row][mid];
if (x == v) {
*rpos = from_row;
*cpos = mid;
return 1;
} else if (v > x) e = mid;
else b = mid + 1;
}
return 0;
}
//最后剩下一列
if (from_col == 0 && from_row != (row - 1)) {
b = from_row, e = row;
while (b < e) {
mid = b + ((e-b)>>1);
v = a[mid][0];
if (v == x) {
*rpos = mid;
*cpos = from_col;
return 1;
} else if (v > x) e = mid;
else b = mid + 1;
}
return 0;
}
//最后只剩下一个点
if (a[from_row][from_col] == x) {
*rpos = from_row;
*cpos = from_col;
return 1;
}
return 0;
}
这里写一个测试程序,如果有错,会输出Error。
无限循环.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
int _cmp(const void *a, const void *b) {
int x = (*(const int *)a);
int y = (*(const int *)b);
return x - y;
}
int **create_matrix(int row, int col) {
int *value = (int *) malloc(sizeof(int)*(row*col));
int **a = (int **) malloc(sizeof(int *)*row);
int i = 0, j = 0, iter = 0;
const int numbers = row * col;
for (i = 0; i < numbers ; ++i) {
value[i] = rand() % 4096;
}
qsort(value, numbers, sizeof(int), _cmp);
for (iter = i = 0; i < row; ++i) {
a[i] = (int *) malloc(sizeof(int)*col);
for (j = 0; j < col; ++j) {
a[i][j] = value[iter++];
}
}
free(value);
return a;
}
void destroy_matrix(int **a, const int row) {
int i = 0;
for (i = 0; i < row; ++i) {
free(a[i]);
}
free(a);
}
void print_matrix(int **a, int row, int col) {
int i, j;
for (i = 0; i < row; ++i) {
for (j = 0; j < col; ++j) {
printf("%d ", a[i][j]);
}
printf("\n");
}
}
int _binary_search(int *a, int from, int to, int x) {
int *mid = NULL, *b = a + from, *e = a + to;
while (b < e) {
mid = b + ((e-b) >> 1);
if (*mid == x) return mid - b;
else if (*mid > x) e = mid;
else b = mid + 1;
}
return -1;
}
int find(int **a, const int row, const int col, const int x, int *rpos, int *cpos) {
//右上角为起始点.
int from_row = 0, from_col = col - 1, v;
int b, e, mid;
*rpos = *cpos = -1;
while (from_row < row && from_col >= 0) {
v = a[from_row][from_col];
if (v == x) {
*rpos = from_row;
*cpos = from_col;
return 1;
}
from_row += x > v;
from_col -= v > x;
}
//最后剩下一行
if (from_row == (row - 1) && from_col != 0) {
b = 0, e = from_col + 1;
while (b < e) {
mid = b + ((e-b)>>1);
v = a[from_row][mid];
if (x == v) {
*rpos = from_row;
*cpos = mid;
return 1;
} else if (v > x) e = mid;
else b = mid + 1;
}
return 0;
}
//最后剩下一列
if (from_col == 0 && from_row != (row - 1)) {
b = from_row, e = row;
while (b < e) {
mid = b + ((e-b)>>1);
v = a[mid][0];
if (v == x) {
*rpos = mid;
*cpos = from_col;
return 1;
} else if (v > x) e = mid;
else b = mid + 1;
}
return 0;
}
//最后只剩下一个点
if (a[from_row][from_col] == x) {
*rpos = from_row;
*cpos = from_col;
return 1;
}
return 0;
}
int main(void) {
int row = 10, col = 10;
int r = -1, c = -1;
int x = rand() % row, y = rand() % col;
int **a = NULL;
while (true) {
row = rand() % 100 + 1;
col = rand() % 100 + 1;
a = create_matrix(row, col);
x = rand() % row, y = rand() % col;
if (find(a, row, col, a[x][y] , &r, &c)) {
if (a[x][y] != a[r][c]) {
printf("Error\n");
}
} else {
printf("Error\n");
}
if (!find(a, row, col, a[0][0], &r, &c)) {
printf("Error\n");
}
if (!find(a, row, col, a[row-1][0], &r, &c)) {
printf("Error\n");
}
if (!find(a, row, col, a[0][col-1], &r, &c)) {
printf("Error\n");
}
if (!find(a, row, col, a[row-1][col-1], &r, &c)) {
printf("Error\n");
}
destroy_matrix(a, row);
}
return 0;
}