先贴上题目:
问题描述 有n个格子,从左到右放成一排,编号为1-n。 共有m次操作,有3种操作类型: 1.修改一个格子的权值, 2.求连续一段格子权值和, 3.求连续一段格子的最大值。 对于每个2、3操作输出你所求出的结果。 输入格式 第一行2个整数n,m。 接下来一行n个整数表示n个格子的初始权值。 接下来m行,每行3个整数p,x,y,p表示操作类型,p=1时表示修改格子x的权值为y,p=2时表示求区间[x,y]内格子权值和,p=3时表示求区间[x,y]内格子最大的权值。 输出格式 有若干行,行数等于p=2或3的操作总数。 每行1个整数,对应了每个p=2或3操作的结果。 样例输入 4 3 1 2 3 4 2 1 3 1 4 3 3 1 4 样例输出 6 3 数据规模与约定 对于20%的数据n <= 100,m <= 200。 对于50%的数据n <= 5000,m <= 5000。 对于100%的数据1 <= n <= 100000,m <= 100000,0 <= 格子权值 <= 10000。
初看这道题目,求连续N个数的和、最大值。如果每次查询区间[x,y]上的和或最大值都遍历[x,y]区间一遍,按照n和m的数量级,肯定会超时。
对于求和,我以前知道,可以用sum[n]保存前n个数的和,那么区间[x,y]所有数的和即为sum[y]-sum[x-1]。
但是对于求区间[x,y]的最大值,这种办法却行不通,只有借助于一种数据结构——线段树。很奇怪,以前竟然没有听说过线段树,《数据结构》,《算法导论》也都没有讲过,我可真是孤陋寡闻。
下面来看看什么是线段树:
引用一篇别人的文章——《线段树入门》:http://hi.baidu.com/semluhiigubbqvq/item/be736a33a8864789f4e4ad18
若根节点表示[x,y],mid = (x+y)/2,那么左儿子表示[x,mid],右儿子表示[mid+1,y],这样递归定义。
有了线段树后,查询任意一个区间[s,t]上的所有数的和、最大值就就可以在logn的时间内完成,n为总节点个数。
该睡觉了,直接贴代码:
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<malloc.h>
using namespace std;
const int MAXD = 100001;
typedef struct node{
struct node *lchild;
struct node *rchild;
int left,right;
int sum;
int max;
}NODE,*LNODE;
int arr[MAXD];
int sum[MAXD];
int build(int left,int right,LNODE p) //建立线段树并填充sum,max
{
p->left = left;
p->right = right;
p->sum = sum[right]-sum[left-1];
if(left == right)
{
p->lchild = p->rchild = NULL;
p->max = arr[left];
return p->max;
}
int mid = (left+right)/2;
p->lchild = (LNODE)malloc(sizeof(NODE));
p->rchild = (LNODE)malloc(sizeof(NODE));
int m = build(left,mid,p->lchild);
int n = build(mid+1,right,p->rchild);
p->max = m>n?m:n;
return p->max;
}
int change(LNODE p,int x,int y) //改变x格子值为y
{
int left = p->left;
int right = p->right;
//modify sum,max
p->sum = p->sum-arr[x]+y;
if(left == right)
{
p->max = y;
return p->max; //返回max好重新比较出max
}
int mid = (left+right)/2;
if(x <= mid)
{
int m = change(p->lchild,x,y);
p->max = m>((p->rchild)->max)?m:((p->rchild)->max);
}
else
{
int m = change(p->rchild,x,y);
p->max = m>((p->lchild)->max)?m:((p->lchild)->max);
}
return p->max;
}
int GetSum(int left,int right,LNODE p)
{
if(left==p->left && right==p->right) return p->sum;
int mid = (p->left+p->right)/2;
if(left > mid) return GetSum(left,right,p->rchild);
else if(right <= mid) return GetSum(left,right,p->lchild);
else return GetSum(left,mid,p->lchild)+GetSum(mid+1,right,p->rchild);
}
int GetMax(int left,int right,LNODE p)
{
if(left==p->left && right==p->right) return p->max;
int mid = (p->left+p->right)/2;
if(left > mid) return GetMax(left,right,p->rchild);
else if(right <= mid) return GetMax(left,right,p->lchild);
else
{
int m = GetMax(left,mid,p->lchild);
int n = GetMax(mid+1,right,p->rchild);
return m>n?m:n;
}
}
int main()
{
int n,m;
int i,k;
LNODE pnode;
int p,x,y;
cin>>n>>m;
memset(sum,0,sizeof(int)*MAXD);
for(i = 1;i <= n;i++)
{
cin>>arr[i];
sum[i] = arr[i]+sum[i-1];
}
pnode = (LNODE)malloc(sizeof(NODE));
build(1,n,pnode);
for(k = 0;k < m;k++) //m行命令
{
cin>>p>>x>>y;
if(p == 1)
{
change(pnode,x,y);
arr[x] = y;
}
else if(p == 2)
{
cout<<GetSum(x,y,pnode)<<endl;
}
else
{
cout<<GetMax(x,y,pnode)<<endl;
}
}
return 0;
}