算法导论-第九章-中位数和顺序统计量:最坏情况为线性时间的选择算法C++实现

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
#include <random>
using namespace std;

int find_mid(int M[], int size = 5) {
	vector<int> v{};
	for (int i = 0; i < size; ++i)
		v.push_back(M[i]);
	sort(v.begin(), v.end());
	return v[(size - 1) / 2];
}

void bult_the_mids(int A[], int mids_array[], int begin, int end, int sub_array_num, int last_array_size) {
	for (int i = 0; i < sub_array_num; ++i) {
		if (i != sub_array_num - 1) {
			int* M = new int[5] {};
			for (int j = 0; j < 5; ++j) {
				M[j] = A[begin + i * 5 + j];
			}
			mids_array[i] = find_mid(M);
			delete[]M;
		}
		else {
			int* M = new int[last_array_size] {};
			for (int j = 0; j < last_array_size; ++j) {
				M[j] = A[begin + i * 5 + j];
			}
			mids_array[i] = find_mid(M, last_array_size);
			delete[]M;
		}
	}
}

int find_index_by_value(int A[], int begin, int end, int x) {
	int i{};
	for (i = begin; i <= end; ++i) {
		if (A[i] == x)
			break;
	}
	return i;
}

int partition(int A[], int begin, int end, int choice) {
	int i{ begin - 1 };
	int j{ begin };
	swap(A[end], A[choice]);
	for (; j <= end - 1; ++j) {
		if (A[j] <= A[end]) {
			i++;
			swap(A[i], A[j]);
		}
	}
	swap(A[i + 1], A[end]);
	return i + 1;
}

int select(int A[], int begin, int end, int xth) {
	if (end <= begin) {
		return A[begin];
	}
	else {
		int size_of_A = end - begin + 1;
		int sub_array_num = ceil(size_of_A / 5.0f);
		int last_array_size = (size_of_A % 5) == 0 ? 5 : size_of_A % 5;
		int* mids_array = new int[sub_array_num] {};
		bult_the_mids(A, mids_array, begin, end, sub_array_num, last_array_size);
		int the_mid_of_mids = select(mids_array, 0, sub_array_num - 1, (sub_array_num - 1) / 2);
		delete[]mids_array;
		int index_of_final_mid{};
		index_of_final_mid = find_index_by_value(A, begin, end, the_mid_of_mids);
		int pivot{};
		pivot = partition(A, begin, end, index_of_final_mid);
		if (pivot == xth) {
			return A[pivot];
		}
		if (pivot > xth){
			return select(A, begin, pivot - 1, xth);
		}
		else{
			return select(A, pivot + 1, end, xth);
		}
	}
}

bool check_exist(vector<int> v, int x) {
	for (auto i : v) {
		if (i == x)
			return false;
	}
	return true;
}

int main(int argc, const char * argv[]) {
	int n;
	cout << "please enter the sum of elements you want to produce :" << endl;
	cin >> n;
	random_device rd{};
	default_random_engine e{ rd() };
	uniform_int_distribution<int> u{ -100, 500 };
	vector<int> v{};
	while (v.size() < n) {
		int element{ static_cast<int>(round(u((e)))) };
		if (check_exist(v, element))
			v.push_back(element);
	}
	cout << "the array is :" << endl;
	for (auto i : v) {
		cout << i << " ";
	}
	cout << endl;
	int* A = new int[v.size()]{};
	int  index{};
	for_each(v.begin(), v.end(), [=](int x)mutable{ A[index++] = x; });
	int xth{};
	cout << "please choose the index you want to select:" << endl;
	cin >> xth;
	cout << "the result is :" << endl;
	int result{};
	result = select(A, 0, static_cast<int>(v.size() - 1), xth);
	cout << result << endl;
	delete[]A;
	cout << "(cheat here) sorted array :" << endl;
	sort(v.begin(), v.end());
	for (auto i : v) {
		cout << i << " ";
	}
	cout << endl;
	system("pause");
	return 0;
}

点赞