题意
给你一个大小为n(1 <= n <= 10 ^ 6)的 int 数组,找出连续子串的个数,满足连续子串的亦或值至少为k(1 <= k <= 10 ^ 9)
思路
字典树
代码
#define rep(i, n) for (int i = 0; i < (n); i++)
#define FOR(i, n, m) for (int i = (n); i <= (m); i++)
#define FORD(i, n, m) for (int i = (n); i >= (m); i--)
const int N = 30e6 + 5, L = 30;
int n, m, e, d;
int ch[N][2];
int wv[N];
void init() {
e = 1, wv[0] = 0;
ch[0][0] = ch[0][1] = -1;
}
void add(int x) {
int p = 0, id;
wv[p]++;
//sc(x);
FORD (i, L - 1, 0) {
id = (x >> i) & 1;
//sc3(p, id, ch[p][id]);
if (ch[p][id] == -1) {
ch[e][0] = ch[e][1] = -1;
wv[e] = 0;
ch[p][id] = e++;
}
p = ch[p][id], wv[p]++;
//sc3(p, wv[p], e);
}
}
int cal() {
int ans = 0, p = 0, id, x;
FORD (i, L - 1, 0) {
id = (m >> i) & 1, x = (d >> i) & 1 ^ 1;
if (id) { // for bit = 1
if (ch[p][x] == -1) return ans;
p = ch[p][x];
} else { // for bit = 0
if (ch[p][x] != -1) ans += wv[ch[p][x]];
if (ch[p][x ^ 1] != -1) p = ch[p][x ^ 1];
else return ans;
}
}
return ans + wv[p];
}
int main() {
int x;
LL ans;
while (~scanf("%d %d", &n, &m)) {
init();
ans = 0, d = 0;
rep (i, n) {
scanf("%d", &x);
add(d);
d ^= x;
ans += cal();
}
cout << ans << '\n';
}
return 0;
}
再来个指针版本
#define rep(i, n) for (int i = 0; i < (n); i++)
#define FOR(i, n, m) for (int i = (n); i <= (m); i++)
#define FORD(i, n, m) for (int i = (n); i >= (m); i--)
const int N = 30e6 + 5, L = 30;
struct Node {
Node *ch[2];
int val;
Node() {
ch[0] = ch[1] = NULL;
val = 0;
}
} *root;
int n, m, e, d;
void init() {
root = new Node();
}
void add(int x) {
int id;
Node *p = root;
p->val ++;
FORD (i, L - 1, 0) {
id = (x >> i) & 1;
if (p->ch[id] == NULL) {
Node *tmp = new Node();
p->ch[id] = tmp;
}
p = p->ch[id], p->val ++;
}
}
int cal() {
int ans = 0, id, x;
Node *p = root;
FORD (i, L - 1, 0) {
id = (m >> i) & 1, x = (d >> i) & 1 ^ 1;
if (id) { // for bit = 1
if (p->ch[x] == NULL) return ans;
p = p->ch[x];
} else { // for bit = 0
if (p->ch[x] != NULL) ans += p->ch[x]->val;
if (p->ch[x ^ 1] != NULL) p = p->ch[x ^ 1];
else return ans;
}
}
return ans + p->val;
}
int main() {
int x;
LL ans;
while (~scanf("%d %d", &n, &m)) {
init();
ans = 0, d = 0;
rep (i, n) {
scanf("%d", &x);
add(d);
d ^= x;
ans += cal();
}
printf("%I64d\n", ans);
}
return 0;
}