Created
June 5, 2016 11:07
-
-
Save soonraah/5001a9eafa7880fb84716618bc0f62de to your computer and use it in GitHub Desktop.
A base class of online learning for binary linear classification
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package mlp.onlineml.classification.binary | |
import breeze.linalg.{DenseMatrix, DenseVector} | |
/** | |
* A base class of binary linear classification | |
* | |
* @param w weight vector | |
* @param sigma covariance matrix | |
*/ | |
abstract class LinearClassifier protected (val w: DenseVector[Double], val sigma: DenseMatrix[Double]) { | |
protected def e(x: DenseVector[Double]): Double | |
protected def alpha(x: DenseVector[Double], y: Label): Double | |
protected def beta(x: DenseVector[Double], y: Label): Double | |
/** | |
* Online training | |
* | |
* @param x training sample | |
* @param y label for training sample | |
* @return updated classifier | |
*/ | |
def train(x: DenseVector[Double], y: Label): LinearClassifier = { | |
require(x.length == w.length) | |
if (y.value * (w.t * x) < e(x)) { | |
create( | |
w + y.value * alpha(x, y) * (sigma * x), | |
sigma - beta(x, y) * (sigma * x * x.t * sigma) | |
) | |
} else { | |
create(w, sigma) | |
} | |
} | |
def classify(x: DenseVector[Double]): Label = { | |
if (w.t * x > 0) Label(true) else Label(false) | |
} | |
protected def create(w: DenseVector[Double], sigma: DenseMatrix[Double]): LinearClassifier | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment