fork download
  1. MOD = 998244353
  2. INV3 = pow(3, MOD - 2, MOD)
  3. PRIMITIVE_ROOT = 3 # for MOD = 998244353
  4.  
  5. def count_valid(n, k, m, arr):
  6. cnt = [0,0,0]
  7. for x in arr:
  8. cnt[x % 3] += 1
  9. c0 = k // 3
  10. c1 = (k + 2) // 3
  11. c2 = (k + 1) // 3
  12. S = n - m
  13. diff = (cnt[1] - cnt[2]) % 3
  14.  
  15. # case S == 0
  16. if S == 0:
  17. return 1 if (cnt[0] >= 1 and (cnt[1] - cnt[2]) % 3 == 0 and cnt[1] == cnt[2]) else 0
  18.  
  19. # prepare omega (primitive 3rd root)
  20. w = pow(PRIMITIVE_ROOT, (MOD - 1) // 3, MOD) # omega
  21. w_pows = [1, w, (w * w) % MOD]
  22.  
  23. def total_with(c0_, c1_, c2_):
  24. # compute sum_{t=0..2} omega^{t*diff} * (c0_ + c1_*omega^t + c2_*omega^{2t})^S
  25. s = 0
  26. for t in range(3):
  27. base = (c0_ + c1_ * w_pows[t] + c2_ * w_pows[(2*t) % 3]) % MOD
  28. val = pow(base, S, MOD)
  29. wt = pow(w_pows[1], (t * diff) % 3, MOD) # omega^{t*diff}; w_pows[1]=omega
  30. s = (s + val * wt) % MOD
  31. return s * INV3 % MOD
  32.  
  33. total = total_with(c0, c1, c2)
  34. # if prefix has no divisible-by-3, subtract sequences that also choose none (i.e., c0'=0)
  35. if cnt[0] == 0:
  36. bad = total_with(0, c1, c2)
  37. ans = (total - bad) % MOD
  38. else:
  39. ans = total % MOD
  40. return ans
  41.  
  42. # input driver (the problem statement uses t testcases)
  43. t = int(input().strip())
  44. for _ in range(t):
  45. n, k, m = map(int, input().split())
  46. arr = list(map(int, input().split())) if m > 0 else []
  47. print(count_valid(n, k, m, arr))
  48.  
Success #stdin #stdout 0.1s 14132KB
stdin
6
3 4 0
3 4 1
1
3 4 1
2
3 4 1
3
3 4 1
4
3 4 3
2 3 2
stdout
122628784
861709705
562599328
831343105
861709705
0