前言
餐馆
思路:可撤销的0-1背包
考察了多个知识点,包括
- 差分技巧
- 离线思路
- 0-1背包
不过这题卡语言,尤其卡python
import java.io.*; import java.util.*; import java.util.stream.Collectors; import java.util.stream.IntStream; public class Main { static final long mod = (long)1e9 + 7; public static void main(String[] args) { AReader scanner = new AReader(); PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out)); // 读取n和v int n = scanner.nextInt(); int v = scanner.nextInt(); // 读取并处理物品信息 List
packs = new ArrayList<>(); List ops = new ArrayList<>(); for (int i = 0; i < n; i++) { int s = scanner.nextInt(); int e = scanner.nextInt(); int w = scanner.nextInt(); packs.add(new int[]{s, e, w}); ops.add(new int[]{s, 1, w}); ops.add(new int[]{e + 1, -1, w}); } // 对ops按时间排序 ops.sort(Comparator.comparingInt(a -> a[0])); // 读取q和查询时间 int q = scanner.nextInt(); int[] arr = IntStream.range(0, q).map(i -> scanner.nextInt()).toArray(); // 将查询和索引关联起来 List qs = IntStream.range(0, q).mapToObj(i -> new int[]{arr[i], i}).collect(Collectors.toList()); qs.sort(Comparator.comparingInt(a -> a[0])); // 互斥的两类 long[] dp1 = new long[v + 1]; long[] dp2 = new long[v + 1]; // 初始化dp2 dp2[0] = 1; for (int[] pack : packs) { int s = pack[0]; int w = pack[2]; for (int m = v - w; m >= 0; m--) { dp2[m + w] += dp2[m]; dp2[m + w] %= mod; } } dp1[0] = 1; // 双指针,离散做法 int p1 = 0; int p2 = 0; int[][] res = new int[q][2]; for (int i = 0; i < 101; i++) { // 假设结束时间不超过arr[q-1] while (p1 < ops.size() && ops.get(p1)[0] <= i) { int[] op = ops.get(p1); int d = op[1]; int w = op[2]; if (d == 1) { add01(dp1, w, v); remove01(dp2, w, v); } else { remove01(dp1, w, v); add01(dp2, w, v); } p1++; } // 找到fz和bz int fz = 0; for (int j = v; j >= 0; j--) { if (dp1[j] > 0) { fz = j; break; } } int bz = 0; for (int j = v - fz; j >= 0; j--) { if (dp2[j] > 0) { bz = j; break; } } // 填充结果 while (p2 < qs.size() && qs.get(p2)[0] <= i) { res[qs.get(p2)[1]][0] = fz; res[qs.get(p2)[1]][1] = bz; p2++; } } // 输出结果 for (int[] pair : res) { out.println(pair[0] + " " + pair[1]); } out.flush(); out.close(); } private static void add01(long[] dp, int w, int v) { for (int u = v - w; u >= 0; u--) { dp[u + w] += dp[u]; dp[u + w] %= mod; } } private static void remove01(long[] dp, int w, int v) { for (int u = 0; u <= v - w; u++) { dp[u + w] -= dp[u]; dp[u + w] = (dp[u + w] % mod + mod) % mod; // 注意:在Java中,如果dp[u]是0或负数,可能需要额外的逻辑来避免负数 } } static class AReader { private BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); private StringTokenizer tokenizer = new StringTokenizer(""); private String innerNextLine() { try { return reader.readLine(); } catch (IOException ex) { return null; } } public boolean hasNext() { while (!tokenizer.hasMoreTokens()) { String nextLine = innerNextLine(); if (nextLine == null) { return false; } tokenizer = new StringTokenizer(nextLine); } return true; } public String nextLine() { tokenizer = new StringTokenizer(""); return innerNextLine(); } public String next() { hasNext(); return tokenizer.nextToken(); } public int nextInt() { return Integer.parseInt(next()); } public long nextLong() { return Long.parseLong(next()); } } } python 版本被卡常
# coding=utf-8 # coding=utf-8 import sys input=sys.stdin.buffer.readline n, v = list(map(int, input().split())) ops = [] packs = [] for i in range(n): s, e, w = list(map(int, input().split())) packs.append((s, e, w)) ops.append((s, 1, w)) ops.append((e + 1, -1, w)) ops.sort(key=lambda x: x[0]) t1, t2 = 100, 0 q = int(input()) arr = list(map(int, input().split())) qs = [] for i in range(q): qs.append((arr[i], i)) t1 = min(t1, arr[i]) t2 = max(t2, arr[i]) qs.sort(key=lambda x: [0]) # 互斥的两类 dp1 = [0] * (v + 1) dp2 = [0] * (v + 1) #-------------------------------- dp2[0] = 1 for (s, e, w) in packs: for m in range(v - w, -1, -1): dp2[m + w] += dp2[m] dp1[0] = 1 def add01(dp, w): for u in range(v - w, -1, -1): dp[u + w] += dp[u] def remove01(dp, w): for u in range(0, v - w + 1): dp[u + w] -= dp[u] # 双指针,离散做法 res = [[]] * q p1, p2 = 0, 0 for i in range(t1, t2 + 1): while p1 < len(ops) and ops[p1][0] <= i: d = ops[p1][1] if d == 1: add01(dp1, ops[p1][2]) remove01(dp2, ops[p1][2]) elif d == -1: remove01(dp1, ops[p1][2]) add01(dp2, ops[p1][2]) p1 += 1 # print (i, dp1, dp2) fz = max([i for i in range(v + 1) if dp1[i] > 0]) bz = max([i for i in range(v + 1) if dp2[i] > 0 and i <= (v - fz)]) while p2 < len(qs) and qs[p2][0] <= i: res[qs[p2][1]] = (fz, bz) p2 += 1 # for (fz, bz) in res: # print (fz, bz) print ("\n".join(map(lambda x: str(x[0]) + " "+ str(x[1]), res)))
写在最后