Created
April 4, 2012 14:26
-
-
Save Melipone/2301728 to your computer and use it in GitHub Desktop.
Viterbi algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns ident.viterbi | |
(:use [clojure.pprint])) | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
;; Example | |
;; (def initpr (to-array [0.6 0.4])) | |
;; (def transpr (to-array-2d [[0.7 0.3][0.4 0.6]])) | |
;; (def emisspr (to-array-2d [[0.1 0.4 0.5][0.6 0.3 0.1]])) | |
;; (def hmm (make-hmm {:states ["rainy" "sunny"] :obs ["walk" "shop" "clean"] :init-probs initpr :emission-probs emisspr :state-transitions transpr})) | |
;; optimal state sequence and probability of sequence | |
;; (viterbi hmm [2 1 0]) -> [(0 0 1) 0.03628] | |
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; | |
(defstruct hmm :n :m :init-probs :emission-probs :state-transitions) | |
(defn make-hmm [{:keys [states, obs, init-probs, emission-probs, state-transitions]}] | |
(struct-map hmm | |
:n (count states) | |
:m (count obs) | |
:states states | |
:obs obs | |
:init-probs init-probs ;; n dim | |
:emission-probs emission-probs ;;m x n | |
:state-transitions state-transitions)) | |
(defn indexed [s] | |
(map vector (iterate inc 0) s)) | |
(defn argmax [coll] | |
(loop [s (indexed coll) | |
max (first s)] | |
(if (empty? s) | |
max | |
(let [[idx elt] (first s) | |
[max-indx max-elt] max] | |
(if (> elt max-elt) | |
(recur (rest s) (first s)) | |
(recur (rest s) max)))))) | |
(defn pprint-hmm [hmm] | |
(println "number of states: " (:n hmm) " number of observations: " (:m hmm)) | |
(print "init probabilities: ") (pprint (:init-probs hmm)) | |
(print "emission probs: " ) (pprint (:emission-probs hmm)) | |
(print "state-transitions: " ) (pprint (:state-transitions hmm))) | |
(defn init-alphas [hmm obs] | |
(map (fn [x] | |
(* (aget (:init-probs hmm) x) (aget (:emission-probs hmm) x obs))) | |
(range (:n hmm)))) | |
(defn forward [hmm alphas obs] | |
(map (fn [j] | |
(* (reduce (fn [sum i] | |
(+ sum (* (nth alphas i) (aget (:state-transitions hmm) i j)))) | |
0 | |
(range (:n hmm))) | |
(aget (:emission-probs hmm) j obs))) (range (:n hmm)))) | |
(defn delta-max [hmm deltas obs] | |
(map (fn [j] | |
(* (apply max (map (fn [i] | |
(* (nth deltas i) | |
(aget (:state-transitions hmm) i j))) | |
(range (:n hmm)))) | |
(aget (:emission-probs hmm) j obs))) | |
(range (:n hmm)))) | |
(defn backtrack [paths deltas] | |
(loop [path (reverse paths) | |
term (first (argmax deltas)) | |
backtrack []] | |
(if (empty? path) | |
(reverse (conj backtrack term)) | |
(recur (rest path) (nth (first path) term) (conj backtrack term))))) | |
(defn update-paths [hmm deltas] | |
(map (fn [j] | |
(first (argmax (map (fn [i] | |
(* (nth deltas i) | |
(aget (:state-transitions hmm) i j))) | |
(range (:n hmm)))))) | |
(range (:n hmm)))) | |
(defn viterbi [hmm observs] | |
(loop [obs (rest observs) | |
alphas (init-alphas hmm (first observs)) | |
deltas alphas | |
paths []] | |
(if (empty? obs) | |
[(backtrack paths deltas) (float (reduce + alphas))] | |
(recur (rest obs) | |
(forward hmm alphas (first obs)) | |
(delta-max hmm deltas (first obs)) | |
(conj paths (update-paths hmm deltas)))))) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment