Skip to content

Instantly share code, notes, and snippets.

@tokugh
Last active January 8, 2025 11:39
Show Gist options
  • Save tokugh/7a8f8460f50c561ce260e02fbe82d311 to your computer and use it in GitHub Desktop.
Save tokugh/7a8f8460f50c561ce260e02fbe82d311 to your computer and use it in GitHub Desktop.
subset_convolution
# https://judge.yosupo.jp/problem/subset_convolution
# O(N^2_2^N)
using OffsetArrays: Origin
const N::Int = 998244353
xβŠ•y = (z=x+y;ifelse(z<N,z,z-N))
xβŠ–y = (z=x-y;ifelse(z<0,z+N,z))
xβŠ—y = mod(x*y,N)
toI(s=readline()) = parse(Int,s)
toVI(s=readline()) = map(toI,eachsplit(s))
function main()
n = toI()
a = toVI() |> Origin(0)
b = toVI() |> Origin(0)
c = solve(n,a,b)
join(stdout,c," "); println()
end
function solve(n,a,b)
𝕒 = deposit_bitcount(a)
zeta_subset!(𝕒)
𝕓 = deposit_bitcount(b)
zeta_subset!(𝕓)
𝕔 = convolve(𝕒,𝕓)
mobius_subset!(𝕔)
c = extract_bitcount(𝕔)
return c
end
function deposit_bitcount(xs)
len = size(xs,1)
@assert ispow2(len)
n = trailing_zeros(len)
res = zeros(Int,1+n,len) |> Origin(0)
for i in eachindex(xs)
@inbounds res[count_ones(i),i] = xs[i]
end
return res
end
function zeta_subset!(res)
len = size(res,2)
@assert ispow2(len)
n = trailing_zeros(len)
for d in 0:n-1
w = 1<<d
for i in 0:len-1
!iszero(i & w) || continue
iβ€² = i ⊻ w
for k in 0:n
@inbounds res[k,i] = res[k,i] βŠ• res[k,iβ€²]
end
end
end
return res
end
function convolve(𝕒,𝕓)
len = size(𝕒,2)
@assert ispow2(len)
n = trailing_zeros(len)
𝕔 = fill!(similar(𝕒),0)
for i in 0:len-1
for k in 0:n
tmp = 0
for j in 0:k
# @inbounds 𝕔[k,i] = 𝕔[k,i] βŠ• 𝕒[j,i]βŠ—π•“[k-j,i]
@inbounds tmp = tmp βŠ• 𝕒[j,i]βŠ—π•“[k-j,i]
end
𝕔[k,i] = tmp
end
end
return 𝕔
end
function mobius_subset!(res)
len = size(res,2)
@assert ispow2(len)
n = trailing_zeros(len)
for d in reverse(0:n-1)
w = 1<<d
for i in reverse(0:len-1)
!iszero(i & w) || continue
iβ€² = i ⊻ w
for k in 0:n
@inbounds res[k,i] = res[k,i] βŠ– res[k,iβ€²]
end
end
end
return res
end
function extract_bitcount(res)
xs = zeros(Int,size(res,2)) |> Origin(0)
for i in eachindex(xs)
@inbounds xs[i] = res[count_ones(i),i]
end
return xs
end
@static if endswith(@__FILE__, PROGRAM_FILE)
main()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment