Treasure Hunt Solution September Challenge 2021

Codechef Treasure Hunt Solution

Chef lives in an N×MN×M grid. He is currently participating in a treasure hunt, and has two items left to find. Chef knows that the Manhattan distance between the cells containing these two items is exactly kk. He wants to know, in how many different pairs of cells can the two items be present?

Let AkAk be the number of desired pairs when the value of Manhattan distance between the two cells containing these two items is equal to kk. Let C=∑N+M−2i=1Ai⋅31i−1C=∑i=1N+M−2Ai⋅31i−1. You have to find the value of CC.

The answer may be large, so you need to find it modulo 998244353998244353.

The Manhattan distance between two points (x,y)(x,y) and (x′,y′)(x′,y′) is defined as |x−x′|+|y−y′||x−x′|+|y−y′|.

Also See: September Long Challenge 2021 Solutions

Input Format

  • The first line of the input contains a single integer TT denoting the number of test cases. The description of TT test cases follows.
  • Each testcase contains of a single line of input, two integers NN and MM.

Output Format

On a new line for each test case, print CC modulo 998244353998244353

Constraints

  • 1≤T≤51≤T≤5
  • 1≤N,M≤1071≤N,M≤107
  • The sum of NN over all tests does not exceed 107107.
  • The sum of MM over all tests does not exceed 107107.

Subtasks

Subtask #1 (5 points):

  • 1≤N,M≤1041≤N,M≤104
  • The sum of NN over all tests does not exceed 104104.
  • The sum of MM over all tests does not exceed 104104.

Subtask #2 (35 points):

  • 1≤N,M≤1061≤N,M≤106
  • The sum of NN over all tests does not exceed 106106.
  • The sum of MM over all tests does not exceed 106106.

Subtask #3 (60 points): original constraints

Sample Input 1 

3
2 3
2 4
100 350

Sample Output 1 

2115
65668
895852507

Explanation

Test case 11:
The pairs of points with distance 11 are:

  • (1,1)(1,1) and (1,2)(1,2)
  • (1,1)(1,1) and (2,1)(2,1)
  • (1,2)(1,2) and (1,3)(1,3)
  • (1,2)(1,2) and (2,2)(2,2)
  • (1,3)(1,3) and (2,3)(2,3)
  • (2,1)(2,1) and (2,2)(2,2)
  • (2,2)(2,2) and (2,3)(2,3)

The pairs of points with distance 22 are:

  • (1,1)(1,1) and (1,3)(1,3)
  • (1,1)(1,1) and (2,2)(2,2)
  • (1,2)(1,2) and (2,3)(2,3)
  • (1,2)(1,2) and (2,1)(2,1)
  • (1,3)(1,3) and (2,2)(2,2)
  • (2,1)(2,1) and (2,3)(2,3)

The pairs of points with distance 33 are:

  • (1,1)(1,1) and (2,3)(2,3)
  • (2,1)(2,1) and (1,3)(1,3)

Therefore, the answer is 7⋅310+6⋅311+2⋅3127⋅310+6⋅311+2⋅312 = 21152115.

Follow us on telegram for quick update an abundance of free knowledge: Click Here

Solution

Program C:

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

int main()
{
    long long int mod = 998244353;

	long long int q;
	scanf("%lld", &q);

	while ( q > 0 )
	{
		long long int n, m;
		scanf("%lld %lld", &n, &m);
		long long int smax, smin;
		if ( n >= m )
		{
			smax = n;
			smin = m;
		}
		else
		{
			smax = m;
			smin = n;
		}
		long long int exp = 1;
		long long int c = 0;
		long long int k;
		for ( k = 1; k <= n+m-2; ++k )
		{
			long long int xmax, xmin;
			if (k <= smax )
			{
				xmax = k;
			}
			else
			{
				xmax = smax;
			}

			if (0 >= k - smin )
			{
				xmin = 0;
			}
			else
			{
				xmin = k - smin;
			}
			long long int inv = 166374059;
			long long int beta = ((((((xmax*(xmax+1))%mod)*(2*xmax+1))%mod)*inv)%mod - ((((((xmin-1)*xmin)%mod)*(2*xmin-1))%mod)*inv)%mod + mod)%mod;
			long long int alfa = ((xmax*(xmax + 1)/2)%mod - (xmin*(xmin-1)/2)%mod + mod)%mod;
			//long long int beta = ((xmax*(xmax+1)*(2*xmax+1))%mod - ((xmin-1)*xmin*(2*xmin-1)/6)%mod + mod)%mod;
			long long int fe1 = 0;
			if (xmin == 0)
			{
				fe1 = ((smax*smin)%mod - (smax*k)%mod + mod)%mod;
			}
			long long int fe2 = 0;
			if (k-xmax == 0)
			{
				fe2 = ((smax*smin)%mod - (smax*k)%mod + (xmax*(smax-smin+k))%mod - (xmax*xmax)%mod + mod)%mod;
			}
			long long int total = 0;
			long long int sumaf = (fe1 + fe2)%mod;
			long long int negativef = ((((smax*k)%mod)*(xmax-xmin+1))%mod + beta%mod)%mod;
			total = ((((smax*smin)%mod)*(xmax-xmin+1))%mod + ((smax-smin+k)*alfa)%mod - negativef  + mod)%mod;
			total = (2*total)%mod;
			total = (total - sumaf + mod)%mod;
			c = (c + (total*exp)%mod)%mod;
			exp = (31*exp)%mod;

		}
		--q;
		printf("%lld\n", c);
		
	}
}

Program C++:

#include <iostream>
#include<bits/stdc++.h>
using namespace std;
const long int M=998244353;
typedef long long int ll;
int main() {
	// your code goes here
	ios_base::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	ll t;
	cin>>t;
	while(t--)
	{
	    ll n,m;
	    cin>>n>>m;
	    if(m>n)
	    {
	        ll f=m;
	        m=n;
	        n=f;
	    }
	    ll s=0,c,p=1;
	    ll z;
	    for(z=1;z<=(m+n-2);z++)
	    {
	        if(z<m)
	        {
	            c=((((2*m)%M)*(((z-1)*(n-z))%M))%M+((z%M)*(((z-1)*(m+z-n))%M))%M-((((z*(z-1))%M)*((2*z-1)%M))/3)%M+(m*(n-z))%M+(n*(m-z))%M+2*M);
	            c=c%M;
	            if(z>1)
	            {
	                p=p*31;
	                p=p%M;
	            }
	            c=c*p;
	            c=c%M;
	            s+=c;
	            s=s%M;
	            
	        }
	        else if(z<n)
	        {
	            c=(((((m*(m-1))%M)*((3*n+m-3*z+1)%M))/3)%M+m*(n-z)%M);
	            c=c%M;
	           if(z>1)
	           {
	                p=p*31;
	                p=p%M;
	           }
	           c=c*p;
	            c=c%M;
	            s+=c;
	            s=s%M;
	        }
	        else
	        {
	            c=((((m*(m-1))%M)*((3*n+m-3*z+1)%M))/3)%M+(((2*m)%M)*(((z-n)*(z-n))%M))%M-(((z-n)%M)*(((z-n+1)*(m+z-n))%M))%M+(((((z-n)*(z-n+1))%M)*((2*(z-n)+1)%M))/3)%M+2*M;
	            c=c%M;
	           if(z>1)
	           {
	                p=p*31;
	                p=p%M;
	           }
	           c=c*p;
	            c=c%M;
	            s+=c;
	            s=s%M;
	        }
	       //cout<<c<<' ';
	    }
	    cout<<s<<'\n';
	}
	return 0;
}

Program Java:

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.math.BigInteger;
import java.nio.charset.Charset;
import java.util.Scanner;
public class Main {
  static final boolean EA = false;
  private int n;
  private int m;

  public static void main(String args[]) throws Exception {
    InputStream is;
    if (System.getProperty("user.name").equals("dev")) {
      String in = "3\n" + "2 3\n" + "2 4\n" +
          "100 350\n";
      is = new ByteArrayInputStream(in.getBytes());

    } else {
      is = System.in;
    }
    BufferedReader br = new BufferedReader(new InputStreamReader(is, Charset.defaultCharset()), 1 << 17);
    Scanner sc = new Scanner(br);
    run(sc);
  }

  private static void run(Scanner sc) {
    int numCases = sc.nextInt();

    for (int caseNum = 0; caseNum < numCases; caseNum++) {
      runCase(sc);
    }
  }

  private static void runCase(Scanner sc) {
    int n = sc.nextInt();
    int m = sc.nextInt();
    System.out.println(new Main(n, m).getResult());
  }

  final static int L = (int) (2 * 1e7 + 10);
  final static int M = 998244353;
  final static BigInteger MM = BigInteger.valueOf(M);

  static int prefix_sum_n[];
  static int prefix_sum_n2[];

  static {
    prefix_sum_n = new int[L];
    prefix_sum_n2 = new int[L];

    long s = 0;
    long s2 = 0;

    for (int i = 0; i < L; i++) {

      long x = i;
      s += x;
      s2 += x * x;

      s %= M;
      s2 %= M;
      prefix_sum_n[i] = (int) s;
      prefix_sum_n2[i] = (int) s2;
    }
  }

  public Main(int n, int m) {
    this.n = n;
    this.m = m;
  }

  private int getResult() {
    if (EA) {
      System.out.printf("@ %d %d\n", n, m);
    }
    BigInteger res = BigInteger.ZERO;
    BigInteger x31 = BigInteger.ONE;
    for (int k = 1; k <= n + m - 2; k++) {

      int minA = Math.max(0, k - (n - 1));
      int maxA = Math.min(k, (m - 1));
      int minB = k - minA;
      int maxB = k - maxA;

      BigInteger cnt = BigInteger.ZERO;

      if (minA == 0) {
        long u = m - minA;
        long v = n - minB;
        if (EA) {
        }
        cnt = cnt.add(BigInteger.valueOf(u).multiply(BigInteger.valueOf(v)));

        minA++;
        minB--;
      }
      if (maxB == 0) {
        long u = m - maxA;
        long v = n - maxB;
        if (EA) {
        }
        cnt = cnt.add(BigInteger.valueOf(u).multiply(BigInteger.valueOf(v)));

        maxA--;
        maxB++;
      }
      cnt = cnt.mod(MM);

      BigInteger _m = BigInteger.valueOf(m);
      BigInteger _n = BigInteger.valueOf(n);
      BigInteger _k = BigInteger.valueOf(k);

      BigInteger s1 = BigInteger.valueOf(maxA - minA + 1);
      BigInteger sA = BigInteger.valueOf(sumA(minA, maxA));
      BigInteger sAA = BigInteger.valueOf(sumA2(minA, maxA));
      BigInteger cnt2 = BigInteger.ZERO;

      cnt2 = cnt2.add(sAA.negate().add(MM));

      cnt2 = cnt2.add(sA.multiply(_m.subtract(_n).add(_k).add(MM).mod(MM)));

      cnt2 = cnt2.add(s1.multiply(BigInteger.valueOf(((long) m * (n - k)) % M))).mod(MM);


      cnt = cnt.add(cnt2);
      cnt = cnt.add(cnt2);

      if (EA) {
        System.out.printf("k:%d %d\n", k, cnt.longValue());
      }

      cnt = cnt.mod(MM);
      res = res.add(cnt.multiply(x31)).mod(MM);

      x31 = x31.multiply(BigInteger.valueOf(31)).mod(MM);

    }
 
    return (int) res.longValue();
  }

  private int sumA(int minA, int maxA) {
    int ret = prefix_sum_n[maxA];
    if (minA - 1 >= 0) {
      ret -= prefix_sum_n[minA - 1];
    }
    if (ret < 0) {
      ret += M;
    }
    return ret;
  }

  private int sumA2(int minA, int maxA) {
    int ret = prefix_sum_n2[maxA];
    if (minA - 1 >= 0) {
      ret -= prefix_sum_n2[minA - 1];
    }
    if (ret < 0) {
      ret += M;
    }
    return ret;
  }

}

Program Python:

t = int(input())
dk = {}

for _ in range(t):
    
    x,y = map(int,input().split())
    d = {}
    if x>y:
        t = x
        x = y
        y = t
    for k in range(1,x+y-1):
        d[k] = 0
    k = x+y
    
    for i in range(2,x):
        if dk.get(i,0):
            d[k-i] = dk[i]
        else:
            d[k-i] = dk[i] = (i*(i+1)*(i-1))//3
    
    a = x*(x-1)
    b = x*(x-1)*(2*x-1)//3
    for i in range(x-1,y):
        d[k-i] = a*i-b
       
    for i in range(y-1,x+y-1):
        g = i-y
        d[k-i] = a*i-b - (g*(g+1)*(3*i-2*g-1))//3
    
    
    for i in range(1,x):
        d[i] += y*(x-i)
    for i in range(1,y):
        d[i] += x*(y-i)
        
        
    a = 0 
    b = 0
    m = 1   
    for k in range(1,x+y-1):
        a += (d[k]*m)%998244353
        m = (m*31)%998244353
    print(a%998244353)

Codechef Long Challenges

September Long Challenge 2021 Solution

August Long Challenge 2021 Solutions

Leave a Comment