线段树初级——《操作格子》

先贴上题目:

问题描述
有n个格子,从左到右放成一排,编号为1-n。
共有m次操作,有3种操作类型:
1.修改一个格子的权值,
2.求连续一段格子权值和,
3.求连续一段格子的最大值。
对于每个2、3操作输出你所求出的结果。

输入格式
第一行2个整数n,m。
接下来一行n个整数表示n个格子的初始权值。
接下来m行,每行3个整数p,x,y,p表示操作类型,p=1时表示修改格子x的权值为y,p=2时表示求区间[x,y]内格子权值和,p=3时表示求区间[x,y]内格子最大的权值。

输出格式
有若干行,行数等于p=2或3的操作总数。
每行1个整数,对应了每个p=2或3操作的结果。

样例输入
4 3
1 2 3 4
2 1 3
1 4 3
3 1 4

样例输出
6
3

数据规模与约定
对于20%的数据n <= 100,m <= 200。
对于50%的数据n <= 5000,m <= 5000。
对于100%的数据1 <= n <= 100000,m <= 100000,0 <= 格子权值 <= 10000。

初看这道题目,求连续N个数的和、最大值。如果每次查询区间[x,y]上的和或最大值都遍历[x,y]区间一遍,按照n和m的数量级,肯定会超时。

对于求和,我以前知道,可以用sum[n]保存前n个数的和,那么区间[x,y]所有数的和即为sum[y]-sum[x-1]。

但是对于求区间[x,y]的最大值,这种办法却行不通,只有借助于一种数据结构——线段树。很奇怪,以前竟然没有听说过线段树,《数据结构》,《算法导论》也都没有讲过,我可真是孤陋寡闻。

下面来看看什么是线段树:

引用一篇别人的文章——《线段树入门》:http://hi.baidu.com/semluhiigubbqvq/item/be736a33a8864789f4e4ad18

若根节点表示[x,y],mid = (x+y)/2,那么左儿子表示[x,mid],右儿子表示[mid+1,y],这样递归定义。

有了线段树后,查询任意一个区间[s,t]上的所有数的和、最大值就就可以在logn的时间内完成,n为总节点个数。

该睡觉了,直接贴代码:

#include<iostream>
#include<stdio.h>
#include<string.h>
#include<malloc.h>

using namespace std;

const int MAXD = 100001;

typedef struct node{
	struct node *lchild;
	struct node *rchild;
	int left,right;
	int sum;
	int max;
}NODE,*LNODE;

int arr[MAXD];
int sum[MAXD];

int build(int left,int right,LNODE p)	//建立线段树并填充sum,max 
{
	p->left = left;
	p->right = right;
	p->sum = sum[right]-sum[left-1];
	if(left == right)
	{
		p->lchild = p->rchild = NULL;
		p->max = arr[left];
		return p->max;
	}
	int mid = (left+right)/2;
	p->lchild = (LNODE)malloc(sizeof(NODE));
	p->rchild = (LNODE)malloc(sizeof(NODE));
	int m = build(left,mid,p->lchild);
	int n = build(mid+1,right,p->rchild);
	p->max = m>n?m:n;
	return p->max;
}

int change(LNODE p,int x,int y)	//改变x格子值为y 
{
	int left = p->left;
	int right = p->right;
	
	
	//modify sum,max
	p->sum = p->sum-arr[x]+y;
	
	if(left == right)
	{
		p->max = y;
		return p->max;	//返回max好重新比较出max 
	}
	
	int mid = (left+right)/2;
	if(x <= mid)
	{
		int m = change(p->lchild,x,y);
		p->max = m>((p->rchild)->max)?m:((p->rchild)->max);	
	} 
	else
	{
		int m = change(p->rchild,x,y);
		p->max = m>((p->lchild)->max)?m:((p->lchild)->max);	
	} 
	return p->max;
}

int GetSum(int left,int right,LNODE p)
{
	if(left==p->left && right==p->right) return p->sum;
	
	int mid = (p->left+p->right)/2;
	
	if(left > mid) return GetSum(left,right,p->rchild);
	else if(right <= mid) return GetSum(left,right,p->lchild);
	else return GetSum(left,mid,p->lchild)+GetSum(mid+1,right,p->rchild);
}

int GetMax(int left,int right,LNODE p)
{
	if(left==p->left && right==p->right) return p->max;
	
	int mid = (p->left+p->right)/2;
	if(left > mid) return GetMax(left,right,p->rchild);
	else if(right <= mid) return GetMax(left,right,p->lchild);
	else
	{
		int m = GetMax(left,mid,p->lchild);
		int n = GetMax(mid+1,right,p->rchild);
		return m>n?m:n;
	}
}

int main()
{
	int n,m;
	int i,k;
	LNODE pnode;
	int p,x,y;
	
	cin>>n>>m;
	memset(sum,0,sizeof(int)*MAXD);
	for(i = 1;i <= n;i++)
	{
		cin>>arr[i];
		sum[i] = arr[i]+sum[i-1];
	}
	
	pnode = (LNODE)malloc(sizeof(NODE));
	build(1,n,pnode);
	
	for(k = 0;k < m;k++)	//m行命令
	{
		cin>>p>>x>>y;
		if(p == 1)
		{
			change(pnode,x,y);
			arr[x] = y;
		}
		else if(p == 2)
		{
			cout<<GetSum(x,y,pnode)<<endl;
		}
		else
		{
			cout<<GetMax(x,y,pnode)<<endl;
		}
		
	} 
	return 0;
}
点赞