Created
October 25, 2024 01:36
-
-
Save d0rc/e362c0e25573af86c1b1e3918082a22f to your computer and use it in GitHub Desktop.
searching text space
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"math" | |
"github.com/d0rc/mcts" | |
"time" | |
) | |
// Sequence and other types remain as provided | |
type Sequence struct { | |
Tokens []int | |
Entropies []float64 | |
} | |
type GenerationResult struct { | |
Tokens []GeneratedToken | |
Entropy float64 | |
} | |
type GeneratedToken struct { | |
TokenId int | |
Probability float64 | |
} | |
// calculateVarentropy computes the variance of entropy values | |
func calculateVarentropy(entropies []float64) float64 { | |
if len(entropies) < 2 { | |
return math.MaxFloat64 | |
} | |
sum := 0.0 | |
for _, e := range entropies { | |
sum += e | |
} | |
mean := sum / float64(len(entropies)) | |
varSum := 0.0 | |
for _, e := range entropies { | |
diff := e - mean | |
varSum += diff * diff | |
} | |
return varSum / float64(len(entropies)) | |
} | |
func main() { | |
// Initial sequence [1,2,3,4] | |
initialSeq := Sequence{ | |
Tokens: []int{1, 2, 3, 4}, | |
Entropies: make([]float64, 4), // Will be populated on first evaluation | |
} | |
// Define next elements function for MCTS | |
nextElements := func(seq interface{}) []interface{} { | |
currentSeq := seq.(Sequence) | |
result, err := GetNextTokens(currentSeq) | |
if err != nil { | |
return nil | |
} | |
// For each possible next token, create a new sequence | |
moves := make([]interface{}, len(result.Tokens)) | |
for i, token := range result.Tokens { | |
newTokens := append(currentSeq.Tokens, token.TokenId) | |
newEntropies := append(currentSeq.Entropies, result.Entropy) | |
moves[i] = Sequence{ | |
Tokens: newTokens, | |
Entropies: newEntropies, | |
} | |
} | |
return moves | |
} | |
// Define fitness function for MCTS | |
fitnessFunc := func(seq interface{}) float64 { | |
currentSeq := seq.(Sequence) | |
varentropy := calculateVarentropy(currentSeq.Entropies) | |
totalEntropy := 0.0 | |
for _, e := range currentSeq.Entropies { | |
totalEntropy += e | |
} | |
// Weight varentropy more heavily (70/30 split) | |
return 0.7*varentropy + 0.3*totalEntropy | |
} | |
// Configure MCTS | |
config := mcts.Config{ | |
ExplorationConstant: 2.0, // Favor exploration | |
MaxIterations: 10000, // Large number to ensure convergence | |
TargetSeqLength: 0, // Variable length sequences | |
RandomSeed: time.Now().UnixNano(), | |
DebugLevel: 1, // Show basic progress | |
} | |
// Run MCTS | |
bestSeq, err := mcts.Run( | |
initialSeq, | |
nextElements, | |
fitnessFunc, | |
config, | |
) | |
if err != nil { | |
panic(err) | |
} | |
// Get results | |
finalSeq := bestSeq.(Sequence) | |
fmt.Printf("Best sequence: %v\n", finalSeq.Tokens) | |
fmt.Printf("Entropies: %v\n", finalSeq.Entropies) | |
fmt.Printf("Varentropy: %v\n", calculateVarentropy(finalSeq.Entropies)) | |
fmt.Printf("Total entropy: %v\n", | |
func() float64 { | |
sum := 0.0 | |
for _, e := range finalSeq.Entropies { | |
sum += e | |
} | |
return sum | |
}()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment