遗传算法简介:
一直都在收听卓老板聊科技这个节目,最近播出了一起人工智能的节目,主要讲的是由霍兰提出的遗传算法,在目中详细阐述了一个有趣的小实验:吃豆人。
首先简单介绍下遗传算法:
1:为了解决某个具体的问题,先随机生成若干个解决问题的实体,每个实体解决问题的方式都用“基因”来表示,也就是说,不同的实体拥有不同的基因,那么也对应着不同的解决问题的方案。
2:有了若干实体之后,接下来就是让这些实体来完成这个任务,根据任务的完成情况用相同标准打分。
3:接下来是进化环节,按照得分的高低,得出每个个体被选出的概率,得分越高越容易被选出,先选出两个个体,对其基因进行交叉,再按照设定的概率对其基因进行突变,来生成新个体,不停重复直到生成足够数量的新个体,这便是一次进化过程。按照这个方法不停的进化,若干代之后就能得到理想的个体。
下面简单介绍下吃豆人实验:
吃豆人首先生存在一个10*10个格子组成的矩形空间中,将50个豆子随机放在这100个格子中,每个格子要嘛为空,要嘛就有一颗豆子。吃豆人出生的时候随机出现在一个任意方格中,接下来吃豆人需要通过自己的策略来吃豆子,一共只有200步,吃到一颗+10分,撞墙-5分,发出吃豆子的动作却没吃到豆子-1分。另外吃豆人只能看到自己所在格子和上下左右一共5个格子的情况。
整理一下
吃豆人的所有动作:上移、下移、左移、由移、吃豆、不动、随机移动,一共7种
吃豆人所能观察到的状态:每个格子有,有豆子,无豆子,墙3种状态,而一共有5个格子,那就是3^5=243种状态。
为此,吃豆人个体的基因可以用243长度的基因表示,分别对应所处的243种状态,每个基因有7种情况,分别表示所处状态下产生的反应。
代码
Main.java
public class Main {
public static void main(String[] args) {
Population population = new Population(1000, false);
System.out.println(population);
long count = 1;
while (true){
Population newPopulation = Algorithm.evolve(population);
if (count % 5 == 0) {
System.out.println("The " + count + "'s evolve");
System.out.println(newPopulation);
}
population = newPopulation;
count++;
}
}
}
Individual.java
public class Individual {
//吃豆人一共会有3^5种状态,它能观察的位置一共有上下左右和当前格子,一个共5个,每个格子有墙,豆子,无豆子3种状态。
private static int length = 243;
/*吃豆人一共有7总动作 * 0 :上 4 : 随机移动 * 1 : 左 5 : 吃 * 2 : 下 6 : 不动 * 3 : 右 */
private static byte actionNum = 7;
private byte genes[] = null;
private int fitness = Integer.MIN_VALUE;
public Individual() {
genes = new byte[length];
}
public void generateGenes(){
for (int i = 0; i < length; i++) {
byte gene = (byte) Math.floor(Math.random() * actionNum);
genes[i] = gene;
}
}
public int getFitness() {
if (fitness == Integer.MIN_VALUE) {
fitness = FitnessCalc.getFitnessPall(this);
}
return fitness;
}
public int getLength() {
return length;
}
public byte getGene(int index) {
return genes[index];
}
public void setGene(int index, byte gene) {
this.genes[index] = gene;
fitness = Integer.MIN_VALUE;
}
//状态码的转换:5个3进制位,第一个代表中,第二个代表上,第三个代表右,第四个代表下,第五个代表左
public byte getActionCode(State state) {
int stateCode = (int) (state.getMiddle() * Math.pow(3, 4) + state.getUp() * Math.pow(3, 3) + state.getRight() * Math.pow(3, 2) + state.getDown() * 3 + state.getLeft());
return genes[stateCode];
}
@Override
public String toString() {
StringBuffer bf = new StringBuffer();
for (int i = 0; i < length; i++) {
bf.append(genes[i]);
}
return bf.toString();
}
public static void main(String[] args) {
Individual ind = new Individual();
ind.generateGenes();
System.out.println(ind);
System.out.println(ind.getFitness());
System.out.println(FitnessCalc.getFitnessPall(ind));
}
}
State.java
public class State {
//0为墙,1为有豆子,2为无豆子
private byte middle;
private byte up;
private byte right;
private byte down;
private byte left;
public State(byte middle, byte up, byte right, byte down, byte left) {
this.middle = middle;
this.up = up;
this.right = right;
this.down = down;
this.left = left;
}
public byte getMiddle() {
return middle;
}
public void setMiddle(byte middle) {
this.middle = middle;
}
public byte getUp() {
return up;
}
public void setUp(byte up) {
this.up = up;
}
public byte getRight() {
return right;
}
public void setRight(byte right) {
this.right = right;
}
public byte getDown() {
return down;
}
public void setDown(byte down) {
this.down = down;
}
public byte getLeft() {
return left;
}
public void setLeft(byte left) {
this.left = left;
}
}
Algorithm.java
public class Algorithm {
/* GA 算法的参数 */
private static final double uniformRate = 0.5; //交叉概率
private static final double mutationRate = 0.0001; //突变概率
private static final int tournamentSize = 3; //淘汰数组的大小
public static Population evolve(Population pop) {
Population newPopulation = new Population(pop.size(), true);
for (int i = 0; i < pop.size(); i++) {
//随机选择两个 优秀的个体
Individual indiv1 = tournamentSelection(pop);
Individual indiv2 = tournamentSelection(pop);
//进行交叉
Individual newIndiv = crossover(indiv1, indiv2);
newPopulation.saveIndividual(i, newIndiv);
}
// Mutate population 突变
for (int i = 0; i < newPopulation.size(); i++) {
mutate(newPopulation.getIndividual(i));
}
return newPopulation;
}
// 随机选择一个较优秀的个体,用了进行交叉
private static Individual tournamentSelection(Population pop) {
// Create a tournament population
Population tournamentPop = new Population(tournamentSize, true);
//随机选择 tournamentSize 个放入 tournamentPop 中
for (int i = 0; i < tournamentSize; i++) {
int randomId = (int) (Math.random() * pop.size());
tournamentPop.saveIndividual(i, pop.getIndividual(randomId));
}
// 找到淘汰数组中最优秀的
Individual fittest = tournamentPop.getFittest();
return fittest;
}
// 进行两个个体的交叉 。 交叉的概率为uniformRate
private static Individual crossover(Individual indiv1, Individual indiv2) {
Individual newSol = new Individual();
// 随机的从 两个个体中选择
for (int i = 0; i < indiv1.getLength(); i++) {
if (Math.random() <= uniformRate) {
newSol.setGene(i, indiv1.getGene(i));
} else {
newSol.setGene(i, indiv2.getGene(i));
}
}
return newSol;
}
// 突变个体。 突变的概率为 mutationRate
private static void mutate(Individual indiv) {
for (int i = 0; i < indiv.getLength(); i++) {
if (Math.random() <= mutationRate) {
// 生成随机的 0-6
byte gene = (byte) Math.floor(Math.random() * 7);
indiv.setGene(i, gene);
}
}
}
}
Population.java
public class Population {
private Individual[] individuals;
public Population(int size, boolean lazy) {
individuals = new Individual[size];
if (!lazy) {
for (int i = 0; i < individuals.length; i++) {
Individual ind = new Individual();
ind.generateGenes();
individuals[i] = ind;
}
}
}
public void saveIndividual(int index, Individual ind) {
individuals[index] = ind;
}
public Individual getIndividual(int index) {
return individuals[index];
}
public Individual getFittest() {
Individual fittest = individuals[0];
// Loop through individuals to find fittest
for (int i = 1; i < size(); i++) {
if (fittest.getFitness() <= getIndividual(i).getFitness()) {
fittest = getIndividual(i);
}
}
return fittest;
}
public Individual getLeastFittest() {
Individual ind = individuals[0];
for (int i = 1; i < size(); i++) {
if (ind.getFitness() > getIndividual(i).getFitness()) {
ind = getIndividual(i);
}
}
return ind;
}
public double getAverageFitness() {
double sum = 0;
for (int i = 0; i < size(); i++) {
sum += individuals[i].getFitness();
}
return sum / size();
}
public int size() {
return individuals.length;
}
@Override
public String toString(){
StringBuffer bf = new StringBuffer();
bf.append("Population size: " + size() + "\n");
bf.append("Max Fitnewss: " + getFittest().getFitness() + "\n");
bf.append("Least Fitness: " + getLeastFittest().getFitness() + "\n");
bf.append("Average Fitness: " + getAverageFitness() + "\n");
return bf.toString();
}
public static void main(String[] args) {
Population population = new Population(8000, false);
System.out.println(population);
}
}
MapMgr.java
public class MapMgr {
private static int x = 10;
private static int y = 10;
private static int beanNum = 50;
private static int mapNum = 100;
private static MapMgr manager = null;
private Map[] maps = null;
private MapMgr() {
maps = new Map[mapNum];
for (int i = 0; i < mapNum; i++) {
Map map = new Map(x, y);
map.setBeans(beanNum);
maps[i] = map;
}
}
synchronized public static MapMgr getInstance() {
if (manager == null) manager = new MapMgr();
return manager;
}
public Map getMap(int index) {
Map map = null;
index = index % mapNum;
try {
map = maps[index].clone();
} catch (CloneNotSupportedException e) {
e.printStackTrace();
}
return map;
}
public static void main(String[] args) {
MapMgr mgr = MapMgr.getInstance();
mgr.getMap(1).print();
System.out.println("--------------");
mgr.getMap(2).print();
}
}
Map.java
import java.awt.Point;
public class Map implements Cloneable{
private int x = -1;
private int y = -1;
private int total = -1;
private byte[][] mapGrid = null;
public Map(int x, int y) {
this.x = x;
this.y = y;
mapGrid = new byte[x][y];
total = x * y;
}
public void setBeans(int num) {
//check num
if (num > total) {
num = total;
}
for (int i = 0; i < num; i++) {
int address, xp, yp;
do{
address = (int) Math.floor((Math.random() * total)); //生成0 - (total-1)的随机数
xp = address / y;
yp = address % y;
//System.out.println(xp+ ":" + yp + ":" + address + ":" + total);
} while (mapGrid[xp][yp] != 0);
mapGrid[xp][yp] = 1;
}
}
public boolean isInMap(int x, int y) {
if (x < 0 || x >= this.x) return false;
if (y < 0 || y >= this.y) return false;
return true;
}
public boolean hasBean(int x, int y) {
boolean ret = mapGrid[x][y] == 0 ? false : true;
return ret;
}
public boolean eatBean(int x, int y) {
if(hasBean(x, y)) {
mapGrid[x][y] = 0;
return true;
}
return false;
}
public Point getStartPoint() {
int x = (int) Math.floor(Math.random() * this.x);
int y = (int) Math.floor(Math.random() * this.y);
return new Point(x, y);
}
public State getState(Point p) {
byte middle = stateOfPoint(p);
byte up = stateOfPoint(new Point(p.x, p.y - 1));
byte right = stateOfPoint(new Point(p.x + 1, p.y));
byte down = stateOfPoint(new Point(p.x, p.y + 1));
byte left = stateOfPoint(new Point(p.x - 1, p.y));
return new State(middle, up, right, down, left);
}
//0为墙,1为有豆子,2为无豆子
private byte stateOfPoint(Point p) {
byte ret;
if (!isInMap(p.x, p.y)) ret = 0;
else if (mapGrid[p.x][p.y] == 0) ret = 2;
else ret = 1;
return ret;
}
@Override
public Map clone() throws CloneNotSupportedException {
Map m = (Map) super.clone();
byte[][] mapGrid = new byte[x][y];
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
mapGrid[i][j] = this.mapGrid[i][j];
}
}
m.mapGrid = mapGrid;
return m;
}
public void print() {
for (int i = 0; i < y; i++) {
for (int j = 0; j < x; j++) {
System.out.print(mapGrid[j][i]);
}
System.out.println();
}
}
public static void main(String[] args) {
Map m = new Map(10, 5);
Map m1 = null;
try {
m1 = m.clone();
} catch (CloneNotSupportedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
m.setBeans(40);
m.print();
m1.setBeans(15);
m1.print();
}
}
FitnessCalc
import java.awt.Point;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
public class FitnessCalc {
/*动作结果说明: * 撞墙:-5分 * 吃到豆子:10分 * 吃空了:-1分 * 其他:0分 */
//模拟进行的场数
private static int DefaultSimTimes = 1000;
//模拟进行的步数
private static int simSteps = 200;
private static int cores = 4;
public static int getFitness(Individual ind) {
return getFitness(ind, DefaultSimTimes);
}
public static int getFitness(Individual ind, int simTimes) {
int fitness = 0;
MapMgr mgr = MapMgr.getInstance();
for (int i = 0; i < simTimes; i++) {
Map map = mgr.getMap(i);
Point point = map.getStartPoint();
for (int j = 0; j < simSteps; j++) {
State state = map.getState(point);
byte actionCode = ind.getActionCode(state);
fitness += action(point, map, actionCode);
//map.print();
//System.out.println("---");
}
}
return fitness / simTimes;
}
public static int getFitnessPall(Individual ind) {
int fitness = 0;
if (DefaultSimTimes < 100) {
fitness = getFitness(ind);
} else {
FutureTask<Integer>[] tasks = new FutureTask[cores];
for (int i = 0; i < cores; i++) {
FitnessPall pall = null;
if (i == 0) {
pall = new FitnessPall(ind, (DefaultSimTimes / cores) + DefaultSimTimes % cores);
} else {
pall = new FitnessPall(ind, DefaultSimTimes / cores);
}
tasks[i] = new FutureTask<Integer>(pall);
Thread thread = new Thread(tasks[i]);
thread.start();
}
for (int i = 0; i < cores; i++) {
try {
fitness += tasks[i].get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
fitness = fitness / cores;
}
return fitness;
}
private static int action(Point point, Map map, int actionCode) {
int sorce = 0;
switch (actionCode) {
case 0:
if (map.isInMap(point.x, point.y - 1)) {
sorce = 0;
point.y = point.y - 1;
} else {
sorce = -5;
}
break;
case 1:
if (map.isInMap(point.x - 1, point.y)) {
sorce = 0;
point.x = point.x - 1;
} else {
sorce = -5;
}
break;
case 2:
if (map.isInMap(point.x, point.y + 1)) {
sorce = 0;
point.y = point.y + 1;
} else {
sorce = -5;
}
break;
case 3:
if (map.isInMap(point.x + 1, point.y)) {
sorce = 0;
point.x = point.x + 1;
} else {
sorce = -5;
}
break;
case 4:
int randomCode = (int) Math.floor(Math.random() * 4);
sorce = action(point, map, randomCode);
break;
case 5:
if (map.eatBean(point.x, point.y)) {
sorce = 10;
} else {
sorce = -1;
}
break;
case 6:
sorce = 0;
break;
}
return sorce;
}
}
class FitnessPall implements Callable<Integer> {
private int simTimes;
private Individual ind;
public FitnessPall(Individual ind, int simTimes) {
this.ind = ind;
this.simTimes = simTimes;
}
@Override
public Integer call() throws Exception {
return FitnessCalc.getFitness(ind, simTimes);
}
}