Last active
July 25, 2020 15:08
-
-
Save xrq-phys/a90e3692a235f7c61c985df4039f6ff1 to your computer and use it in GitHub Desktop.
Prototype for TBLIS contract for ForwardDiff.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# here idx has 3 entries corresponding to e.g. "ik,jk->ij". | |
contract!(A::Array{T}, | |
B::Array{T}, | |
C::Array{T}, | |
idx) where {T<:Dual} = contract!(T, sizeof(T)/sizeof(tovalue(T)), # tovalue unveils base type. | |
A, 0, B, 0, C, 0, idx) | |
contract!(Type::Dual{Tg, T, ND}, topst, # top-level stride | |
A::Array, sftA, # arrays here are all at their top-level (not dispatched) | |
B::Array, sftB, | |
C::Array, sftC, | |
idx) = begin | |
# direct dispatch for value types. | |
contract!(T, topst, A, sftA, B, sftB, C, sftC, idx) | |
# for all differentials | |
# TODO: consider exchangability | |
for id = 1:ND | |
# unpacks one layer of dual. note that differentials are also in value's type. | |
contract!(T, topst, | |
A, sftA + id*sizeof(T), | |
B, sftB, | |
C, sftC + id*sizeof(T), | |
idx) | |
contract!(T, topst, | |
A, sftA, | |
B, sftB + id*sizeof(T), | |
C, sftC + id*sizeof(T), | |
idx) | |
end | |
end | |
contract!(Type::ValueType, # to be defined | |
topst, | |
A::Ptr{Cvoid}, sftA, | |
B::Ptr{Cvoid}, sftB, | |
C::Ptr{Cvoid}, sftC, | |
idx) = begin | |
# convert stride unit in top duals. | |
stA = topst .* strides(A) | |
stB = topst .* strides(B) | |
stC = topst .* strides(C) | |
# - build TBLIS-object for A, B and C from bare memory shifted by sft{A,B,C}. | |
# - TLIBS-contract according to idx. | |
return nothing | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment