Last active
December 15, 2015 14:38
-
-
Save berlinbrown/5275387 to your computer and use it in GitHub Desktop.
Use markov chain model to generate text. Train the system and based on the data, return generated random text.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* Copyright (c) 2013 Berlin Brown (berlin2research.com) | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.berlin.crawl.util.text; | |
import java.io.BufferedInputStream; | |
import java.io.BufferedReader; | |
import java.io.File; | |
import java.io.FileInputStream; | |
import java.io.InputStreamReader; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.NavigableMap; | |
import java.util.Random; | |
import java.util.TreeMap; | |
/** | |
* Use markov chain model to generate text. Train the system and based on the data, | |
* return generated random text. | |
* | |
* <pre> | |
* Example Input: | |
* This is going to be a good day for all the powerful people. | |
* | |
* this is going to be a bad | |
* | |
* </pre> | |
* | |
* @author bbrown (berlin.brown at gmail.com) | |
*/ | |
public class MarkovChainTextGenerator { | |
// With a more expressive language or more time and better approach | |
// We could use a combination of words/pairs as the current sequence, what word follows next? | |
// E.g. in python : ('brown', 'fox'): ['jumps', 'who', 'who'] | |
/** List of string data in each document. */ | |
private List<String> docsData = new ArrayList<String>(); | |
/** Map data structure, position in doc to word map. */ | |
private Map<Integer, Map<Integer, String>> posByWordsData = new TreeMap<Integer, Map<Integer, String>>(); | |
/** Map data structure, position in document with word probability weight. */ | |
private Map<Integer, Map<String, Double>> posWordsFreqData = new TreeMap<Integer, Map<String, Double>>(); | |
/** Map data structure, word and list of next words. */ | |
private Map<String, List<String>> nextWords = new TreeMap<String, List<String>>(); | |
/** Map data structure with a WORD and then then frequency/probability of next possible word. */ | |
private Map<String, Map<String, Double>> nextWordsFreqPerc = new TreeMap<String, Map<String, Double>>(); | |
/** Use navigable map for prob to word. Access using ascending or descending order. Use 'map.ceilingEntry(rval).getValue' */ | |
private final Map<String, NavigableMap<Double, String>> reverseFreqToWord = new TreeMap<String, NavigableMap<Double, String>>(); | |
private Map<String, Map<String, Double>> reverseFreqToTotal = new TreeMap<String, Map<String, Double>>(); | |
private int txtGenStartPos = 0; | |
private int txtGenMaxWords = 6; | |
private Random random = new Random(System.currentTimeMillis()); | |
/** | |
* Use java.io routines to load the document based on the filename. | |
* | |
* @param filename | |
* @return | |
* @throws Exception | |
*/ | |
public static String loadDocument(final String filename) throws Exception { | |
final File f = new File(filename); | |
final FileInputStream fos = new FileInputStream(f); | |
final BufferedInputStream s = new BufferedInputStream(fos); | |
final BufferedReader reader = new BufferedReader(new InputStreamReader(s)); | |
final StringBuffer buf = new StringBuffer(); | |
{ | |
String data = ""; | |
String line = ""; | |
do { | |
data = reader.readLine(); | |
if (data != null) { | |
line = data.trim(); | |
if (line.length() > 0) { | |
buf.append(line); | |
} | |
} // End of the if // | |
} while(data != null); | |
reader.close(); | |
} | |
String ss = buf.toString().trim().toLowerCase(); | |
ss = ss.replaceAll("\\.", ""); | |
ss.toString(); | |
return ss; | |
} // End of the method // | |
/** | |
* Tokenize the document in load into position data structure. | |
* | |
* Exit method with position of document and words at that position (<code>posByWordsData</code>). | |
* | |
* @param doc | |
* @param docId | |
*/ | |
public void loadAndTokenizeDocument(final String doc, final int docId) { | |
// Build a map data structure, position in doc to word-doc map. | |
final String [] words = doc.split("\\s+"); | |
if (words.length > 0) { | |
for (int posi = 0; posi < words.length; posi++) { | |
if (posByWordsData.get(posi) == null) { | |
// Only one document can have one word at each position. | |
final Map<Integer, String> docWordMap = new TreeMap<Integer, String>(); | |
posByWordsData.put(posi, docWordMap); | |
docWordMap.put(docId, words[posi]); | |
} else { | |
posByWordsData.get(posi).put(docId, words[posi]); | |
} | |
} // End of the for // | |
} | |
// Exit method with position of document and words at that position. | |
} // End of the method // | |
/** | |
* Load position/probability data structure. | |
* | |
* Exit method with the 'posWordsFreqData' loaded, 'position' in document | |
* with count/freq at that position. | |
*/ | |
public void findWordsFreqPos() { | |
for (final Integer posi : this.posByWordsData.keySet()) { | |
// Use tmp map of count for each word, used to convert to probability | |
final Map<String, Integer> mapTmpCtPerPos = new HashMap<String, Integer>(); | |
final Map<Integer, String> docWordMap = this.posByWordsData.get(posi); | |
for (final Integer docId : docWordMap.keySet()) { | |
final String word = docWordMap.get(docId); | |
if (mapTmpCtPerPos.get(word) == null) { | |
mapTmpCtPerPos.put(word, 1); | |
} else { | |
mapTmpCtPerPos.put(word, mapTmpCtPerPos.get(word)+1); | |
} | |
} // End of for through the docs | |
// We should have a count map for words at this position: | |
// Convert to floats | |
final double n = docWordMap.size(); | |
for (final Integer docId : docWordMap.keySet()) { | |
final String word = docWordMap.get(docId); | |
final double perc = mapTmpCtPerPos.get(word) / n; | |
if (posWordsFreqData.get(posi) == null) { | |
final Map<String, Double> wordByFreq = new HashMap<String, Double>(); | |
posWordsFreqData.put(posi, wordByFreq); | |
wordByFreq.put(word, perc); | |
} else { | |
posWordsFreqData.get(posi).put(word, perc); | |
} | |
} // End of the for // | |
} | |
// Exit method with the 'posWordsFreqData' loaded, 'position' in document with count/freq at that position. | |
} | |
/** | |
* Load the digram word list, the current word against a list of possible next words. | |
* | |
* Exit method with 'nextWords' data structure loaded, the current word against | |
* a list of possible next words. | |
*/ | |
public void buildDigramWordList() { | |
// Loop through each position. | |
for (final Integer posi : this.posByWordsData.keySet()) { | |
final int nxt = posi + 1; | |
if (this.posByWordsData.get(nxt) != null && this.posByWordsData.get(posi) != null) { | |
// Next tokens // | |
final Map<Integer, String> curPosData = this.posByWordsData.get(posi); | |
final Map<Integer, String> nextPosData = this.posByWordsData.get(nxt); | |
Digram diagramSet = null; | |
if (curPosData != null && nextPosData != null) { | |
// Build a pair, current to next // | |
for (Integer docid : curPosData.keySet()) { | |
final String curword = curPosData.get(docid); | |
final String nextword = nextPosData.get(docid); | |
// Now query next | |
if (curword != null && nextword != null) { | |
diagramSet = new Digram(curword, nextword); | |
} | |
// Build a list of nxt words | |
if (diagramSet != null) { | |
if (this.nextWords.get(diagramSet.cur) == null) { | |
final List<String> listOfWordsNext = new ArrayList<String>(); | |
listOfWordsNext.add(diagramSet.next); | |
this.nextWords.put(diagramSet.cur, listOfWordsNext); | |
} else { | |
this.nextWords.get(diagramSet.cur).add(diagramSet.next); | |
} | |
} | |
} // End of loop through doc // | |
} // End of if - check cur and next pos | |
} | |
} // End of the for // | |
// Exit method with 'nextWords' data structure loaded, the current word against a list of possible next words. | |
} | |
/** | |
* Load the digram word list, the current word against a list of possible next words with | |
* the probability that the word will appear next. | |
* | |
* Exit with the nextWordsFreqPerc data loaded. | |
*/ | |
public void buildDigramStats() { | |
// Core logic, exit method with a map of words to freq/probability/weights. | |
for (final String keyWordCurrent : this.nextWords.keySet()) { | |
final List<String> nextWordsTmp = this.nextWords.get(keyWordCurrent); | |
final Map<String, Integer> mapTmpCtPerNext = new HashMap<String, Integer>(); | |
// Keep tmp map count, used to convert to probability | |
for (final String keyWordNext : nextWordsTmp) { | |
if (mapTmpCtPerNext.get(keyWordNext) == null) { | |
mapTmpCtPerNext.put(keyWordNext, 1); | |
} else { | |
// Increment count at each 'current' word. | |
mapTmpCtPerNext.put(keyWordNext, mapTmpCtPerNext.get(keyWordNext)+1); | |
} | |
} // End of for, tmp map with count loaded | |
final double n = nextWordsTmp.size(); | |
Map<String, Double> mapfreq = null; | |
if (this.nextWordsFreqPerc.get(keyWordCurrent) == null) { | |
mapfreq = new HashMap<String, Double>(); | |
this.nextWordsFreqPerc.put(keyWordCurrent, mapfreq); | |
} else { | |
mapfreq = this.nextWordsFreqPerc.get(keyWordCurrent); | |
} | |
for (final String nxt : mapTmpCtPerNext.keySet()) { | |
final Integer ct = mapTmpCtPerNext.get(nxt); | |
if (ct != null) { | |
final double prob = ct.doubleValue() / n; | |
mapfreq.put(nxt, prob); | |
} | |
} | |
// Also add nav map perc | |
this.buildProbNavMap(keyWordCurrent); | |
} // Loop through each current word | |
// Exit with the nextWordsFreqPerc data loaded. | |
} | |
/** | |
* Generate text. | |
*/ | |
public String generateText() { | |
final Map<Integer, String> simplfirst = this.posByWordsData.get(txtGenStartPos); | |
final String firstWord = simplfirst.get(0); | |
final StringBuffer buf = new StringBuffer(); | |
String lastword = firstWord; | |
buf.append(lastword).append(" "); | |
for (int i = 0; i < this.txtGenMaxWords; i++) { | |
lastword = this.next(lastword); | |
buf.append(lastword).append(" "); | |
} | |
return buf.toString(); | |
} | |
public void printText() { | |
System.out.println(generateText()); | |
} | |
/** | |
* Load the markov chain data structures. | |
* Load the input documents, invoke <code>loadAndTokenizeDocument</code> | |
* then findWordsReqPos, buildDigramWordList ... | |
* | |
* @throws Exception | |
*/ | |
public void train() throws Exception { | |
// These simple input text documents are mostly variants of the string: | |
// 'This is going to be a good day for all the powerful people'. | |
final String doc1 = loadDocument("docs/testmark/doc1.txt"); | |
final String doc2 = loadDocument("docs/testmark/doc2.txt"); | |
final String doc3 = loadDocument("docs/testmark/doc3.txt"); | |
final String doc4 = loadDocument("docs/testmark/doc4.txt"); | |
this.docsData.add(doc1); | |
this.docsData.add(doc2); | |
this.docsData.add(doc3); | |
this.docsData.add(doc4); | |
int id = 0; | |
for (final String doc : this.docsData) { | |
this.loadAndTokenizeDocument(doc, id); | |
id++; | |
} | |
this.findWordsFreqPos(); | |
this.buildDigramWordList(); | |
this.buildDigramStats(); | |
} // End of the method // | |
/** | |
* Digram/bigram is a key value structure with 'current' word and 'next' word. | |
* | |
* @author bbrown | |
*/ | |
private class Digram { | |
// Or bigram | |
// Key value pair for 'next' | |
private String cur; | |
private String next; | |
public Digram(final String k, final String v) { | |
this.cur = k; | |
this.next = v; | |
} | |
} | |
private String next(final String cur) { | |
Map<String, Double> revFreqToTotal = null; | |
NavigableMap<Double, String> revFreqToWord = null; | |
if (reverseFreqToWord.get(cur) == null) { | |
revFreqToWord = new TreeMap<Double, String>(); | |
reverseFreqToWord.put(cur, revFreqToWord); | |
} else { | |
revFreqToWord = reverseFreqToWord.get(cur); | |
} | |
if (reverseFreqToTotal.get(cur) == null) { | |
revFreqToTotal = new TreeMap<String, Double>(); | |
reverseFreqToTotal.put(cur, revFreqToTotal); | |
} else { | |
revFreqToTotal = reverseFreqToTotal.get(cur); | |
} | |
return nextnav(revFreqToWord, revFreqToTotal, cur); | |
} | |
private void buildProbNavMap(final String cur) { | |
final Map<String, Double> mapfreq = this.nextWordsFreqPerc.get(cur); | |
if (mapfreq != null) { | |
// Could possibly be null, do not continue | |
// Loop through current words, find next | |
Map<String, Double> revFreqToTotal = null; | |
NavigableMap<Double, String> revFreqToWord = null; | |
if (reverseFreqToWord.get(cur) == null) { | |
revFreqToWord = new TreeMap<Double, String>(); | |
reverseFreqToWord.put(cur, revFreqToWord); | |
} else { | |
revFreqToWord = reverseFreqToWord.get(cur); | |
} | |
if (reverseFreqToTotal.get(cur) == null) { | |
revFreqToTotal = new TreeMap<String, Double>(); | |
reverseFreqToTotal.put(cur, revFreqToTotal); | |
} else { | |
revFreqToTotal = reverseFreqToTotal.get(cur); | |
} | |
for (final String nxt : mapfreq.keySet()) { | |
final Double weight = mapfreq.get(nxt); | |
if (weight != null) { | |
addnav(revFreqToWord, revFreqToTotal, weight, cur, nxt); | |
} | |
} // End of for, add nav | |
} | |
} | |
private static void addnav(final NavigableMap<Double, String> revFreqToWord, Map<String, Double> revFreqToTotal, | |
final double weight, final String cur, final String resultar) { | |
double curtotal = 0; | |
curtotal = revFreqToTotal.get(cur) == null ? 0 : revFreqToTotal.get(cur); | |
if (curtotal < 0) { | |
curtotal = 0; | |
} | |
curtotal += weight; | |
revFreqToTotal.put(cur, curtotal); | |
revFreqToWord.put(curtotal, resultar); | |
} | |
private String nextnav(final NavigableMap<Double, String> revFreqToWord, Map<String, Double> revFreqToTotal, final String cur) { | |
if (revFreqToWord == null) { | |
return ""; | |
} | |
if (revFreqToTotal == null) { | |
return ""; | |
} | |
double curtotal = 0; | |
curtotal = revFreqToTotal.get(cur) == null ? 0 : revFreqToTotal.get(cur); | |
if (curtotal < 0) { | |
curtotal = 0; | |
} | |
double value = random.nextDouble() * curtotal; | |
return revFreqToWord.ceilingEntry(value).getValue(); | |
} | |
/** | |
* Main entry point for program. | |
* | |
* @param args | |
* @throws Exception | |
*/ | |
public static void main(final String [] args) throws Exception { | |
final MarkovChainTextGenerator c = new MarkovChainTextGenerator(); | |
c.train(); | |
c.printText(); | |
} // End of the method // | |
} // End of the class // |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment