POJ 4045 Power Station 2012金华邀请赛B题(树形DP)

很早就听说过树形DP了,只是一直没做过这方面的题。这次金华邀请赛出了这道树形DP,是zz_1215做出来的。发现这个考的挺多的,于是就学习了一下。


这个题大致意思是给你一颗树,让你求一点,使该点到其余各点的距离之和最小。如果这样的点有多个,则按升序依次输出。


我的做法:对于点x定义两个量dp[x]和node[x]。 dp[x]为以节点x为根的子树中x到它后代节点的距离之和;node[x]为以节点x为根的子树中节点的数目(含x)。则:

dp[x] = sum(dp[y] + node[y]);

node[x] = sum(node[y]) + 1;

其中y为x的孩子节点。

可以先任取一个节点作为根(如:1)进行一次dfs,求出每个节点i的dp[i]和node[i]。然后,再进行一次dfs,依次改变树根为每个节点,并更新和保存结果。其中,将树根由节点x变为其孩子节点y时,dp[y] = dpx] + n – 2 * node[y],n为节点总数。

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>

using namespace std;

const int maxn = 50010;
typedef __int64 LL;
vector<int> son[maxn], ans;
int t, n, I, R;
bool vis[maxn];
LL dp[maxn];   // dp[x]: x到x其后代节点的距离之和 
LL node[maxn]; // node[x]: 以x为根的子树的节点数(包括x) 
LL minDis;

/* 进行一次dfs, 求以节点x为根的子树有多少个节点, 
 * 以及节点x到其后代节点的距离之和
 */
void dfs1(int x)
{
    dp[x] = 0;
    node[x] = 1;
    vis[x] = true;
    int size = son[x].size();
    for (int i = 0; i < size; ++i)
    {
        int y = son[x][i];
        if (!vis[y])
        {
            dfs1(y);
            dp[x] += (dp[y] + node[y]);
            node[x] += node[y];
        }
    }
}

/* 分别以树中的每个节点为根, 
 * 求其到其它节点的距离之和 */ 
void dfs2(int x)
{
    vis[x] = true;
    if (dp[x] < minDis)
    {
        ans.clear();
        minDis = dp[x];
        ans.push_back(x);
    }
    else if (dp[x] == minDis)
    {
        ans.push_back(x);
    }
    int size = son[x].size();
    for (int i = 0; i < size; ++i)
    {
        int y = son[x][i];
        if (!vis[y])
        {
            // 树根由x变为y时,修改dp[y]的值
            dp[y] = dp[x] + n - 2 * node[y];
            dfs2(y);
        }
    }
}

void init()
{
    int x, y;
    for (int i = 1; i < maxn; ++i)
        son[i].clear();
    ans.clear();
         
    scanf("%d %d %d", &n, &I, &R);
    for (int i = 1; i < n; ++i)
    {
        scanf("%d %d", &x, &y);
        son[x].push_back(y);
        son[y].push_back(x);
    }
    minDis = 1 << 30;
}

void solve()
{
    memset(vis, false, sizeof(vis));
    dfs1(1); // 以1为根结点 
    memset(vis, false, sizeof(vis));
    dfs2(1);
    sort(ans.begin(), ans.end());
    
    printf("%I64d\n", minDis * I * I * R);
    for (int i = 0; i < ans.size() - 1; ++i)
        printf("%d ", ans[i]);
    printf("%d\n\n", ans[ans.size()-1]);
}

int main()
{
    //freopen("in.txt", "r", stdin);
    //freopen("out.txt", "w", stdout);
    scanf("%d", &t);
    while (t--)
    {
        init();
        solve();
    }
    return 0;
}

    原文作者:B树
    原文地址: https://blog.csdn.net/ahfywff/article/details/7565063
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞