CCF CAT- 全国算法精英大赛(2024第二场)往届真题练习 4 | 珂学家


前言



餐馆

思路:可撤销的0-1背包

考察了多个知识点,包括

  • 差分技巧
  • 离线思路
  • 0-1背包

    不过这题卡语言,尤其卡python

    import java.io.*;
    import java.util.*;
    import java.util.stream.Collectors;
    import java.util.stream.IntStream;
    public class Main { static final long mod = (long)1e9 + 7;
        public static void main(String[] args) { AReader scanner = new AReader();
            PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
            // 读取n和v
            int n = scanner.nextInt();
            int v = scanner.nextInt();
            // 读取并处理物品信息
            List packs = new ArrayList<>();
            List ops = new ArrayList<>();
            for (int i = 0; i < n; i++) { int s = scanner.nextInt();
                int e = scanner.nextInt();
                int w = scanner.nextInt();
                packs.add(new int[]{s, e, w});
                ops.add(new int[]{s, 1, w});
                ops.add(new int[]{e + 1, -1, w});
            }
            // 对ops按时间排序
            ops.sort(Comparator.comparingInt(a -> a[0]));
            // 读取q和查询时间
            int q = scanner.nextInt();
            int[] arr = IntStream.range(0, q).map(i -> scanner.nextInt()).toArray();
            // 将查询和索引关联起来
            List qs = IntStream.range(0, q).mapToObj(i -> new int[]{arr[i], i}).collect(Collectors.toList());
            qs.sort(Comparator.comparingInt(a -> a[0]));
            // 互斥的两类
            long[] dp1 = new long[v + 1];
            long[] dp2 = new long[v + 1];
            // 初始化dp2
            dp2[0] = 1;
            for (int[] pack : packs) { int s = pack[0];
                int w = pack[2];
                for (int m = v - w; m >= 0; m--) { dp2[m + w] += dp2[m];
                    dp2[m + w] %= mod;
                }
            }
            dp1[0] = 1;
            // 双指针,离散做法
            int p1 = 0;
            int p2 = 0;
            int[][] res = new int[q][2];
            for (int i = 0; i < 101; i++) { // 假设结束时间不超过arr[q-1]
                while (p1 < ops.size() && ops.get(p1)[0] <= i) { int[] op = ops.get(p1);
                    int d = op[1];
                    int w = op[2];
                    if (d == 1) { add01(dp1, w, v);
                        remove01(dp2, w, v);
                    } else { remove01(dp1, w, v);
                        add01(dp2, w, v);
                    }
                    p1++;
                }
                // 找到fz和bz
                int fz = 0;
                for (int j = v; j >= 0; j--) { if (dp1[j] > 0) { fz = j;
                        break;
                    }
                }
                int bz = 0;
                for (int j = v - fz; j >= 0; j--) { if (dp2[j] > 0) { bz = j;
                        break;
                    }
                }
                // 填充结果
                while (p2 < qs.size() && qs.get(p2)[0] <= i) { res[qs.get(p2)[1]][0] = fz;
                    res[qs.get(p2)[1]][1] = bz;
                    p2++;
                }
            }
            // 输出结果
            for (int[] pair : res) { out.println(pair[0] + " " + pair[1]);
            }
            out.flush();
            out.close();
        }
        private static void add01(long[] dp, int w, int v) { for (int u = v - w; u >= 0; u--) { dp[u + w] += dp[u];
                dp[u + w] %= mod;
            }
        }
        private static void remove01(long[] dp, int w, int v) { for (int u = 0; u <= v - w; u++) { dp[u + w] -= dp[u];
                dp[u + w] = (dp[u + w] % mod + mod) % mod;
                // 注意:在Java中,如果dp[u]是0或负数,可能需要额外的逻辑来避免负数
            }
        }
        static
        class AReader { private BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
            private StringTokenizer tokenizer = new StringTokenizer("");
            private String innerNextLine() { try { return reader.readLine();
                } catch (IOException ex) { return null;
                }
            }
            public boolean hasNext() { while (!tokenizer.hasMoreTokens()) { String nextLine = innerNextLine();
                    if (nextLine == null) { return false;
                    }
                    tokenizer = new StringTokenizer(nextLine);
                }
                return true;
            }
            public String nextLine() { tokenizer = new StringTokenizer("");
                return innerNextLine();
            }
            public String next() { hasNext();
                return tokenizer.nextToken();
            }
            public int nextInt() { return Integer.parseInt(next());
            }
            public long nextLong() { return Long.parseLong(next());
            }
        }
    }
    

    python 版本被卡常

    # coding=utf-8
    # coding=utf-8
    import sys
    input=sys.stdin.buffer.readline
    n, v = list(map(int, input().split()))
    ops = []
    packs = []
    for i in range(n):
        s, e, w = list(map(int, input().split()))
        packs.append((s, e, w))
        ops.append((s, 1, w))
        ops.append((e + 1, -1, w))
    ops.sort(key=lambda x: x[0])
    t1, t2 = 100, 0
    q = int(input())
    arr = list(map(int, input().split()))
    qs = []
    for i in range(q):
        qs.append((arr[i], i))
        t1 = min(t1, arr[i])
        t2 = max(t2, arr[i])
    qs.sort(key=lambda x: [0])
    # 互斥的两类
    dp1 = [0] * (v + 1)
    dp2 = [0] * (v + 1)
    #--------------------------------
    dp2[0] = 1
    for (s, e, w) in packs:
        for m in range(v - w, -1, -1):
            dp2[m + w] += dp2[m]
    dp1[0] = 1
    def add01(dp, w):
        for u in range(v - w, -1, -1):
            dp[u + w] += dp[u]
    def remove01(dp, w):
        for u in range(0, v - w + 1):
            dp[u + w] -= dp[u]
    # 双指针,离散做法
    res = [[]] * q
    p1, p2 = 0, 0
    for i in range(t1, t2 + 1):
        while p1 < len(ops) and ops[p1][0] <= i:
            d = ops[p1][1]
            if d == 1:
                add01(dp1, ops[p1][2])
                remove01(dp2, ops[p1][2])
            elif d == -1:
                remove01(dp1, ops[p1][2])
                add01(dp2, ops[p1][2])
            p1 += 1
        # print (i, dp1, dp2)
        fz = max([i for i in range(v + 1) if dp1[i] > 0])
        bz = max([i for i in range(v + 1) if dp2[i] > 0 and i <= (v - fz)])
        while p2 < len(qs) and qs[p2][0] <= i:
            res[qs[p2][1]] = (fz, bz)
            p2 +=  1
    # for (fz, bz) in res:
    #     print (fz, bz)
    print ("\n".join(map(lambda x: str(x[0]) + " "+ str(x[1]), res)))
    

    写在最后