Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
okdshin committed Jan 13, 2025
1 parent 3eb720d commit 1fd5fca
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 28 deletions.
8 changes: 8 additions & 0 deletions src/reasoning_llm_mcts/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import asyncio
import uuid
import traceback
import sys
from pathlib import Path
from typing import Any, Optional

Expand Down Expand Up @@ -85,6 +87,7 @@ async def completions(request: CompletionRequest):
)

best_node = await mcts.search(initial_state)
print(f"{best_node.state.total_prompt=}")

return CompletionResponse(
id=f"cmpl-{uuid.uuid4()}", # Generate a unique ID
Expand All @@ -106,6 +109,11 @@ async def completions(request: CompletionRequest):
},
)
except Exception as e:
# スタックトレースを標準出力に出力
traceback.print_exc(file=sys.stdout)
# エラーの詳細情報をログに残す
print(f"Error details: {str(e)}", file=sys.stdout)
# HTTPエラーを発生させる
raise HTTPException(status_code=500, detail=str(e))


Expand Down
19 changes: 17 additions & 2 deletions src/reasoning_llm_mcts/mcts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
Expand All @@ -23,6 +22,10 @@ def __str__(self) -> str:
"""String representation of the state"""
raise NotImplementedError

@abstractmethod
def is_terminal(self) -> bool:
return False


@dataclass
class SearchNode:
Expand All @@ -39,6 +42,9 @@ def __post_init__(self):
def is_leaf(self) -> bool:
return len(self.children) == 0

def is_terminal(self) -> bool:
return self.state.is_terminal()

async def expand(self, expand_num: int) -> None:
child_states = await self.state.expand(expand_num=expand_num)
for child_state in child_states:
Expand Down Expand Up @@ -81,11 +87,20 @@ async def search(self, initial_state: State) -> SearchNode:
)
continue
assert current_node.is_leaf()
if (not current_node.is_terminal()) and (
current_node.visit_count >= self.visit_count_threshold):
await current_node.expand(expand_num=self.expand_num)
current_node = root_node
continue
assert current_node.is_terminal() or current_node.visit_count < self.visit_count_threshold
"""
if current_node.visit_count >= self.visit_count_threshold:
await current_node.expand(expand_num=self.expand_num)
current_node = root_node
continue
assert current_node.visit_count < self.visit_count_threshold
"""

value = await current_node.evaluate()
self.backpropagate(start_node=current_node, value=value)
break
Expand All @@ -100,6 +115,6 @@ def backpropagate(self, start_node: SearchNode, value: float) -> None:
def get_best_child(self, start_node: SearchNode) -> SearchNode:
current_node = start_node
while not current_node.is_leaf():
print(f"{str(current_node.state)=} {current_node.children=}")
#print(f"{str(current_node.state)=} {current_node.children=}")
current_node = max(current_node.children, key=lambda node: node.visit_count)
return current_node
106 changes: 80 additions & 26 deletions src/reasoning_llm_mcts/reasoning_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,16 @@ def __post_init__(self):

async def expand(self, expand_num: int) -> State:
"""Generate a new state by expanding the current state"""
assert expand_num <= len(self.child_state_candidates)
responses = list(
zip(*(sorted(self.child_state_candidates, reverse=True)[:expand_num]))
)[1]
print("expand", self.total_prompt)
print(f"{self.child_state_candidates=}")
assert expand_num <= len(self.child_state_candidates), "here"
assert not self.is_terminal()
top_candidates = sorted(
self.child_state_candidates,
reverse=True,
key=lambda candidate: candidate[0],
)[:expand_num]
responses = [top_candidate[1] for top_candidate in top_candidates]

child_states = []
for response in responses:
Expand All @@ -59,17 +65,24 @@ async def expand(self, expand_num: int) -> State:
]

former_confidence_score = calc_confidence_score(
token_logprobs=token_logprobs, top_logprobs=top_logprobs
token_logprobs=token_logprobs,
top_logprobs=top_logprobs,
debug_logprobs=logprobs,
)
print(f"{former_confidence_score=}")
token_delta_num = len(token_logprobs)
child_confidence_score = (
self.total_new_token_num * self.confidence_score
+ token_delta_num * former_confidence_score
) / (self.total_new_token_num + token_delta_num)
print(f"{child_confidence_score=}")

# Construct text_delta from the first max_delta_new_tokens tokens
tokens = logprobs.tokens[: self.max_new_tokens_delta]
text_delta = "".join(bytes(t.bytes).decode("utf-8") for t in tokens)
# print(f"{tokens=}")
# text_delta = "".join(bytes(t.bytes).decode("utf-8") for t in tokens)
text_delta = "".join(tokens)
print(f"{text_delta=}")

child_states.append(
ReasoningState(
Expand All @@ -83,49 +96,75 @@ async def expand(self, expand_num: int) -> State:
top_logprobs_num=self.top_logprobs_num,
)
)
# print(f"{child_states}")
return child_states

async def evaluate(self) -> float:
"""Evaluate the current state and return a value"""
max_new_tokens = max(self.max_total_tokens - self.total_new_token_num, 0)
if max_new_tokens == 0:
print(f"{self.total_prompt=}")
if self.is_terminal():
return self.confidence_score
response = await AsyncOpenAI(base_url=self.api_base_url).completions.create(
prompt=self.total_prompt,
logprobs=self.top_logprobs_num,
max_tokens=max_new_tokens,
)
max_new_tokens = max(self.max_total_tokens - self.total_new_token_num, 0)
while True:
response = await AsyncOpenAI(base_url=self.api_base_url).completions.create(
model="swallow-mx-4bit", # TODO
prompt=self.total_prompt,
logprobs=self.top_logprobs_num,
max_tokens=max_new_tokens,
stop=["\n\n"],
)
if response.choices[0].text != "":
break

print(response.choices[0].text)
logprobs = response.choices[0].logprobs

print(response.choices[0].finish_reason)

if logprobs is None:
print(f"{response=}")
print(f"{response.choices[0].text=}")
# top_logprobs may contain self.top_logprobs + 1 elements
top_logprobs = [
sorted(top_lps.values())[: self.top_logprobs_num]
for top_lps in logprobs.top_logprobs
]
# print(f"{top_logprobs=}")

rest_confidence_score = calc_confidence_score(
token_logprobs=logprobs.token_logprobs, top_logprobs=top_logprobs
token_logprobs=logprobs.token_logprobs,
top_logprobs=top_logprobs,
debug_logprobs=logprobs,
)
print(f"{rest_confidence_score=}")

rest_token_num = len(logprobs.token_logprobs)
print(f"{rest_token_num=}")

confidence_score = (
combined_confidence_score = (
self.total_new_token_num * self.confidence_score
+ rest_token_num * rest_confidence_score
) / (self.total_new_token_num + rest_token_num)
print(f"{combined_confidence_score=}")

# Update candidates for expanding
self.child_state_candidates.append((confidence_score, response))
self.child_state_candidates.append((combined_confidence_score, response))

return 0.1 * combined_confidence_score

return confidence_score
def is_terminal(self) -> bool:
if self.parent_state is None:
return False # root state is not terminal
return (self.max_total_tokens <= self.total_new_token_num) or (
self.num_tokens_delta < self.max_new_tokens_delta)

def __str__(self) -> str:
return f"{self.text_delta}"

@cached_property
def total_new_token_num(self) -> int:
if self.parent_state is None:
assert self.num_tokens_delta == 0
assert self.num_tokens_delta == 0, "here2"
return self.num_tokens_delta
return self.parent_state.total_new_token_num + self.num_tokens_delta

Expand All @@ -137,20 +176,35 @@ def total_prompt(self) -> str:


def calc_confidence_score(
token_logprobs: list[float], top_logprobs: list[list[float]]
token_logprobs: list[float],
top_logprobs: list[list[float]],
debug_logprobs,
) -> float:
assert len(token_logprobs) == len(top_logprobs)
assert len(token_logprobs) == len(top_logprobs), "here3"
prob_sum = 0.0
for token_lp, top_lps in zip(token_logprobs, top_logprobs):
assert len(top_lps) == len(top_logprobs[0])
count = 0
for i, (token_lp, top_lps) in enumerate(zip(token_logprobs, top_logprobs)):
# print(f"{token_lp=}, {top_lps=}")
# print(f"{debug_logprobs.top_logprobs[i]=}")
# TODO
# assert len(top_lps) == len(top_logprobs[0])

# This is actual computation
# ci_sum += math.exp(token_lp) / sum([math.exp(top_lps) for top_lps in top_lps])

max_lp = max(top_lps)
prob_sum += math.exp(token_lp - max_lp) / sum(
max_lp = max(top_lps + [token_lp])
# print(f"{max_lp=}")
prob = math.exp(token_lp - max_lp) / sum(
math.exp(lp - max_lp) for lp in top_lps
)

confidence_score = prob_sum / len(token_logprobs)
if len(top_lps) != 5: # TODO
continue
# print(f"{prob=}")
prob_sum += prob
# print(f"{prob_sum=}, {len(token_logprobs)=}")
count += 1
assert prob_sum <= count

confidence_score = prob_sum / count
print(f"{confidence_score=}")
return confidence_score

0 comments on commit 1fd5fca

Please sign in to comment.