← Back to home
dec 5

monte carlo tree search

monte carlo tree search is yet another concept that's pretty simple to understand but sounds v scary. i hate the fact that all research concepts sound so overwhelming that people won't even dare explore them. dw, i gotchu transformers are feedforward prediction approximators. dw im just messing w you. in simple terms, you give a transformer y tokens, it will try to predict the y+1th token. it doesn't think. it doesn't search. it just maps an input to an output. all the fancy reasoning methods(including chain of thought) are barely just that. fancy. it still just does the next token approximation. the only difference is it would have some tokens in between to spend tokens on the question asked. true reasoning however requires search. when you think about how humans reason: - understand the problem - search our option space, see what would lead to a good/bad result - ignore the bad ones - try out the good ones, once something fails. mark it as bad. - repeat transformers cannot do this natively because they generate tokens by sampling, not by exploring alternatives. MCTS fixes this by letting the model simulate multiple hypothetical futures. let's take a step back. you have a model that can predict tokens. RL or reinforcement learning is about discovering how current actions(imagine tool calls) influence future outcomes(rewards), and choosing actions that maximize long-term payoff. pretty much like humans. it still doesn't enable the model to search. what does? during RL, the model learns a value function V(s) (usually a linear layer on top of the final transformer hidden state). MCTS uses this value during inference(inference means answering your "my friend is not replying, are they angry with me? was it even my fault") estimate to score hypothetical future states during search. so during inference say you are at a state A′. You can take actions 1, 2, or 3 and reach hypothetical future states A₁, A₂, and A₃. for each of these: - the transformer simulates the next hidden state in latent space(latent space is basically the model's internal representation of the situation at a given moment) - the value head predicts the expected future reward of that state - MCTS explores the most promising branches before choosing the best next action to actually execute but how does MCTS "explore" the most promising branches? what even is a promising branch? a promising branch is a branch that "might" lead to the best reward. MCTS uses the UCB score (Upper Confidence Bound) to find whether a branch is promising or not: UCB = Q(s) + c * sqrt( log(N_parent) / N_node ) Where: Q(s) = average value estimate for this node N_node = how many times this branch was visited N_parent = how many times the parent was visited c = exploration constant actually just ignore the equation above, here is what it basically means: Q(s) high → this branch looks good sqrt(log N_parent / N_node) high → this branch is not explored much yet so a branch is "promising" unless proven otherwise. MCTS is a cute concept like that. that's it for today. see you tomorrow in another lesson of making overwhelming research concepts less overwhelming.
loading comments...