ИЛЛЮСТРАЦИЯ СЕТЕЙ - по ссылкам в комментарии.
- 2 нейрона входа
- 3 нейрона выхода, SoftMax
- Cross Entropy loss
W =
[
w11 w12 w13
w21 w22 w23
]
Терминология:
In
- взвешенная суммаOut
- результат активацииCE
- Cross Entropy
Для примера, подсчитаем производную для w11
.
Сначала разберем влияние w11:
w11
влияет только наIn1
In1
влияет на всеOut1
,Out2
,Out3
т.к. SoftMax используетIn1
значение в знаменателе при подсчете любого изOut
Out1
,Out2
,Out3
влияют наCE
Для использования chain rule инвертируем понятие "влияние" и определим зависимости:
CE
зависит отOut1
,Out2
,Out3
- Каждая из
Out1
,Out2
,Out3
зависит отIn1
In1
зависит отw11
Таким образом получаем:
dCE/dw11
= dCE/dOut1 * dOut1/dw11 + dCE/dOut2 * dOut2/dw11 + dCE/dOut3 * dOut3/dw11
= (dCE/dOut1 * dOut1/dIn1 + dCE/dOut2 * dOut2/dIn1 + dCE/dOut3 * dOut3/dIn1) * dIn1/dw11
Всё ли здесь верно?
Изменим слои так, что:
- 4 нейрона входа
- 2 нейрона скрытого слоя, ReLU
- 3 нейрона выхода, SoftMax
- Cross Entropy loss
Параметры для скрытого слоя будем обозначать со звездочкой *
перед индексами.
W* =
[
w*11 w*12
w*21 w*22
w*31 w*32
w*41 w*42
]
W =
[
w11 w12 w13
w21 w22 w23
]
Out*1 стал вточности x1 из предыдущего примера
Чему равно dCE/dw*11
?
Аналогично, распишем влияние:
w*11
влияет только наIn*1
In*1
влияет только наOut*1
т.к. выход ReLU зависит только от входа с аналогичным индексомOut*1
влияет и наIn1
и наIn2
и наIn3
- Каждый из
In1
,In2
,In3
влияют наOut1
,Out2
,Out3
Out1
,Out2
,Out3
влияют наCE
Тогда производная dCE/dw*11
:
dCE/dw*11
= dCE/dOut1 * dOut1/dw*11 + dCE/dOut2 * dOut2/dw*11 + dCE/dOut3 * dOut3/dw*11
dOut1/dw*11 = dOut1/dIn1 * dIn1/dw*11 + dOut1/dIn2 * dIn2/dw*11 + dOut1/dIn3 * dIn3/dw*11
dOut2/dw*11 = dOut2/dIn1 * dIn1/dw*11 + dOut2/dIn2 * dIn2/dw*11 + dOut2/dIn3 * dIn3/dw*11
dOut3/dw*11 = dOut3/dIn1 * dIn1/dw*11 + dOut3/dIn2 * dIn2/dw*11 + dOut3/dIn3 * dIn3/dw*11
dIn1/dw*11 = dIn1/dOut*1 * dOut*1/dw*11
dIn2/dw*11 = dIn2/dOut*1 * dOut*1/dw*11
dIn3/dw*11 = dIn3/dOut*1 * dOut*1/dw*11
dOut*1/dw*11 = dOut*1/dIn*1 * dIn*1/dw*11
# Итого:
dCE/dw*11 =
= dCE/dOut1 * (dOut1/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut1/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut1/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11) +
dCE/dOut2 * (dOut2/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut2/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut2/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11) +
dCE/dOut3 * (dOut3/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut3/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
dOut3/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11)
= dCE/dOut1 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut1/dIn1 * dIn1/dOut*1 + dOut1/dIn2 * dIn2/dOut*1 + dOut1/dIn3 * dIn3/dOut*1) +
dCE/dOut2 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut2/dIn1 * dIn1/dOut*1 + dOut2/dIn2 * dIn2/dOut*1 + dOut2/dIn3 * dIn3/dOut*1) +
dCE/dOut3 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut3/dIn1 * dIn1/dOut*1 + dOut3/dIn2 * dIn2/dOut*1 + dOut3/dIn3 * dIn3/dOut*1)
= dOut*1/dIn*1 * dIn*1/dw*11 * (dCE/dOut1 * (dOut1/dIn1 * dIn1/dOut*1 + dOut1/dIn2 * dIn2/dOut*1 + dOut1/dIn3 * dIn3/dOut*1)
dCE/dOut2 * (dOut2/dIn1 * dIn1/dOut*1 + dOut2/dIn2 * dIn2/dOut*1 + dOut2/dIn3 * dIn3/dOut*1)
dCE/dOut3 * (dOut3/dIn1 * dIn1/dOut*1 + dOut3/dIn2 * dIn2/dOut*1 + dOut3/dIn3 * dIn3/dOut*1))
Иллюстрации сетей находится ниже. Ручкой прорисованы лишь влияющие связи, остальное - карандашом для полноты картины.
simple
complex