2021杭电多校第二场

I love tree

题意:给一棵树,每次操作选择一条链,从起点到终点点权依次增加$1^2,2^2,…n^2$,单点询问当前点权​

容易想到树剖,然后考虑在线段树上维护这个问题

发现可以转化成选择一段区间$[l,r]$,对于每个点$x$,该位置增加的值为$(x-h)^2$​​,把平方项展开后得到$x^2+h^2-2xh$​,对这三项分别线段树维护即可(分别维护$x^2,-2x$​出现的次数和$h^2$的和)

有时候dfs序和修改顺序是反的,就是对于区间$[l,r]$​,每个点增加量分别是$n^2,(n-1)^2,…,1^2$​,实际上就是每个位置增加量变为$(h-x)^2$​​,h为r+1,而显然$(x-h)^2=(h-x)^2$​,因此这种情况下线段树不需要进行任何修改,只要调整传入的h参数即可

树剖时还需要稍微注意的是,这里对链的修改存在先后顺序,常规的树剖对于两个端点x,y是一起往上跳的,那么这里只在x往上跳时修改,y往上跳时先用vector记录哪些点需要修改,等x修改完再去修改y

#include <bits/stdc++.h>
#define debug(x) cerr << #x << " = " << x << ' '
#define debugl(x) cerr << #x << " = " << x << '\n'
#define all(x) x.begin(), x.end()
#define mem(x, y) memset(x, y, sizeof(x))
#define rep(x,y,z) for (int x=y;x<=z;x++)
#define per(x,y,z) for (int x=y;x>=z;x--)

using namespace std;
typedef long long ll;
typedef long long LL;
typedef pair<int, int> pii;

const int N = 2e5 + 5, M = 2e5 + 5, INF = 0x3f3f3f3f;
const ll INF64 = 0x3f3f3f3f3f3f3f3f;
struct Segment_tree {
    #define lson (o << 1)
    #define rson ((o << 1) | 1)
    #define mid ((l + r) >> 1)
    struct Node {
        ll v[3]; // v[0] : x^2, v[1] : h^2, v[2] : -2x
    }t[N << 2];
    void build(int o, int l, int r) {
        if (l == r) {
            t[o].v[0] = t[o].v[1] = t[o].v[2] = 0;
            return;
        }
        build(lson, l, mid); build(rson, mid + 1, r);
    }
    void modify(int o, int l, int r, int tl, int tr, int h) {
        if (r < tl || l > tr) return;
        if (tl <= l && r <= tr) {
            t[o].v[0] ++;
            t[o].v[1] += 1ll * h * h;
            t[o].v[2] += -2 * h;
            return;
        }
        modify(lson, l, mid, tl, tr, h); modify(rson, mid + 1, r, tl, tr, h);
    }
    ll query(int o, int l, int r, int pos, ll v0, ll v1, ll v2) {
        v0 += t[o].v[0];
        v1 += t[o].v[1];
        v2 += t[o].v[2];
        if (l == r) return v0 * pos * pos + v1 + pos * v2;
        if (pos <= mid) return query(lson, l, mid, pos, v0, v1, v2);
        else return query(rson, mid + 1, r , pos, v0, v1, v2);
    }
}ST;

ll q[N];

int n,m,r,p,ord[N];
int sz[N],dep[N],son[N],fa[N],top[N],num[N],cnt;

vector<int>G[N];

void init() {
    r = 1;
    memset(sz,0,sizeof(sz));
    cnt=0;
    fa[r]=-1;
    dep[r]=1;
    top[r]=r;
}

void dfs1(int x) {
    int mxson=-1;
    son[x]=-1;
    for (auto to:G[x]) {
        if (to==fa[x]) continue;
        dep[to]=dep[x]+1;
        fa[to]=x;
        dfs1(to);
        if (sz[to]>mxson) mxson=sz[to], son[x]=to;
        sz[x]+=sz[to];
    }
    sz[x]++;
}
void dfs2(int x,int tp) {
    num[x]=++cnt;
    ord[cnt]=q[x];
    top[x]=tp;
    if (~son[x]) dfs2(son[x],tp);
    for (auto to:G[x]) {
        if (to==fa[x] || to==son[x]) continue;
        dfs2(to,to);
    }
}
void chainModify(int x,int y) {
    int st = 0;
    vector<pii> vc;
    while(top[x]!=top[y]) {
        if (dep[top[x]] > dep[top[y]]) {
            ST.modify(1,1,n,num[top[x]],num[x], num[x] + st + 1);
            st += num[x] - num[top[x]] + 1;
            x = fa[top[x]];
        } else {
            vc.push_back({num[top[y]], num[y]});
            y = fa[top[y]];
        }
    }
    bool swp = 0;
    if (dep[x]>dep[y]) swap(x,y), swp = 1;
    ST.modify(1,1,n,num[x],num[y], swp ? num[y] + st + 1: num[x] - st - 1);
    st += num[y] - num[x] + 1;
    reverse(all(vc));
    for (auto X : vc) {
        int u = X.first, v = X.second;
        ST.modify(1, 1, n, u, v, u - st - 1);
        st += v - u + 1;
    }

}

ll spotQuery(int x) {
    return ST.query(1, 1, n, num[x], 0, 0, 0);
}

int main() {
    scanf("%d", &n);
    for (int i=1;i<=n-1;i++) {
        int u,v;
        scanf("%d%d",&u,&v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    init();
    dfs1(r);
    dfs2(r,r);
    ST.build(1,1,n);
    scanf("%d", &m);
    for (int i=1;i<=m;i++) {
        int opt,x,y;
        LL z;
        scanf("%d",&opt);
        if (opt==1) {
            scanf("%d%d",&x,&y);
            chainModify(x,y);
        }
        if (opt==2) {
            scanf("%d",&x);
            printf("%lld\n",spotQuery(x));
        }
    }
    return 0;
}

I love counting

题意:给一个序列A,每次查询区间$[l_i,r_i]$​中满足$c xor a\le b$​​​的​不同数c的个数(重复不计)

首先在字典树上思考,满足$c xor a\le b$​​的c​至多有$\log$​​​个区间:

从二进制最高位开始考虑,用u表示a的当前位,v表示b的当前位,分成以下四种情况讨论

  1. u=0, v=0 c向0子树搜索
  2. u=0, v=1 统计0子树下的所有答案,c向1子树搜索
  3. u=1, v=0 c向1子树搜索
  4. u=1, v=1 统计1子树下的所有答案,c向0子树搜索

可以发现最多统计$\log$次答案,并且每次统计的区间至多是上一次的一半

问题就转化成了和上一场1010几乎一样的问题,可以用莫队+分块解决,只不过每次查询需要统计$\log$次,复杂度上界是$O(n\sqrt n\log n)$,但是由于大多数查询区间较小,区间长度小于$2^{10}$的时间可以忽略不计,因此常数是非常小的,事实上也的确可以通过此题

#include <bits/stdc++.h>
#define debug(x) cerr << #x << " = " << x << ' '
#define debugl(x) cerr << #x << " = " << x << '\n'
#define all(x) x.begin(), x.end()
#define mem(x, y) memset(x, y, sizeof(x))
#define rep(x,y,z) for (int x=y;x<=z;x++)
#define per(x,y,z) for (int x=y;x>=z;x--)

using namespace std;
typedef long long ll;
typedef pair<int, int> pii;

const int N = 1e5 + 5, M = 2e5 + 5, INF = 0x3f3f3f3f;
const ll INF64 = 0x3f3f3f3f3f3f3f3f;

int n, m, a[M], ans[M], BB, sum[M], BB2, block[M];
int num[N];

struct Query {
    int l, r, a, b, idx;
    bool operator < (Query &Q) const {
        if (block[l] != block[Q.l]) return l < Q.l;
        if (block[l] & 1) return r < Q.r;
        else return r > Q.r;
    }
}q[N];

void modify(int X, int type) {
    int x = a[X];
    if (type == 1) {
        if (num[x] == 0) {
            assert(x / BB2 <= 100000);
            sum[x / BB2] += 1;
        }
        num[x]++;
    } else {
        if (num[x] == 1) {
            assert(x / BB2 <= 100000);
            sum[x / BB2] -= 1;
        }
        num[x]--;
    }
}

int query(int l, int r) {
    if (l > 100000) return 0;
    if (r > 100000) r = 100000;
    int res = 0;
    int b1 = l / BB2, b2 = r / BB2;
    if (b1 == b2) {
        for (int i = l; i <= r; i++) {
            assert(i <= 100000);
            assert(i >= 0);
            res += num[i] > 0;
        }
    } else {
        for (int i = l; i <= min(n, (b1 + 1) * BB2 - 1); i++) assert(i <= 100000),assert(i >= 0), res += num[i] > 0;
        for (int i = b2 * BB2; i <= r; i++) assert(i <= 100000),assert(i >= 0), res += num[i] > 0;
        for (int i = b1 + 1; i <= b2 - 1; i++) assert(i <= 100000),assert(i >= 0), res += sum[i];
    }
    return res;
}

int main() {
    scanf("%d", &n);
    BB = (int)sqrt(n);
    for (int i = 1; i <= n; i++) {
        assert(i <= 100000);
        block[i] = i <= BB ? 1 : block[i - BB] + 1;
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        assert(a[i] <= n);
    }

    scanf("%d", &m);
    BB2 = (int)sqrt(m);
    for (int i = 1; i <= m; i++) {
        scanf("%d %d %d %d", &q[i].l, &q[i].r, &q[i].a, &q[i].b);
        q[i].idx = i;
    }
    sort(q + 1, q + 1 + m);
    int L = 1, R = 0;
    for (int i = 1; i <= m; i++) {
        while(R < q[i].r) modify(++R, 1);
        while(L > q[i].l) modify(--L, 1);
        while(R > q[i].r) modify(R--, -1);
        while(L < q[i].l) modify(L++, -1);
        int now = 0, aans = 0;
        for (int j = (1 << 17); j; j >>= 1) {
            int u = q[i].a & j, v = q[i].b & j;
            if (u == 0 && v == 0) {
                continue;
            }
            if (u == 0 && v != 0) {
                aans += query(now, now + j - 1);
                now += j;
                continue;
            }
            if (u != 0 && v == 0) {
                now += j;
                continue;
            }
            if (u != 0 && v != 0) {
                aans += query(now + j, now + j + j - 1);
                continue;
            }
        }
        if (now <= 100000) aans += num[now] > 0;
        ans[q[i].idx] = aans;
    }
    for (int i = 1; i <= m; i++) printf("%d\n", ans[i]);
    return 0;
}

标算是一个$O(n\log^2n)$的做法,对于每个区间跑一次“二维数点”,原理就是将查询区间按照右端点从大到小排序,预处理出每个数右边第一个和它相同的数的位置,对于每个数,只有当【当前查询的区间右端点<其右边第一个和它相同的数的位置】时,才将该数记录到树状数组中。

正确性可以这样理解,考虑一个位置为$x_i$,右边第一个和它相同的数位置为$nxt_i$的数在什么时候能对答案产生贡献,即查询区间左端点$l_i\leqslant x_i$,且查询区间右端点$r_i\leqslant nxt_i-1$时能产生贡献,否则$l_i>x_i$时显然不行,$r_i\geqslant nxt_i$时,贡献由其右边第一个点与它相同的点产生。这样一来,树状数组限制了$l_i\leqslant x_i$,根据右端点排序记录限制了$r_i\leqslant nxt_i-1$

然而比较讽刺的是这个$O(n\log^2n)$的标算跑的没莫队$O(n\sqrt n\log n)$快…

#include<bits/stdc++.h>
#define N 100009
using namespace std;
typedef long long ll;
int nxt[N],c[N],head[N],Q,ans[N];
int n,tot,tr[N],ch[N*20][2],sum[N*20];
inline ll rd(){
    ll x=0;char c=getchar();bool f=0;
    while(!isdigit(c)){if(c=='-')f=1;c=getchar();}
    while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
    return f?-x:x;
}
inline void add(int x,int y){while(x<=n)tr[x]+=y,x+=x&-x;}
inline int query(int x){int ans=0;while(x)ans+=tr[x],x-=x&-x;return ans;}
struct node{
    int l,r,x,id;
    inline bool operator <(const node &b)const{
        if(x!=b.x)return x>b.x;
        if(id!=b.id)return id<b.id;
        return l<b.l; 
    }
};
struct qq{
    int l,r,a,b;
}q[N];
vector<node>vec[N*20];
vector<node>::iterator it; 
void query(int rt,int now,int a,int b,int id){
    if(!rt&&now!=16)return;
    if(now<0){
        vec[rt].push_back(node{q[id].l,q[id].r,q[id].r-1,id});
        return;
    }
    int xx=(a&(1<<now))!=0,yy=(b&(1<<now))!=0;
    if(yy) vec[ch[rt][xx]].push_back(node{q[id].l,q[id].r,q[id].r-1,id});
    xx^=yy;
    query(ch[rt][xx],now-1,a,b,id);
}
void ins(int x,int id){
    int now=0;
    for(int i=16;i>=0;--i){
        int xx=(x&(1<<i))!=0;
        if(!ch[now][xx])ch[now][xx]=++tot;
        now=ch[now][xx];
        sum[now]++;
        vec[now].push_back(node{id,0,nxt[id]-2,0});
    }
}
void solve(){
    n=rd();
    for(int i=1;i<=n;++i){
        c[i]=rd();
    }
    for(int i=n;i>=1;--i){
        if(!head[c[i]])nxt[i]=n+1;
        else nxt[i]=head[c[i]];
        head[c[i]]=i;
    }
    for(int i=1;i<=n;++i)ins(c[i],i);
    Q=rd();
    for(int i=1;i<=Q;++i){
        q[i].l=rd();q[i].r=rd();q[i].a=rd();q[i].b=rd();
        query(0,16,q[i].a,q[i].b,i);
    }
    for(int i=1;i<=tot;++i){
        sort(vec[i].begin(),vec[i].end());
        for(it=vec[i].begin();it!=vec[i].end();++it){
            if(it->id==0){
                add(it->l,1);      
            }
            else {
              ans[it->id]+=query(it->r)-query(it->l-1);
            }
        }
        for(it=vec[i].begin();it!=vec[i].end();++it){
            if(it->id==0){
              add(it->l,-1);
            }
        }
    }
    for(int i=1;i<=Q;++i)printf("%d\n",ans[i]);
}
int main(){
    int T=1;
    while(T--){
        solve(); 
    }
    return 0;
}

I love exam

01背包处理出$f[i][j]$表示第i门课,考j分,需要花费的最小天数

分组背包$f2[i][j][k]$表示考虑前i门课(每门课考0-100分这101项作为一个分组),花费j天,挂了k门课的最高分数

#include <bits/stdc++.h>
#define debug(x) cerr << #x << " = " << x << ' '
#define debugl(x) cerr << #x << " = " << x << '\n'
#define all(x) x.begin(), x.end()
#define mem(x, y) memset(x, y, sizeof(x))
#define rep(x,y,z) for (int x=y;x<=z;x++)
#define per(x,y,z) for (int x=y;x>=z;x--)

using namespace std;
typedef long long ll;
typedef pair<int, int> pii;

const int N = 5e1 + 5, M = 1e2 + 5, MOD = 1e9 + 7, INF = 0x3f3f3f3f;
const ll INF64 = 0x3f3f3f3f3f3f3f3f;

int add(int a, int b) { return a + b >= MOD ? a + b - MOD : a + b; }
int sub(int a, int b) { return a - b < 0 ? a - b + MOD : a - b; }
int mul(int a, int b) { return 1ll * a * b % MOD; }
int bin(int x, int p) {
    int res = 1;
    for (int b = x; p; p >>= 1, b = mul(b, b))
        if (p & 1) res = mul(res, b);
    return res;
}
const int maxn=111111,maxm=5555555;
const double pi=3.1415926535897932384626433832795,eps=1e-14;

int T, n, m, f[N][M], f2[N][505][5];

map<string ,int> mp;
vector<pii> vc[N];

string st;

void init() {
    memset(f, 0x3f, sizeof(f));
    memset(f2, -0x3f, sizeof(f2));
    for (int i = 0; i <= n; i++) f[i][0] = 0;
    f2[0][0][0] = 0;
    mp.clear();
}
void solve() {
    scanf("%d", &n);
    init();
    for (int i = 1; i <= n; i++) {
        cin >> st;
        mp[st] = i;
    }
    scanf("%d", &m);
    for (int i = 1; i <= m; i++) vc[i].clear();
    for (int i = 1; i <= m; i++) {
        cin >> st;
        int u, v; scanf("%d %d", &u, &v);
        vc[mp[st]].push_back({u, v});
    }
    for (int i = 1; i <= n; i++) {
        for (auto X : vc[i]) {
            int u = X.first, v = X.second;
            for (int j = 100 - u; j >= 0; j--) {
                f[i][j + u] = min(f[i][j + u], f[i][j] + v);
            }
        }
    }
    int t, p; scanf("%d %d", &t, &p);
    for (int i = 1; i <= n; i++) {
        for (int l = 0; l <= t; l++) {
            for (int k = 0; k <= p; k++) {
                for (int j = 0; j <= 100; j++) {
                    if (k == 0 && j < 60) continue;
                    if (l - f[i][j] >= 0)
                        f2[i][l][k] = max(f2[i][l][k], f2[i - 1][l - f[i][j]][k - (j < 60)] + j);
                }
            }
        }
    }
    int ans = -1;
    for (int l = 0; l <= t; l++) {
        for (int k = 0; k <= p; k++) {
            ans = max(ans, f2[n][l][k]);
        }
    }
    cout << ans << endl;
}
int main() {
    scanf("%d", &T);
    while(T--) {
        solve();
    }
    return 0;
}

I love permutation