动态规划的一些经典算法

文章目录

动态规划

采用动态规划的两要素:存在最优子结构和重叠子问题

  • 最优子结构:
    如果一个问题中包含子问题的最优解,则该问题具有最优子结构,注意子问题之间应当互不影响。一个问题可以有多个子问题,要求一个问题的最优解,则要先求多个子问题中的最优子问题的解。
    动态规划用自底向上的的方式来利用最优子结构,也就是说先求子问题的最优解,解决子问题,再找到上级问题的最优解。贪心算法也适用于最优子结构,但它是以自顶向下的方式使用最优子结构,它会先做选择,在当时看来是最优的选择,然后再求解一个结果子问题,而不是先寻找子问题的最优解再做选择。
  • 重叠子问题:
    用来解原问题的递归算法可反复地解同样的子问题,也就是说当一个递归算法不断地调用同一问题时,我们说该问题包含重叠子问题。
    相反的,适合用分治法的问题往往在递归的每一步都会产生全新的问题。而动态规划总是充分利用重叠子问题,即通过每个子问题只解一次,然后把解保存在一个可以随时访问的表中(查表时间为常数),每次遇到直接查表中是否存在,再决定是否递归。

eg. 计算斐波拉契序列递归时会有很多数字被重复在递归,明显出现了重叠子问题,我们可以用动态规划思想优化它。

装配线调度

《动态规划的一些经典算法》
《动态规划的一些经典算法》

求解一个制造问题。汽车公司在有两条装配线的工厂内生产汽车,如图所示。一个汽车底盘在进入每一条装配线后,在一些装配站中会在底盘上安装部件,然后,完成的汽车在装配线的末端离开。每一条装配线上有n个装配站,编号为j=1,2,⋯,n。将装配线i(i为1或2)的第j个装配站表示为Si,j,装配线1的第j个站(S1,j)和装配线2的第j个站(S2,j)执行相同的功能。然而,这些装配站是在不同的时间建造的,并且采用了不同的技术,因此,每个站上所需的时间是不同的,即使是在两条不同装配线相同位置的装配站上也是这样。我们把在装配站Si,j上所需的装配时间记为ai,j。如图15-1所示,一个汽车底盘进入其中一条装配线,然后从每一站进行到下一站。底盘进人装配线i的进入时间为ei,装配完的汽车离开装配线i的离开时间为xi。
在正常情况下一旦一个底盘进入一条装配线后,它只会经过该条装配线,在相同的装配线中,从一个装配站到下一个装配站所花的时间可以忽略。偶尔会来一个特别急的订单,客户要求尽可能快地制造这些汽车。对这些加急的订单,底盘仍然依序经过n个装配站,但是工厂经理可以将部分完成的汽车在任何装配站上从一条装配线移到另一条装配线上。把已经通过装配站Si,j的一个底盘从Si,j移走到另一条线上的时间为ti,j,其中i=1,2,j=1,2,……,n-1(因为在第n个装配站后,装配已经完成)。
问题是:要确定在装配线1内选择哪些站以及在装配线2内选择哪些站,以使汽车通过的总时间最小。

寻找最优子结构,即用子问题的最优解寻找原问题的最优解

观察一条通过装配站S1,j的最快路线,会发现它必定是经过装配线1或2上的装配站j一1。因此,通过装配站S1,j的最快路线只能是以下二者之一:

  • 通过装配站S1,j-1的最快路线,然后直接通过装配站S1,j;
  • 通过装配站S2,j-1的最快路线,从装配线2移动到装配线1然后通过装配站S1,j;

利用对称的推理思想,通过装配站S2,j的最快路线也只能是以下二者之一:

  • 通过装配站S2,j-1的最快路线,然后直接通过装配站S2,j;
  • 通过装配站S1,j-1的最快路线,从装配线1移动到装配线2然后通过装配站S2,j;

为了解决这个问题,即寻找通过任一条装配线上的装配站i的最快路线,我们解决它的子问题,即寻找通过两条装配线上的装配站j一1的最快路线。

在这里我定义有n条装配线,每条线有m个站。因此通过装配站Si,j的最快路线如下:

  • 通过装配站Si,j-1的最快路线,然后直接通过装配站Si,j;
  • 通过装配站S非i,j-1的最快路线,从装配线非i移动到装配线i然后通过装配站Si,j;

所以对于装配线调度问题,通过建立子问题的最优解,就可以建立原问题某个实例的一个最优解了。

算法如下:
源码:https://github.com/yangbijia/algorithm.git

/**
 * 动态规划--装配线调度
 * m条装配线,每条装配线下n个装配站,计算从进入装配线到出去的最短时间和路径
 * @author bijiayang
 */
public class Main {

//    enter[i],exit[i],station[i][j],move[i][j]
//
//
//    f(i,1) = min(enter[i] + station[i][1]) (i=1...m,j=1)
//
//    f(i,2) = min(min(f(1,1),...,f(i,1)) + station[i][2])	(i=1...m,j=2)
//
//    f(i,j) = min(min(f(1,j - 1),...,f(i,j - 1)) + station[i][j])	(i=1...m,j=2...n)
//
//    f(i,n) = min(min(f(1,n - 1),...,f(i,n - 1)) + station[i][n] + exit[i]) (i-1...m,j=n)
//
//    l[i][j] = result i of min

    /**
     * 输入
     */
    static List<Integer> enter = Stream.of(new Integer[]{2, 4}).collect(Collectors.toList());
    static List<Integer> exit = Stream.of(new Integer[]{3, 2}).collect(Collectors.toList());
    static List<List<Integer>> station = Arrays.stream(new Integer[][]{{7, 9, 3, 4, 8, 4},{8, 5, 6, 4, 5, 7}})
            .map(integers -> Arrays.asList(integers))
            .collect(Collectors.toList());
    static List<List<Integer>> move = Arrays.stream(new Integer[][]{{2, 3, 1, 3, 4},{2, 1, 2, 2, 1}})
            .map(integers -> Arrays.asList(integers))
            .collect(Collectors.toList());
    static Integer line_size = enter.size(), station_size = station.get(0).size();

    /**
     * 输出
     */
    static List<List<Integer>> f = Arrays.stream(new Integer[][]{{-1, -1, -1, -1, -1, -1},{-1, -1, -1, -1, -1, -1}})
            .map(integers -> Arrays.asList(integers))
            .collect(Collectors.toList());
    static List<List<Integer>> l = Arrays.stream(new Integer[][]{{-1, -1, -1, -1, -1},{-1, -1, -1, -1, -1}})
            .map(integers -> Arrays.asList(integers))
            .collect(Collectors.toList());
    static List<Station> way = new LinkedList<>();


    public static void main(String[] args) {
        // 非递归求最短时间和路径
        fastway();

        // 递归求最短时间和路径
//        List<Integer> lastTimeList  = new ArrayList<>();
//        for (int i = 0; i < line_size; i++) {
//            lastTimeList.add(fastestwayRecursion(station_size - 1, i));
//        }
//        LineAndTime min = min(lastTimeList);
//        System.out.println("fatestTime = " + min.getTime());
//        printWay(min);


        // 时间表打印
        for (int i = 0; i < f.size(); i++) {
            List<Integer> list = f.get(i);
            for (Integer e : list) {
                System.out.print(String.format("%" + 4 + "s", e.toString()));
            }
            System.out.println();
        }

        // 线路表打印
        for (int i = 0; i < l.size(); i++) {
            List<Integer> list = l.get(i);
            for (Integer e : list) {
                System.out.print(String.format("%" + 4 + "s", e.toString()));
            }
            System.out.println();
        }
    }

    /**
     * 递归计算计算到给定节点的最短路径
     * @param stationPos
     * @param linePos
     * @return
     */
    public static Integer fastestwayRecursion(Integer stationPos, Integer linePos) {
        LineAndTime lineAndTime;
        // 中间节点
        if (stationPos > 0 && stationPos < station_size - 1) {
            Map<Integer, Integer> map = new HashMap<Integer, Integer>();
            for (int k = 0; k < line_size; k++) {
                Integer time = k == linePos ? f.get(linePos).get(stationPos) : -1, line = linePos;
                if (time == -1) {
                    Integer currentTime = fastestwayRecursion(stationPos - 1, k) + (k == linePos ?
                            0 : move.get(k).get(stationPos - 1)) + station.get(linePos).get(stationPos);
                    time = currentTime;
                    line = k;
                }
                map.put(line, time);
            }
            lineAndTime = min(map);
        // 开始节点
        } else if (stationPos == 0){
            Integer time = f.get(linePos).get(stationPos);
            if (time == -1) {
                Integer currentTime = enter.get(linePos) + station.get(linePos).get(stationPos);
                time = currentTime;
            }
            lineAndTime = new LineAndTime(linePos, time);
        // 结束节点
        } else {
            Map<Integer, Integer> map = new HashMap<Integer, Integer>();
            for (int k = 0; k < line_size; k++) {
                Integer time = k == linePos ? f.get(linePos).get(stationPos) : -1, line = linePos;
                if (time == -1) {
                    time = fastestwayRecursion(stationPos - 1, k) + (k == linePos ?
                            0 : move.get(k).get(stationPos - 1)) + station.get(linePos).get(stationPos) + exit.get(linePos);
                    line = k;
                }
                map.put(line, time);
            }
            lineAndTime = min(map);
        }

        Integer line = lineAndTime.getLine();
        Integer time = lineAndTime.getTime();

        if (f.get(linePos).get(stationPos) == -1) {
            f.get(linePos).set(stationPos, time);
        }

        if (stationPos > 0 && l.get(linePos).get(stationPos - 1) == -1) {
            l.get(linePos).set(stationPos - 1, line + 1);
        }

        return time;
    }

    /**
     * 非递归最短时间和路径
     */
    public static void fastway() {
        int i,j;
        for (i = 0; i < line_size; i++) {
            Integer fi1 = enter.get(i) + station.get(i).get(0);
            f.get(i).set(0, fi1);
        }
        LineAndTime currentLineAndTime;
        for (j = 1; j < station_size; j++) {
            for (i = 0; i < line_size; i++) {
                List<Integer> times = new ArrayList<Integer>();
                for (int k = 0; k < line_size; k++) {
                    times.add(f.get(k).get(j - 1) + (k == i ? 0 : move.get(k).get(j - 1)) + station.get(i).get(j));
                }
                currentLineAndTime = min(times);
                f.get(i).set(j, currentLineAndTime.getTime());
                l.get(i).set(j - 1, currentLineAndTime.getLine() + 1);
            }
        }
        List<Integer> times = new ArrayList<Integer>();
        for (i = 0; i < line_size; i++) {
            Integer exitTime = exit.get(i) + f.get(i).get(station_size - 1);
            times.add(exitTime);
        }
        currentLineAndTime = min(times);
        System.out.println("fasttime = " + currentLineAndTime.getTime().toString());

        printWay(currentLineAndTime);
    }

    public static void printWay(LineAndTime lastStation) {
        way.add(new Station(lastStation.getLine() + 1, station_size));
        for (int i = station_size - 2; i >= 0; i--) {
            way.add(new Station(l.get(way.get(station_size - 2 - i).getLinePos() - 1).get(i), i + 1));
        }
        System.out.println("fastway = " + way.toString());
    }

    public static LineAndTime min(List<Integer> times) {
        if (times.size() == 0) {
            return null;
        }
        Integer min = times.get(0),line = 0;
        for (int i = 0; i < line_size; i++) {
            Integer time = times.get(i);
            if (time < min) {
                min = time;
                line = i;
            }
        }
        return new LineAndTime(line, min);
    }

    public static LineAndTime min(Map<Integer, Integer> times) {
        if (times.size() == 0) {
            return null;
        }
        Integer min = times.get(0), line = 0;
        if (times.size() == 1) {
            line = times.keySet().stream().findFirst().get();
            return new LineAndTime(line, times.get(line));
        }
        for (int i = 0; i < line_size; i++) {
            Integer time = times.get(i);
            if (time < min) {
                min = time;
                line = i;
            }
        }
        return new LineAndTime(line, min);
    }
}

矩阵链相乘

这是解决矩阵链相乘问题的一个动态规划经典算法。给定由n个要相乘的矩阵构成的序列(链)<A1,A2,…,An>,要计算乘积

A1A2…An(15.10)

为计算式(15.10),可将两个矩阵相乘的标准算法作为一个子程序,根据括号给出的计算顺序做全部的矩阵乘法。一组矩阵的乘积是加全部括号的(fullyparenthesized),如果它是单个的矩阵,或是两个加全部括号的矩阵的乘积外加括号而成。矩阵的乘法满足结合率,故无论怎样加括号都会产生相同的结果。例如,如果矩阵链为<A1,A2,A3,A4>,乘积A1A2A3A4可用五种不同方式加全部括号:

(A1(A2(A3A4))),
(A1((A2A3)A4)),
((A1A2)(A3A4)),
((A1(A2A3))A4),
(((A1A2)A3)A4)。

仅当两个矩阵A和B相容(即A的列数等于B的行数)时,才可以进行相乘运算。如果A是pxq矩阵,B是qxr矩阵,则结果矩阵C是一个pxr矩阵,计算C的时间由矩阵的标量乘法的次数决定,这里为pqr。下面对时间代价的计算均按照乘法次数来表示。

问题是:求矩阵(A1A2……An)的最优加全括号,及最优时间代价。

这里用Ai,j表示对乘积AiAi+1……Aj求值的结果,其中i<=j
分为以下两种情况:

  • 如果i = j(即只有一个矩阵A1),则代价为0
  • 如果i < j,则必有乘积Ai,j的任何加全括号形式都将在Ak与Ak+1之间分开(即在Ak与Ak+1之间必定存在括号,称k为裂变点,当然我们现在还不知道k的具体值,i <= k < j),这样加全括号的代价就是计算Ai,k和Ak+1,j的代价之和,再加上两者相乘的代价。再从分解出来的子问题继续向下分解,直到分解到最小子问题。

根据子问题的最优解来递归定义一个最优解的代价。
这里用m[i,j]表示最优时间代价,得出结论:

  • m[i, j] = 0 ,i = j
  • m[i, j] = min{m[i, k] + m[k + 1, j] + Pi-1PkPj} ,i < j

下方左图为m[i, j](最优时间代价),右图为s[i, j](最优裂变位置)
《动态规划的一些经典算法》
下方为输入矩阵
《动态规划的一些经典算法》

算法如下:
源码:https://github.com/yangbijia/algorithm.git

/**
 * 计算矩阵A1*A2*……*An乘积的最优乘法,即最优加全括号
 * @author ellin
 * @since 2019/03/19
 */
public class Main {

    static List<Integer> p = Stream.of(new Integer[]{30, 35, 15, 5, 10, 20, 25}).collect(Collectors.toList());
    static int len = p.size();
    /**
     * 存放最优代价结果
     */
    static int[][] m = new int[len][len];

    /**
     * 存放裂变位置
     */
    static int[][] s = new int[len][len];

    public static void main(String[] args) {
        // 初始化结果数组
        for (int i = 0; i < len; i++) {
            for (int j = 0; j < len; j++) {
                m[i][j] = -1;
            }
        }
        for (int i = 0; i < len; i++) {
            for (int j = 0; j < len; j++) {
                s[i][j] = -1;
            }
        }

        //计算矩阵A1*A2*……*An乘积的最优代价,两个矩阵相乘 A1mxn * A2nxr 的代价为:m*n*r
        // 计算矩阵从i到j的最优代价
        int i = 1, j = len - 1;
        int min = m[i][j] == -1 ? optimalSolution(i, j, i) : m[i][j];
        int s_all = i;
        for (int pos = i + 1; pos < j; pos++) {
            int res = m[i][j] == -1 ? optimalSolution(i, j, pos) : m[i][j];
            if (res < min) {
                s_all = pos;
                min = res;
            }
        }
        if (s[i][j] == -1) {
            s[i][j] = s_all;
        }
        m[i][j] = min;
        System.out.println("min = " + min);

        for (i = 0; i < len; i++) {
            for (j = 0; j < len; j++) {
                System.out.print(String.format("%8s", m[i][j]));
            }
            System.out.println();
        }
        System.out.println();
        System.out.println();
        for (i = 0; i < len; i++) {
            for (j = 0; j < len; j++) {
                System.out.print(String.format("%8s", s[i][j]));
            }
            System.out.println();
        }
        System.out.println();

        printOptimalParens(1, len - 1);
        System.out.println();
    }

    /**
     * 递归计算矩阵i-j从k处开始裂变的标量值
     * @param i 矩阵下标Ai
     * @param j 矩阵下标Aj
     * @param k 裂变位置
     * @return
     */
    public static Integer optimalSolution(int i, int j, int k) {
        if (i == j) {
            m[i][j] = 0;
            return 0;
        } else {
            int min_left = m[i][k] == -1 ? optimalSolution(i, k, i) : m[i][k],
                min_right = m[k + 1][j] == -1 ? optimalSolution(k + 1, j, k + 1) : m[k + 1][j];

            int s_left = i, s_right = k + 1;
            // 计算k裂变点及其左边矩阵乘积的最优代价
            for (int pos = i + 1; pos < k; pos++) {
                int res = m[i][k] == -1 ? optimalSolution(i, k, pos) : m[i][k];
                if (res < min_left) {
                    s_left = pos;
                    min_left = res;
                }
            }
            if (s[i][k] == -1){
                s[i][k] = s_left;
            }
            m[i][k] = min_left;
            // 计算k右边矩阵乘积的最优代价
            for (int pos = k + 2; pos < j; pos++) {
                int res = m[k + 1][j] == -1 ? optimalSolution(k + 1, j, pos) : m[k + 1][j];
                if (res < min_right) {
                    s_right = pos;
                    min_right = res;
                }
            }
            if (s[k + 1][j] == -1){
                s[k + 1][j] = s_right;
            }
            m[k + 1][j] = min_right;
            int m_ijk = min_left + min_right + p.get(i - 1) * p.get(j) * p.get(k);
            return m_ijk;
        }

    }


    /**
     * 打印出相乘矩阵的最优加全括号
     * @param i 开始矩阵下标
     * @param j 末尾矩阵下标
     */
    public static void printOptimalParens(int i, int j) {
        if (i == j) {
            System.out.print("A" + i);
            return;
        }
        else {
            System.out.print("(");
            printOptimalParens(i, s[i][j]);
            printOptimalParens(s[i][j] + 1, j);
            System.out.print(")");
        }
    }

}
点赞