Last active
January 8, 2025 11:39
-
-
Save tokugh/7a8f8460f50c561ce260e02fbe82d311 to your computer and use it in GitHub Desktop.
subset_convolution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://judge.yosupo.jp/problem/subset_convolution | |
# O(N^2_2^N) | |
using OffsetArrays: Origin | |
const N::Int = 998244353 | |
xβy = (z=x+y;ifelse(z<N,z,z-N)) | |
xβy = (z=x-y;ifelse(z<0,z+N,z)) | |
xβy = mod(x*y,N) | |
toI(s=readline()) = parse(Int,s) | |
toVI(s=readline()) = map(toI,eachsplit(s)) | |
function main() | |
n = toI() | |
a = toVI() |> Origin(0) | |
b = toVI() |> Origin(0) | |
c = solve(n,a,b) | |
join(stdout,c," "); println() | |
end | |
function solve(n,a,b) | |
π = deposit_bitcount(a) | |
zeta_subset!(π) | |
π = deposit_bitcount(b) | |
zeta_subset!(π) | |
π = convolve(π,π) | |
mobius_subset!(π) | |
c = extract_bitcount(π) | |
return c | |
end | |
function deposit_bitcount(xs) | |
len = size(xs,1) | |
@assert ispow2(len) | |
n = trailing_zeros(len) | |
res = zeros(Int,1+n,len) |> Origin(0) | |
for i in eachindex(xs) | |
@inbounds res[count_ones(i),i] = xs[i] | |
end | |
return res | |
end | |
function zeta_subset!(res) | |
len = size(res,2) | |
@assert ispow2(len) | |
n = trailing_zeros(len) | |
for d in 0:n-1 | |
w = 1<<d | |
for i in 0:len-1 | |
!iszero(i & w) || continue | |
iβ² = i β» w | |
for k in 0:n | |
@inbounds res[k,i] = res[k,i] β res[k,iβ²] | |
end | |
end | |
end | |
return res | |
end | |
function convolve(π,π) | |
len = size(π,2) | |
@assert ispow2(len) | |
n = trailing_zeros(len) | |
π = fill!(similar(π),0) | |
for i in 0:len-1 | |
for k in 0:n | |
tmp = 0 | |
for j in 0:k | |
# @inbounds π[k,i] = π[k,i] β π[j,i]βπ[k-j,i] | |
@inbounds tmp = tmp β π[j,i]βπ[k-j,i] | |
end | |
π[k,i] = tmp | |
end | |
end | |
return π | |
end | |
function mobius_subset!(res) | |
len = size(res,2) | |
@assert ispow2(len) | |
n = trailing_zeros(len) | |
for d in reverse(0:n-1) | |
w = 1<<d | |
for i in reverse(0:len-1) | |
!iszero(i & w) || continue | |
iβ² = i β» w | |
for k in 0:n | |
@inbounds res[k,i] = res[k,i] β res[k,iβ²] | |
end | |
end | |
end | |
return res | |
end | |
function extract_bitcount(res) | |
xs = zeros(Int,size(res,2)) |> Origin(0) | |
for i in eachindex(xs) | |
@inbounds xs[i] = res[count_ones(i),i] | |
end | |
return xs | |
end | |
@static if endswith(@__FILE__, PROGRAM_FILE) | |
main() | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment