Skip to content

Instantly share code, notes, and snippets.

@PovelikinRostislav
Last active November 17, 2019 14:00
Show Gist options
  • Save PovelikinRostislav/f279e5893fbbdea71c61560413500926 to your computer and use it in GitHub Desktop.
Save PovelikinRostislav/f279e5893fbbdea71c61560413500926 to your computer and use it in GitHub Desktop.
BackPropagation.md

Back Propagation

Устройство сети

ИЛЛЮСТРАЦИЯ СЕТЕЙ - по ссылкам в комментарии.

  • 2 нейрона входа
  • 3 нейрона выхода, SoftMax
  • Cross Entropy loss
W =
[
  w11 w12 w13
  w21 w22 w23
]

Терминология:

  1. In - взвешенная сумма
  2. Out - результат активации
  3. CE - Cross Entropy

Подсчет градиента

Для примера, подсчитаем производную для w11.

Сначала разберем влияние w11:

  • w11 влияет только на In1
  • In1 влияет на все Out1,Out2,Out3 т.к. SoftMax использует In1 значение в знаменателе при подсчете любого из Out
  • Out1, Out2, Out3 влияют на CE

Для использования chain rule инвертируем понятие "влияние" и определим зависимости:

  • CE зависит от Out1, Out2, Out3
  • Каждая из Out1, Out2, Out3 зависит от In1
  • In1 зависит от w11

Таким образом получаем:

dCE/dw11
= dCE/dOut1 * dOut1/dw11 + dCE/dOut2 * dOut2/dw11 + dCE/dOut3 * dOut3/dw11
= (dCE/dOut1 * dOut1/dIn1 + dCE/dOut2 * dOut2/dIn1 + dCE/dOut3 * dOut3/dIn1) * dIn1/dw11

Всё ли здесь верно?

Усложнение сети и подсчет градиента

Изменим слои так, что:

  • 4 нейрона входа
  • 2 нейрона скрытого слоя, ReLU
  • 3 нейрона выхода, SoftMax
  • Cross Entropy loss

Параметры для скрытого слоя будем обозначать со звездочкой * перед индексами.

W* =
[
  w*11 w*12
  w*21 w*22
  w*31 w*32
  w*41 w*42
]
W =
[
  w11 w12 w13
  w21 w22 w23
]

Out*1 стал вточности x1 из предыдущего примера

Чему равно dCE/dw*11?

Аналогично, распишем влияние:

  • w*11 влияет только на In*1
  • In*1 влияет только на Out*1 т.к. выход ReLU зависит только от входа с аналогичным индексом
  • Out*1 влияет и на In1 и на In2 и на In3
  • Каждый из In1, In2, In3 влияют на Out1,Out2,Out3
  • Out1, Out2, Out3 влияют на CE

Тогда производная dCE/dw*11:

dCE/dw*11
= dCE/dOut1 * dOut1/dw*11 + dCE/dOut2 * dOut2/dw*11 + dCE/dOut3 * dOut3/dw*11

dOut1/dw*11 = dOut1/dIn1 * dIn1/dw*11 + dOut1/dIn2 * dIn2/dw*11 + dOut1/dIn3 * dIn3/dw*11
dOut2/dw*11 = dOut2/dIn1 * dIn1/dw*11 + dOut2/dIn2 * dIn2/dw*11 + dOut2/dIn3 * dIn3/dw*11
dOut3/dw*11 = dOut3/dIn1 * dIn1/dw*11 + dOut3/dIn2 * dIn2/dw*11 + dOut3/dIn3 * dIn3/dw*11

dIn1/dw*11 = dIn1/dOut*1 * dOut*1/dw*11
dIn2/dw*11 = dIn2/dOut*1 * dOut*1/dw*11
dIn3/dw*11 = dIn3/dOut*1 * dOut*1/dw*11

dOut*1/dw*11 = dOut*1/dIn*1 * dIn*1/dw*11

# Итого:
dCE/dw*11 =
= dCE/dOut1 * (dOut1/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
               dOut1/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
               dOut1/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11) +
  dCE/dOut2 * (dOut2/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
               dOut2/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
               dOut2/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11) +
  dCE/dOut3 * (dOut3/dIn1 * dIn1/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
               dOut3/dIn2 * dIn2/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11 +
               dOut3/dIn3 * dIn3/dOut*1 * dOut*1/dIn*1 * dIn*1/dw*11)
               
= dCE/dOut1 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut1/dIn1 * dIn1/dOut*1 + dOut1/dIn2 * dIn2/dOut*1 + dOut1/dIn3 * dIn3/dOut*1) +
  dCE/dOut2 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut2/dIn1 * dIn1/dOut*1 + dOut2/dIn2 * dIn2/dOut*1 + dOut2/dIn3 * dIn3/dOut*1) +
  dCE/dOut3 * dOut*1/dIn*1 * dIn*1/dw*11 * (dOut3/dIn1 * dIn1/dOut*1 + dOut3/dIn2 * dIn2/dOut*1 + dOut3/dIn3 * dIn3/dOut*1)
  
= dOut*1/dIn*1 * dIn*1/dw*11 * (dCE/dOut1 * (dOut1/dIn1 * dIn1/dOut*1 + dOut1/dIn2 * dIn2/dOut*1 + dOut1/dIn3 * dIn3/dOut*1)
                                dCE/dOut2 * (dOut2/dIn1 * dIn1/dOut*1 + dOut2/dIn2 * dIn2/dOut*1 + dOut2/dIn3 * dIn3/dOut*1)
                                dCE/dOut3 * (dOut3/dIn1 * dIn1/dOut*1 + dOut3/dIn2 * dIn2/dOut*1 + dOut3/dIn3 * dIn3/dOut*1))
@PovelikinRostislav
Copy link
Author

Иллюстрации сетей находится ниже. Ручкой прорисованы лишь влияющие связи, остальное - карандашом для полноты картины.

  1. Простая сеть
    simple
  2. Расширенная сеть
    complex

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment