原题 https://www.luogu.org/problemnew/show/P3373
(手写线段树 wKw)
题目描述
如题,已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
输入输出格式
输入格式:
第一行包含三个整数N、M、P,分别表示该数列数字的个数、操作的总个数和模数。
第二行包含N个用空格分隔的整数,其中第i个数字表示数列第i项的初始值。
接下来M行每行包含3或4个整数,表示一个操作,具体如下:
操作1: 格式:1 x y k 含义:将区间[x,y]内每个数乘上k
操作2: 格式:2 x y k 含义:将区间[x,y]内每个数加上k
操作3: 格式:3 x y 含义:输出区间[x,y]内每个数的和对P取模所得的结果
输出格式:
输出包含若干行整数,即为所有操作3的结果。
输入输出样例
输入样例#1:
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4
输出样例#1:
17
2
说明
时空限制:1000ms,128M
数据规模:
对于30%的数据:N<=8,M<=10
对于70%的数据:N<=1000,M<=10000
对于100%的数据:N<=100000,M<=100000
题解
使用两个lazytag分别表示加法和乘法
在pushdown操作之后发现必须在向下传递lazytag的时候人为地为这两个lazytag规定一个先后顺序。只有两种情况:
- ①先算加法,规定
st[root*2].value=((st[root*2].value+st[root].add)*st[root].mul)%p
,但是这样非常不容易进行更新操作,如果改变add的数值,mul也要联动变成分数小数损失精度,这不是我们希望的。 - ②先算乘法,规定
st[root*2].value=(st[root*2].value*st[root].mul+st[root].add*(本区间长度))%p
,这样如果改变add不会影响mul,改变mul的时候把add也对应的乘一下就可以了,没有精度损失,这是可行的。
// P3372 【模板】线段树 1.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//
#include <iostream>
#include<stdio.h>
int n, m, p;
long long a[100000 + 7];
struct node {
long long v, mul, add;
}st[400000 + 6];
void init(int root, int l, int r) {
st[root].mul = 1;
st[root].add = 0;
if (l == r) {
st[root].v = a[l] % p;
}
else {
int m = (l + r) >> 1;
init(root * 2, l, m);
init(root * 2 + 1, m + 1, r);
st[root].v = (st[root * 2].v + st[root * 2 + 1].v) % p;
}
}
void pushdown(int root, int l, int r) {
int m = (r + l) >> 1;
st[root * 2].v = (st[root * 2].v*st[root].mul + (m - l + 1)*st[root].add) % p;
st[root * 2 + 1].v = (st[root * 2 + 1].v*st[root].mul + (r - m)*st[root].add) % p;
st[root * 2].mul = (st[root].mul*st[root * 2].mul) % p;
st[root * 2 + 1].mul = (st[root].mul*st[root * 2 + 1].mul) % p;
st[root * 2].add = (st[root * 2].add*st[root].mul + st[root].add) % p;
st[root * 2 + 1].add = (st[root * 2 + 1].add*st[root].mul + st[root].add) % p;
st[root].mul = 1;
st[root].add = 0;
}
//mul
void update1(int root, int stdl, int stdr, int l, int r, int k) {
if (stdl > r || stdr < l)
return;
if ((l <= stdl && stdr <= r)) {
st[root].v = st[root].v*k%p;
st[root].mul = st[root].mul*k%p;
st[root].add = st[root].add*k%p;
return;
}
pushdown(root, stdl, stdr);
int m = (stdl + stdr) /2;
update1(root * 2, stdl, m, l, r, k);
update1(root * 2+1,m+1, stdr, l, r, k);
st[root].v = (st[root * 2].v + st[root * 2 + 1].v)%p;
}
//add
void update2(int root, int stdl, int stdr, int l, int r, int k) {
if (stdl > r || stdr < l)
return;
if (l <= stdl && stdr <= r) {
st[root].v = (st[root].v+(stdr-stdl+1)*k)%p;
st[root].add = (st[root].add+k)%p;
return;
}
pushdown(root, stdl, stdr);
int m = (stdl + stdr)/2.;
update2(root * 2, stdl, m, l, r, k);
update2(root * 2 + 1, m+1, stdr, l, r, k);
st[root].v = (st[root * 2].v + st[root * 2 + 1].v) % p;
}
long long query(int root, int stdl, int stdr, int l, int r) {
if (stdl > r || stdr < l)
return 0LL;
if (l <= stdl && stdr <= r) {
return st[root].v;
}
pushdown(root, stdl, stdr);
int m = (stdl + stdr) >> 1;
return (query(root * 2, stdl, m, l, r)+query(root * 2 + 1, m+1, stdr, l, r))%p;
}
int main()
{
scanf("%d%d%d", &n, &m, &p);
for (int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
}
init(1, 1, n);
while (m--)
{
int k,a,b,x;
scanf("%d", &k);
if (k == 1) {
scanf("%d%d%d", &a, &b,&x);//if x==1 continue;
update1(1, 1, n, a, b, x);
}
else if (k == 2) {
scanf("%d%d%d", &a, &b, &x);
update2(1, 1, n, a, b, x);
}
else {
scanf("%d%d", &a, &b);
printf("%d\n", query(1, 1, n, a, b));
}
}
//system("pause");
}