Skip to content

Instantly share code, notes, and snippets.

@d0rc
Created October 25, 2024 01:36
Show Gist options
  • Save d0rc/e362c0e25573af86c1b1e3918082a22f to your computer and use it in GitHub Desktop.
Save d0rc/e362c0e25573af86c1b1e3918082a22f to your computer and use it in GitHub Desktop.
searching text space
package main
import (
"math"
"github.com/d0rc/mcts"
"time"
)
// Sequence and other types remain as provided
type Sequence struct {
Tokens []int
Entropies []float64
}
type GenerationResult struct {
Tokens []GeneratedToken
Entropy float64
}
type GeneratedToken struct {
TokenId int
Probability float64
}
// calculateVarentropy computes the variance of entropy values
func calculateVarentropy(entropies []float64) float64 {
if len(entropies) < 2 {
return math.MaxFloat64
}
sum := 0.0
for _, e := range entropies {
sum += e
}
mean := sum / float64(len(entropies))
varSum := 0.0
for _, e := range entropies {
diff := e - mean
varSum += diff * diff
}
return varSum / float64(len(entropies))
}
func main() {
// Initial sequence [1,2,3,4]
initialSeq := Sequence{
Tokens: []int{1, 2, 3, 4},
Entropies: make([]float64, 4), // Will be populated on first evaluation
}
// Define next elements function for MCTS
nextElements := func(seq interface{}) []interface{} {
currentSeq := seq.(Sequence)
result, err := GetNextTokens(currentSeq)
if err != nil {
return nil
}
// For each possible next token, create a new sequence
moves := make([]interface{}, len(result.Tokens))
for i, token := range result.Tokens {
newTokens := append(currentSeq.Tokens, token.TokenId)
newEntropies := append(currentSeq.Entropies, result.Entropy)
moves[i] = Sequence{
Tokens: newTokens,
Entropies: newEntropies,
}
}
return moves
}
// Define fitness function for MCTS
fitnessFunc := func(seq interface{}) float64 {
currentSeq := seq.(Sequence)
varentropy := calculateVarentropy(currentSeq.Entropies)
totalEntropy := 0.0
for _, e := range currentSeq.Entropies {
totalEntropy += e
}
// Weight varentropy more heavily (70/30 split)
return 0.7*varentropy + 0.3*totalEntropy
}
// Configure MCTS
config := mcts.Config{
ExplorationConstant: 2.0, // Favor exploration
MaxIterations: 10000, // Large number to ensure convergence
TargetSeqLength: 0, // Variable length sequences
RandomSeed: time.Now().UnixNano(),
DebugLevel: 1, // Show basic progress
}
// Run MCTS
bestSeq, err := mcts.Run(
initialSeq,
nextElements,
fitnessFunc,
config,
)
if err != nil {
panic(err)
}
// Get results
finalSeq := bestSeq.(Sequence)
fmt.Printf("Best sequence: %v\n", finalSeq.Tokens)
fmt.Printf("Entropies: %v\n", finalSeq.Entropies)
fmt.Printf("Varentropy: %v\n", calculateVarentropy(finalSeq.Entropies))
fmt.Printf("Total entropy: %v\n",
func() float64 {
sum := 0.0
for _, e := range finalSeq.Entropies {
sum += e
}
return sum
}())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment