Skip to content

Instantly share code, notes, and snippets.

@xuzhongxing
Created December 6, 2016 03:37
Show Gist options
  • Save xuzhongxing/1da20b2dad815d542d54a7f6e4c7a02f to your computer and use it in GitHub Desktop.
Save xuzhongxing/1da20b2dad815d542d54a7f6e4c7a02f to your computer and use it in GitHub Desktop.
fix for UNK
diff --git a/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
index ca619ee..99f20b6 100755
--- a/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
+++ b/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java
@@ -68,6 +68,7 @@ import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.GZIPInputStream;
@@ -404,13 +405,13 @@ public class WordVectorSerializer {
PrintWriter writer = new PrintWriter(new OutputStreamWriter(stream, "UTF-8"));
- for (int x = 0; x < vocabCache.numWords(); x++) {
- T element = vocabCache.elementAtIndex(x);
+ Collection<String> words = vocabCache.words();
+ for (String w : words) {
StringBuilder builder = new StringBuilder();
- builder.append(encodeB64(element.getLabel())).append(" ");
- INDArray vec = lookupTable.vector(element.getLabel());
+ builder.append(encodeB64(w)).append(" ");
+ INDArray vec = lookupTable.vector(w);
for (int i = 0; i < vec.length(); i++) {
builder.append(vec.getDouble(i));
if (i < vec.length() - 1) builder.append(" ");
@@ -582,6 +583,12 @@ public class WordVectorSerializer {
// writing out huffman tree
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
+
+ // UNK symbol is not in huffman tree
+ if (vectors.getConfiguration().isUseHierarchicSoftmax() && (i == vectors.getVocab().numWords() - 1)) {
+ continue;
+ }
+
VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
for (int code: word.getCodes()) {
@@ -606,6 +613,12 @@ public class WordVectorSerializer {
// writing out huffman tree
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
+
+ // UNK symbol is not in huffman tree
+ if (vectors.getConfiguration().isUseHierarchicSoftmax() && (i == vectors.getVocab().numWords() - 1)) {
+ continue;
+ }
+
VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ");
for (int point: word.getPoints()) {
@@ -629,6 +642,12 @@ public class WordVectorSerializer {
// writing out word frequencies
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) {
for (int i = 0; i < vectors.getVocab().numWords(); i++) {
+
+ // UNK symbol is not in huffman tree
+ if (vectors.getConfiguration().isUseHierarchicSoftmax() && (i == vectors.getVocab().numWords() - 1)) {
+ continue;
+ }
+
VocabWord word = vectors.getVocab().elementAtIndex(i);
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()).append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel()));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment