线段树 Segtree

原题 https://www.luogu.org/problemnew/show/P3373

(手写线段树 wKw)

题目描述

如题,已知一个数列,你需要进行下面三种操作:

1.将某区间每一个数乘上x

2.将某区间每一个数加上x

3.求出某区间每一个数的和

输入输出格式

输入格式:

第一行包含三个整数N、M、P,分别表示该数列数字的个数、操作的总个数和模数。

第二行包含N个用空格分隔的整数,其中第i个数字表示数列第i项的初始值。

接下来M行每行包含3或4个整数,表示一个操作,具体如下:
操作1: 格式:1 x y k 含义:将区间[x,y]内每个数乘上k
操作2: 格式:2 x y k 含义:将区间[x,y]内每个数加上k
操作3: 格式:3 x y 含义:输出区间[x,y]内每个数的和对P取模所得的结果

输出格式:

输出包含若干行整数,即为所有操作3的结果。

输入输出样例

输入样例#1:

5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4

输出样例#1:

17
2

说明

时空限制:1000ms,128M
数据规模:
对于30%的数据:N<=8,M<=10
对于70%的数据:N<=1000,M<=10000
对于100%的数据:N<=100000,M<=100000

题解

使用两个lazytag分别表示加法和乘法
在pushdown操作之后发现必须在向下传递lazytag的时候人为地为这两个lazytag规定一个先后顺序。只有两种情况:

  • ①先算加法,规定st[root*2].value=((st[root*2].value+st[root].add)*st[root].mul)%p,但是这样非常不容易进行更新操作,如果改变add的数值,mul也要联动变成分数小数损失精度,这不是我们希望的。
  • ②先算乘法,规定st[root*2].value=(st[root*2].value*st[root].mul+st[root].add*(本区间长度))%p,这样如果改变add不会影响mul,改变mul的时候把add也对应的乘一下就可以了,没有精度损失,这是可行的。
// P3372 【模板】线段树 1.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//

#include <iostream>
#include<stdio.h>

int n, m, p;

long long a[100000 + 7];

struct node {
    long long v, mul, add;
}st[400000 + 6];

void init(int root, int l, int r) {
    st[root].mul = 1;
    st[root].add = 0;
    if (l == r) {
        st[root].v = a[l] % p;
    }
    else {
        int m = (l + r) >> 1;
        init(root * 2, l, m);
        init(root * 2 + 1, m + 1, r);
        st[root].v = (st[root * 2].v + st[root * 2 + 1].v) % p;
    }
}

void pushdown(int root, int l, int r) {
    int m = (r + l) >> 1;
    st[root * 2].v = (st[root * 2].v*st[root].mul + (m - l + 1)*st[root].add) % p;
    st[root * 2 + 1].v = (st[root * 2 + 1].v*st[root].mul + (r - m)*st[root].add) % p;

    st[root * 2].mul = (st[root].mul*st[root * 2].mul) % p;
    st[root * 2 + 1].mul = (st[root].mul*st[root * 2 + 1].mul) % p;
    st[root * 2].add = (st[root * 2].add*st[root].mul + st[root].add) % p;
    st[root * 2 + 1].add = (st[root * 2 + 1].add*st[root].mul + st[root].add) % p;
    st[root].mul = 1;
    st[root].add = 0;
}

//mul
void update1(int root, int stdl, int stdr, int l, int r, int k) {
    if (stdl > r || stdr < l)
        return;
    if ((l <= stdl && stdr <= r)) {
        st[root].v = st[root].v*k%p;
        st[root].mul = st[root].mul*k%p;
        st[root].add = st[root].add*k%p;
        return;
    }
    pushdown(root, stdl, stdr);
    int m = (stdl + stdr) /2;
    update1(root * 2, stdl, m, l, r, k);
    update1(root * 2+1,m+1, stdr, l, r, k);
    st[root].v = (st[root * 2].v + st[root * 2 + 1].v)%p;
}

//add
void update2(int root, int stdl, int stdr, int l, int r, int k) {
    if (stdl > r || stdr < l)
        return;
    if (l <= stdl && stdr <= r) {
        st[root].v = (st[root].v+(stdr-stdl+1)*k)%p;
        st[root].add = (st[root].add+k)%p;
        return;
    }
    pushdown(root, stdl, stdr);
    int m = (stdl + stdr)/2.;
    update2(root * 2, stdl, m, l, r, k);
    update2(root * 2 + 1, m+1, stdr, l, r, k);
    st[root].v = (st[root * 2].v + st[root * 2 + 1].v) % p;
}

long long query(int root, int stdl, int stdr, int l, int r) {
    if (stdl > r || stdr < l)
        return 0LL;
    if (l <= stdl && stdr <= r) {
        return st[root].v;
    }
    pushdown(root, stdl, stdr);
    int m = (stdl + stdr) >> 1;
    return (query(root * 2, stdl, m, l, r)+query(root * 2 + 1, m+1, stdr, l, r))%p;
}

int main()
{
    scanf("%d%d%d", &n, &m, &p);
    for (int i = 1; i <= n; i++) {
        scanf("%lld", &a[i]);
    }
    init(1, 1, n);
    while (m--)
    {
        int k,a,b,x;
        scanf("%d", &k);
        if (k == 1) {
            scanf("%d%d%d", &a, &b,&x);//if x==1 continue;
            update1(1, 1, n, a, b, x);
        }
        else if (k == 2) {
            scanf("%d%d%d", &a, &b, &x);
            update2(1, 1, n, a, b, x);
        }
        else {
            scanf("%d%d", &a, &b);
            printf("%d\n", query(1, 1, n, a, b));
        }
    }
    //system("pause");
}

    原文作者:HAPPYers
    原文地址: https://www.jianshu.com/p/5c33e634fd81
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞