Created
October 21, 2015 00:39
-
-
Save pitzcarraldo/ed7cdda48b7f66f66257 to your computer and use it in GitHub Desktop.
MAB(Multi-Armed Bandits) Java Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* @descreption Calculate impression ratio for Arms by ArmStatistics. | |
* @author Minkyu Cho([email protected]) | |
* The logic of this class is based on R code in below article(Figure 4). | |
* http://mktg455cnu.net/wp-content/uploads/2014/10/scott.pdf | |
*/ | |
@Slf4j | |
@Component | |
public class RatioCalculator { | |
private static final int SIMULATE_COUNT_BASE = 100; | |
/** | |
* | |
* Calculate ratios by ArmStatistics. | |
* | |
* In R Language | |
* compute.win.prob(clicks , views, simulateCount); | |
* | |
* clicks and views included in ArmStatistic and SimulateCount is number of Arms * SIMULATE_COUNT_BASE | |
* | |
* @param statistics the ArmStatistics of single Arm | |
* @return impressionRationMap key : itemId, value : impressionRatio | |
*/ | |
public Map<Long, Integer> computeImpressionRatio(List<ArmStatistic> statistics) { | |
int simulateCount = statistics.size() * SIMULATE_COUNT_BASE; | |
Map<Long, Integer> ratioMap = sortAndFillPercentage(computeImpressionRatio(makeSimulateMap(statistics, simulateCount))); | |
for (ArmStatistic statistic : statistics) { | |
if (!ratioMap.containsKey(statistic.getArmId())) { | |
ratioMap.put(statistic.getArmId(), 0); | |
} | |
} | |
return ratioMap; | |
} | |
/** | |
* In R Language | |
* prob.winner <- function(post){ | |
* k <- ncol(post) | |
* w <- table(factor(max.col(post), levels=1:k)) | |
* return (w/sum(w)) | |
* }; | |
* @param simulateMap | |
* @return | |
*/ | |
private Map<Long, Integer> computeImpressionRatio(Multimap<Long, Double> simulateMap) { | |
Map<Long, AtomicLong> winCountMap = getWinCountMap(getWinItems(simulateMap)); | |
long sum = 0; | |
for (Map.Entry<Long, AtomicLong> entry : winCountMap.entrySet()) { | |
sum += entry.getValue().longValue(); | |
} | |
BigDecimal sumBigDecimal = new BigDecimal(sum); | |
Map<Long, Integer> winners = Maps.newHashMap(); | |
final int roundUpLimit = 2; | |
for (Map.Entry<Long, AtomicLong> entry : winCountMap.entrySet()) { | |
BigDecimal wins = new BigDecimal(entry.getValue().longValue()); | |
BigDecimal ratio = wins.divide(sumBigDecimal, roundUpLimit, BigDecimal.ROUND_HALF_UP).multiply(new BigDecimal(SIMULATE_COUNT_BASE)); | |
winners.put(entry.getKey(), ratio.intValue()); | |
} | |
return winners; | |
} | |
/** | |
* Make simulation map from log of click and presentation. | |
* {item1 : [betaValue1, betaValue2 ... ~ simulateCount]} | |
* {item2 : [betaValue1, betaValue2 ... ~ simulateCount]} | |
* ... | |
* betaValue = rbeta(clicks + 1,view - click + 1) > Estimated Conversion Rate. | |
* | |
* In R Language | |
* sim.post <- function(clicks , views, simulateCount){ | |
* nItems <- length(clicks) | |
* simulateMatrix <- matrix(nrow=simulateCount, ncol=nItems) | |
* no <- views-clicks | |
* for(i in 1:nItems) | |
* simulateMatrix[,i]<-rbeta(simulateCount,clicks[i]+1,no[i]+1) | |
* return(simulateMatrix) | |
* }; | |
* | |
* @param stats | |
* @param simulateCount | |
* @return simulateMap | |
*/ | |
private Multimap<Long, Double> makeSimulateMap(List<ArmStatistic> stats, int simulateCount) { | |
Multimap<Long, Double> simulateMap = ArrayListMultimap.create(); | |
for (ArmStatistic stat : stats) { | |
BetaDistribution beta = new BetaDistributionImpl(stat.getClick() + 1, stat.getView() - stat.getClick() + 1); | |
for (int i = 0; i < simulateCount; i++) { | |
double betaValue = 0; | |
try { | |
betaValue = beta.inverseCumulativeProbability(Math.random()); | |
} catch (Exception e) { | |
log.error(e.getMessage(), e); | |
} | |
simulateMap.put(stat.getArmId(), betaValue); | |
} | |
} | |
return simulateMap; | |
} | |
/** | |
* Pick and sort highest betaValue. | |
* orderedDrawMap > Sorted item map by betaValue per simulation. | |
* { simulation 1 : {{betaValue : itemId}, ... 오름차순}} | |
* { simulation 2 : {{betaValue : itemId}, ... }} | |
* ... | |
* winItems > Set of highest betaValue. | |
* | |
* In R Language | |
* w <- table(factor(max.col(post), levels=1:k)) | |
* | |
* @param simulateMap | |
* @return ordered itemId list by simulation. | |
*/ | |
private List<Long> getWinItems(Multimap<Long, Double> simulateMap) { | |
Map<Integer, Map<Double, Long>> orderedDrawMap = Maps.newHashMap(); | |
for (Long itemId : simulateMap.keySet()) { | |
List<Double> draws = Lists.newArrayList(simulateMap.get(itemId)); | |
for (int i = 0; i < draws.size(); i++) { | |
if (!orderedDrawMap.containsKey(i)) { | |
Map<Double, Long> newMap = Maps.newTreeMap(); | |
orderedDrawMap.put(i, newMap); | |
} | |
orderedDrawMap.get(i).put(draws.get(i), itemId); | |
} | |
} | |
List<Long> winItems = Lists.newArrayList(); | |
for (Map.Entry<Integer, Map<Double, Long>> entry : orderedDrawMap.entrySet()) { | |
TreeMap<Double, Long> currentMap = (TreeMap<Double, Long>) entry.getValue(); | |
winItems.add(currentMap.lastEntry().getValue()); | |
} | |
return winItems; | |
} | |
/** | |
* Calculate count of test what has highest betaValue per item. | |
* In R Language | |
* max.col(post) | |
* @param winItems | |
* @return winCountMap | |
*/ | |
private Map<Long, AtomicLong> getWinCountMap(List<Long> winItems) { | |
ConcurrentMap<Long, AtomicLong> winCountMap = Maps.newConcurrentMap(); | |
for (Long itemId : winItems) { | |
winCountMap.putIfAbsent(itemId, new AtomicLong(0)); | |
winCountMap.get(itemId).incrementAndGet(); | |
} | |
return winCountMap; | |
} | |
/** | |
* Sort presentation ratio by DSC and fill to 100 when size of map is lower than 100. | |
* @param ratioMap | |
* @return sortedRatioMap | |
*/ | |
private Map<Long, Integer> sortAndFillPercentage(Map<Long, Integer> ratioMap) { | |
Map<Long, Integer> sortedMap = sortByDesc(ratioMap); | |
fillPercentageToOneHundred(sortedMap); | |
return sortedMap; | |
} | |
/** | |
* Sort by DSC. | |
* @param unsortedMap | |
* @return sortedMap | |
*/ | |
private Map<Long, Integer> sortByDesc(Map<Long, Integer> unsortedMap) { | |
List list = new LinkedList(unsortedMap.entrySet()); | |
Collections.sort(list, new Comparator() { | |
public int compare(Object o1, Object o2) { | |
return ((Comparable) ((Map.Entry) (o2)).getValue()).compareTo(((Map.Entry) (o1)).getValue()); | |
} | |
}); | |
Map<Long, Integer> sortedMap = Maps.newLinkedHashMap(); | |
for (Iterator it = list.iterator(); it.hasNext();) { | |
Map.Entry<Long, Integer> entry = (Map.Entry<Long, Integer>) it.next(); | |
sortedMap.put(entry.getKey(), entry.getValue()); | |
} | |
return sortedMap; | |
} | |
/** | |
* Fill sortedMap to 100. | |
* @param sortedMap | |
*/ | |
private void fillPercentageToOneHundred(Map<Long, Integer> sortedMap) { | |
int sum = 0; | |
for (Map.Entry<Long, Integer> entry : sortedMap.entrySet()) { | |
sum += entry.getValue(); | |
} | |
if (sum <= 0) { | |
return; | |
} | |
if (sum != SIMULATE_COUNT_BASE) { | |
int gap = SIMULATE_COUNT_BASE - sum; | |
Long lastKey = Iterables.getLast(sortedMap.keySet(), null); | |
sortedMap.put(lastKey, sortedMap.get(lastKey) + gap); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment