XOR Sums Solution Codechef

XOR Sums Solution February Challenge 2021

You are given a sequence of positive integers A1,A2,…,ANA1,A2,…,AN. You should answer QQ queries. In each query:

  • You are given a positive integer MM.
  • Consider all non-empty subsequences of AA with length ≤M≤M. Recall that a subsequence is any sequence that can be created by deleting zero or more elements without changing the order of the remaining elements.
  • For each of these subsequences, compute the bitwise XOR of its elements. Your task is to determine the sum of these values. Since this sum can be very large, compute it modulo 998,244,353998,244,353.

Also See: February Long Challenge 2021 Solutions

Input

  • The first line of the input contains a single integer NN.
  • The second line contains NN space-separated integers A1,A2,…,ANA1,A2,…,AN.
  • The third line contains a single integer QQ.
  • QQ lines follow. Each of these lines contains a single integer MM describing a query.

Output

For each query, print a single line containing one integer ― the sum of bitwise XORs for all subsequences of AA with length ≤M≤M, modulo 998,244,353998,244,353.

Constraints

  • 1≤N,Q≤2⋅1051≤N,Q≤2⋅105
  • 1≤Ai<2301≤Ai<230 for each valid ii
  • 1≤M≤N1≤M≤N

Subtask

Subtask #1 (10 points): 1≤N,Q≤1,0001≤N,Q≤1,000

Subtask #2 (90 points): original constraints

Example Input

4
1 3 5 2
2
1
2

Example Output

11
34

Explanation

In the first query, the answer is just the sum of elements of AA (modulo 998,244,353998,244,353), which is 1+3+5+2=111+3+5+2=11.

In the second query, the answer is the sum of bitwise XORs for all subsequences with length 11 or 22, which is 1+3+5+2+(1⊕3)+(1⊕5)+(1⊕2)+(3⊕5)+(3⊕2)+(5⊕2)=34

Follow us on telegram for quick update an abundance of free knowledge: Click Here

Solution

Program C:

#include <stdio.h>
#define ll long long int
#define mod 998244353
#define N 200000
ll fact[N+1];
ll arr[N+1];
ll sum[N+1] = {0ll};
ll ncr(ll n, ll r){
    if(r > n)
        return 0;
    ll ans = (fact[n]/ (fact[n-r] * fact[r])) % mod;
    return ans;
}
ll nCrModPFermat(ll n, ll r)
{ 
    if (r == 0) 
        return 1; 
    if (n < r) 
        return 0;
    ll r_val = 1;
    ll b = mod - 2;
    ll a = fact[r];
    a = a % mod;
    while (b > 0) 
    { 

        if (b & 1) 
            r_val = (r_val * a) % mod;

        b = b >> 1; 
        a = (a * a) % mod; 
    }
    ll n_rval = 1;
    b = mod - 2;
    a = fact[n-r];
    a = a % mod;
    while (b > 0) 
    { 

        if (b & 1) 
            n_rval = (n_rval * a) % mod;

        b = b >> 1; 
        a = (a * a) % mod; 
    }
    return ((fact[n] * r_val) % mod * n_rval) % mod; 
}
int main(){
    ll n;
    scanf("%ld", &n);
    ll x[31] = {0};
    ll y[31] = {0};
    fact [0ll] = 1ll;
    for(ll i = 1; i< N; i++){
        fact[i] = (fact[i-1] * i) % mod;
    }
    for(int i = 0; i<n; i++){
        scanf("%ld", &arr[i]);
        for(int j=0; j < 31; j++){
            if(arr[i] & (1ll << j))
                x[j]++;
            else
                y[j]++;
        }
    }

    for (int i=1; i<=n; i++){
        for(int j=0; j<31; j++){
            if(x[j] != 0){
                for(int k=1; k<=i; k+=2){
                    if(x[j] >= k && y[j] >= (i-k)){
                        ll p1_val = nCrModPFermat(x[j], k);
                        ll p2_val = nCrModPFermat(y[j], i - k);
                        ll prob_1 = p1_val > 1 ? p1_val : 1;
                        ll prob_2 = p2_val > 1 ? p2_val : 1;
                        ll fin_prob = ((prob_1 % mod) * (prob_2 % mod)) % mod;
                        sum[i] = (sum[i] + ((fin_prob * ((1ll << j) % mod)) % mod)) % mod;
                    }
                }
            }
            else
                continue;
        }
        sum[i] = (sum[i-1] + sum[i])%mod;
    }
    int q;
    scanf("%d", &q);
    for(int i=0; i<q; i++){
        int m;
        scanf("%d", &m);
        printf("%lld\n", sum[m]);
    }
    return 0;
}

Program C++:

#include<bits/stdc++.h>
#define mp make_pair
#define ll long long
using namespace std;
                         
#define int long long

const int mod = 998244353;


struct base {
    double x, y;
    base() { 
        x = y = 0; 
    }
    base(double x, double y): x(x), y(y) 
    { }
    
};

inline base operator + (base a, base b) { 
    return base(a.x + b.x, a.y + b.y); 
}

inline base operator - (base a, base b) {
    return base(a.x - b.x, a.y - b.y); 
}

inline base operator * (base a, base b) {
    return base(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); 
}

inline base conj(base a) { 
    return base(a.x, -a.y); 
}

int lim = 1;


vector<base> roots = {{0, 0}, {1, 0}};
vector<int> rev = {0, 1};
const double PI = acosl(- 1.0);
void ensure_base(int p) {
    if(p <= lim) 
    return;
    rev.resize(1 << p);
    for(int i = 0; i < (1 << p); i++) 
    rev[i] = (rev[i >> 1] >> 1) + ((i & 1)  <<  (p - 1));
    roots.resize(1 << p);
    while(lim < p) {
        double angle = 2 * PI / (1 << (lim + 1));
        for(int i = 1 << (lim - 1); i < (1 << lim); i++) {
            roots[i << 1] = roots[i];
            double angle_i = angle * (2 * i + 1 - (1 << lim));
            roots[(i << 1) + 1] = base(cos(angle_i), sin(angle_i));
        }
        lim++;
    }
}

void fft(vector<base> &a, int n = -1) {
    if(n == -1) 
    n = a.size();
    assert((n & (n - 1)) == 0);
    int zeros = __builtin_ctz(n);
    ensure_base(zeros);
    int shift = lim - zeros;
    for(int i = 0; i < n; i++) 
    if(i < (rev[i] >> shift)) 
    swap(a[i], a[rev[i] >> shift]);
    for(int k = 1; k < n; k <<= 1) {
        for(int i = 0; i < n; i += 2 * k) {
            for(int j = 0; j < k; j++) {
                base z = a[i + j + k] * roots[j + k];
                a[i + j + k] = a[i + j] - z;a[i + j] = a[i + j] + z;
            }
        }
    }
}

vector<int> multiply(vector<int> &a, vector<int> &b, int eq = 0) {
    int need = a.size() + b.size() - 1;
    int p = 0;while((1 << p) < need) p++;
    ensure_base(p);int sz = 1 << p;
    vector<base> A, B;
    if(sz > (int)A.size()) A.resize(sz);
    for(int i = 0; i < (int)a.size(); i++) {
        int x = (a[i] % mod + mod) % mod;
        A[i] = base(x & ((1 << 15) - 1), x >> 15);
    }
    fill(A.begin() + a.size(), A.begin() + sz, base{0, 0});
    fft(A, sz);
    if(sz > (int)B.size()) B.resize(sz);
    if(eq) copy(A.begin(), A.begin() + sz, B.begin());
    else {
        for(int i = 0; i < (int)b.size(); i++) {
            int x = (b[i] % mod + mod) % mod;
            B[i] = base(x & ((1 << 15) - 1), x >> 15);
        }
        fill(B.begin() + b.size(), B.begin() + sz, base{0, 0});
        fft(B, sz);
    }
    double ratio = 0.25 / sz;
    base r2(0,  - 1), r3(ratio, 0), r4(0,  - ratio), r5(0, 1);
    for(int i = 0; i <= (sz >> 1); i++) {
        int j = (sz - i) & (sz - 1);
        base a1 = (A[i] + conj(A[j])), a2 = (A[i] - conj(A[j])) * r2;
        base b1 = (B[i] + conj(B[j])) * r3, b2 = (B[i] - conj(B[j])) * r4;
        if(i != j) {
            base c1 = (A[j] + conj(A[i])), c2 = (A[j] - conj(A[i])) * r2;
            base d1 = (B[j] + conj(B[i])) * r3, d2 = (B[j] - conj(B[i])) * r4;A[i] = c1 * d1 + c2 * d2 * r5;
            B[i] = c1 * d2 + c2 * d1;
        }
        A[j] = a1 * b1 + a2 * b2 * r5;
        B[j] = a1 * b2 + a2 * b1;
    }
    fft(A, sz); 
    fft(B, sz);
    vector<int> res(need);
    for(int i = 0; i < need; i++) {
        long long aa = A[i].x + 0.5;
        long long bb = B[i].x + 0.5;
        long long cc = A[i].y + 0.5;
        res[i] = (aa + ((bb % mod) << 15) + ((cc % mod) << 30))%mod;
    }
    return res;
}

int mul(int a, int b) {
    return (1LL * a * b) % mod;
}


//use
int po(int a, int deg){
    if (deg==0) return 1;
    if (deg%2==1) return mul(a, po(a, deg-1));
    int t = po(a, deg/2);
    return mul(t, t);
}

int inv(int n){
    return po(n, mod-2);
}

const int N = 200007;

vector<int> facs(N), invfacs(N);

void init(){
    facs[0] = 1;
    for (int i = 1; i<N; i++) 
    facs[i] = mul(facs[i-1], i);
    invfacs[N-1] = inv(facs[N-1]);
    for (int i = N-2; i>=0; i--) 
    invfacs[i] = mul(invfacs[i+1], i+1);
}

int C(int n, int k){
    if (n<k) return 0;
    return mul(facs[n], mul(invfacs[k], invfacs[n-k]));
}

int32_t main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    init();
    int n;
    cin>>n;
    vector<int> v(n);
    for(int i=0;i<n;i++){
        cin>>v[i];
    }
    vector<vector<int>> nbit(n+1,vector<int>(32,0));
    vector<int> ct(32,0);
    for(int i=0;i<n;i++){
        for(int j=0;j<32;j++){
            if((v[i]>>j)&1)ct[j]++;
        }
    }
    for(int j=0;j<32;j++){
        vector<int> v1(ct[j]+1,0);
        vector<int> v2(n-ct[j]+1,0);
        for(int i=0;i<=n;i++){
            if(i%2&&i<=ct[j])v1[i]=C(ct[j],i);
            if(i<=n-ct[j])v2[i]=C(n-ct[j],i);
        }
        vector<int> v3 = multiply(v1,v2);
        for(int i=1;i<=n;i++){
            nbit[i][j]=nbit[i-1][j]+v3[i];
        }
    }
    int q;
    cin>>q;
    while(q--){
        int m;
        cin>>m;
        int ans=0;
        for(int i=0;i<32;i++){
            ans+=(nbit[m][i]%mod*(1ll<<i)%mod)%mod;
            ans%=mod;
        }
        cout<<ans<<endl;
    }
    return 0;
}

Program Java:

import java.util.*;
import java.lang.*;
import java.io.*;
import java.math.*;

class Codechef
{
        //ncr
    public static long modPow(long a, long x, long p) {
        //calculates a^x mod p in logarithmic time.
        long res = 1;
        while(x > 0) {
            if( x % 2 != 0) {
                res = (res * a) % p;
            }
            a = (a * a) % p;
            x /= 2;
        }
        return res;
    }

    public static long modInverse(long a, long p) {
        //calculates the modular multiplicative of a mod m.
        //(assuming p is prime).
        return modPow(a, p-2, p);
    }
    public static long modBinomial(long n, long k, long p) {
        // calculates C(n,k) mod p (assuming p is prime).

        long numerator = 1; // n * (n-1) * ... * (n-k+1)
        for (int i=0; i<k; i++) {
            numerator = (numerator * (n-i) ) % p;
        }

        long denominator = 1; // k!
        for (int i=1; i<=k; i++) {
            denominator = (denominator * i) % p;
        }

        // numerator / denominator mod p.
        return ( numerator* modInverse(denominator,p) ) % p;
    }

    
    
    public static long calc(int yes[],int no[],int i,LinkedHashMap<String,Long> hm,long po[]){
        BigInteger p=new BigInteger("0");
        long q=998244353;
        for(int x=0;x<32;x++){
            int w=1;
            while(w<=yes[x] && w<=i){
                if((i-w)>no[x]){
                    w+=2;
                    continue;
                }
                BigInteger y=new BigInteger("0");
                y=y.add(BigInteger.valueOf(hm.get(yes[x]+" "+w)));
                y=y.multiply(BigInteger.valueOf(hm.get(no[x]+" "+(i-w))));
                y=y.multiply(BigInteger.valueOf(po[x]));
                y=y.mod(BigInteger.valueOf(q));
                p=p.add(y);
                p=p.mod(BigInteger.valueOf(q));
                w+=2;
                
            }
        }
        return p.longValue();
    }
	public static void main (String[] args) throws java.lang.Exception
	{
	    LinkedHashMap <String,Long> hm=new LinkedHashMap<String,Long>();
	    Scanner scan=new Scanner(System.in);
	    int n=scan.nextInt();
	    long z=998244353;
	    int a[]=new int[n];
	    long x[]=new long[n];
	    for(int i=0;i<n;i++){
	        a[i]=scan.nextInt();
	    }
	    int yes[]=new int[32];
	    int no[]=new int[32];
	    long pp[][]=new long[n+1][n+1];
	    long po[]=new long[32];
	    int hh=1;
	    for(int i=0;i<32;i++){
	        po[i]=hh;
	        hh=hh*2;
	    }
	    BigInteger y=new BigInteger("1");
	    for(int i=0;i<n+1;i++){
	        for(int j=0;j<=i;j++){
	            if(j==0 || j==i){
	                pp[i][j]=1;
	                hm.put(i+" "+j,1L);
	                y=new BigInteger("1");
	            }
	            else{
                    y=y.multiply(BigInteger.valueOf(i-j+1));
                    y=y.divide(BigInteger.valueOf(j));
                    y=y;
                    
	                hm.put(i+" "+j,y.mod(BigInteger.valueOf(z)).longValue());
	            }
	        }
	    }
	    for(int i=0;i<n;i++){
	        int w=a[i];
	        int v=0;
	        while(w!=0){
	            if(w%2==1){
	                yes[v]++;
	            }
	            v++;
	            w=w/2;
	        }
	    }
	    for(int i=0;i<32;i++){
	        no[i]=n-yes[i];
	    }
	    for(int i=1;i<=n;i++){
	        x[i-1]=calc(yes,no,i,hm,po);
	    }
	    for(int i=1;i<n;i++){
	        x[i]+=x[i-1];
	        x[i]=x[i]%z;
	    }
	    int q=scan.nextInt();
	    for(int i=0;i<q;i++){
	        int p=scan.nextInt();
	        System.out.println(x[p-1]);
	    }
	}
}

Program Python:

MOD = 998244353

bitPos = [0] * 30
ans = []
MOD = 998244353
facs = []
invfacs = []

def mul(a, b):
    return (a * b) % MOD


def sub(a, b):
    sub = (a + MOD - b)
    if sub >= MOD:
        sub -= MOD
    return sub


def po(a, deg):
    if deg == 0:
        return 1
    if deg % 2 == 1:
        return mul(a, po(a, deg - 1))
    t = po(a, deg // 2)
    return mul(t, t)


def inv(n):
    return po(n, MOD - 2)


def init(N):
    global facs
    global invfacs

    facs = [0] * N
    facs[0] = 1
    for i in range(1, N):
        facs[i] = mul(facs[i - 1], i)

    invfacs = [0] * N
    invfacs[N - 1] = inv(facs[N - 1])
    for i in range(N - 2, -1, -1):
        invfacs[i] = mul(invfacs[i + 1], i + 1)


def nCr(n, k):
    if n < k:
        return 0
    return mul(facs[n], mul(invfacs[k], invfacs[n - k]))


def makeBitPos(arr):
    global bitPos
    for el in arr:
        s = bin(el)[2:][::-1]
        for j in range(len(s)):
            if s[j] == '1':
                bitPos[j] += 1


def xorsum(n):
    global ans
    global arr
    global bitPos
    global MOD

    ans.append(sum(arr)%MOD)
    for m in range(2, n + 1):
        cur_sum = 0
        for b_pos in range(30):
            no = 1
            sm = 0
            while no <= m:
                sm += ((nCr(bitPos[b_pos], no) % MOD) * (nCr(n - bitPos[b_pos], m - no) % MOD)) % MOD
                sm = sm % MOD
                no += 2
            sm *= 2 ** b_pos
            cur_sum += sm % MOD
        ans.append((cur_sum + ans[-1]) % MOD)


def processQuery():
    global ans

    q = int(input())
    for q_ in range(q):
        m = int(input())
        print(ans[m - 1])



n = int(input())
init(200007)
arr = list(map(int, input().split()))
makeBitPos(arr)
xorsum(n)
processQuery()

Codechef Long Challenges

September Long Challenge 2021 Solution

February Long Challenge 2021

Leave a Comment