Created
December 6, 2016 03:37
-
-
Save xuzhongxing/1da20b2dad815d542d54a7f6e4c7a02f to your computer and use it in GitHub Desktop.
fix for UNK
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java | |
index ca619ee..99f20b6 100755 | |
--- a/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java | |
+++ b/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java | |
@@ -68,6 +68,7 @@ import java.nio.file.Files; | |
import java.nio.file.Paths; | |
import java.nio.file.StandardCopyOption; | |
import java.util.ArrayList; | |
+import java.util.Collection; | |
import java.util.List; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import java.util.zip.GZIPInputStream; | |
@@ -404,13 +405,13 @@ public class WordVectorSerializer { | |
PrintWriter writer = new PrintWriter(new OutputStreamWriter(stream, "UTF-8")); | |
- for (int x = 0; x < vocabCache.numWords(); x++) { | |
- T element = vocabCache.elementAtIndex(x); | |
+ Collection<String> words = vocabCache.words(); | |
+ for (String w : words) { | |
StringBuilder builder = new StringBuilder(); | |
- builder.append(encodeB64(element.getLabel())).append(" "); | |
- INDArray vec = lookupTable.vector(element.getLabel()); | |
+ builder.append(encodeB64(w)).append(" "); | |
+ INDArray vec = lookupTable.vector(w); | |
for (int i = 0; i < vec.length(); i++) { | |
builder.append(vec.getDouble(i)); | |
if (i < vec.length() - 1) builder.append(" "); | |
@@ -582,6 +583,12 @@ public class WordVectorSerializer { | |
// writing out huffman tree | |
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileCodes))) { | |
for (int i = 0; i < vectors.getVocab().numWords(); i++) { | |
+ | |
+ // UNK symbol is not in huffman tree | |
+ if (vectors.getConfiguration().isUseHierarchicSoftmax() && (i == vectors.getVocab().numWords() - 1)) { | |
+ continue; | |
+ } | |
+ | |
VocabWord word = vectors.getVocab().elementAtIndex(i); | |
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); | |
for (int code: word.getCodes()) { | |
@@ -606,6 +613,12 @@ public class WordVectorSerializer { | |
// writing out huffman tree | |
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileHuffman))) { | |
for (int i = 0; i < vectors.getVocab().numWords(); i++) { | |
+ | |
+ // UNK symbol is not in huffman tree | |
+ if (vectors.getConfiguration().isUseHierarchicSoftmax() && (i == vectors.getVocab().numWords() - 1)) { | |
+ continue; | |
+ } | |
+ | |
VocabWord word = vectors.getVocab().elementAtIndex(i); | |
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" "); | |
for (int point: word.getPoints()) { | |
@@ -629,6 +642,12 @@ public class WordVectorSerializer { | |
// writing out word frequencies | |
try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { | |
for (int i = 0; i < vectors.getVocab().numWords(); i++) { | |
+ | |
+ // UNK symbol is not in huffman tree | |
+ if (vectors.getConfiguration().isUseHierarchicSoftmax() && (i == vectors.getVocab().numWords() - 1)) { | |
+ continue; | |
+ } | |
+ | |
VocabWord word = vectors.getVocab().elementAtIndex(i); | |
StringBuilder builder = new StringBuilder(encodeB64(word.getLabel())).append(" ").append(word.getElementFrequency()).append(" ").append(vectors.getVocab().docAppearedIn(word.getLabel())); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment