From cc798c40ffdc75dac872e5c51e529b7167b15a85 Mon Sep 17 00:00:00 2001 From: rahulbshrestha Date: Mon, 27 May 2024 11:29:59 +0200 Subject: [PATCH 1/6] Added rought draft of notebook Signed-off-by: rahulbshrestha --- tree-of-thoughts.ipynb | 377 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 377 insertions(+) create mode 100644 tree-of-thoughts.ipynb diff --git a/tree-of-thoughts.ipynb b/tree-of-thoughts.ipynb new file mode 100644 index 0000000000..c2b4914d99 --- /dev/null +++ b/tree-of-thoughts.ipynb @@ -0,0 +1,377 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tree of Thoughts for problem solving with large language models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TLDR: This blog post is about using \"Tree of Thoughts\", a tree-based framework to solve the Game of 24 tasks with a large language model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the paper, \"Tree of Thoughts\", the authors introduced a new tree-based approach to solve LLMs " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Load Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll be using Hugging face ```transformers``` to generate text with our LLMs. First, we start off by importing the necessary libraries." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "import itertools" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll use the popular open-source language model, Mistral-7B. We can load the model and the tokenizer by:\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_id = \"mistralai/Mistral-7B-v0.3\"\n", + "model = AutoModelForCausalLM.from_pretrained(model_id)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To test out if your model works, you can run the following code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "inputs = tokenizer(\"Hi! My name is \", return_tensors=\"pt\")\n", + "outputs = model.generate(**inputs, max_new_tokens=20)\n", + "print(tokenizer.decode(outputs[0], skip_special_tokens=True))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Implement Tree of Thought (ToT) algorithm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The ToT algorithm is a tree-based approach that uses the LLM to generate a tree of possible solutions to a problem. The tree is constructed by recursively generating text from the LLM and selecting the most likely continuation at each node. The algorithm is designed to be flexible and can be applied to a wide range of problems. The core feature of the ToT algorithm can be separted into 4 parts:\n", + "\n", + "\n", + "- Generation\n", + "- Evaluation\n", + "- Selection\n", + "\n", + "\n", + "Below, we define the prompts (taken from the original repo for ToT) for guiding each of the different parts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 5-shot\n", + "standard_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24.\n", + "Input: 4 4 6 8\n", + "Answer: (4 + 8) * (6 - 4) = 24\n", + "Input: 2 9 10 12\n", + "Answer: 2 * 12 * (10 - 9) = 24\n", + "Input: 4 9 10 13\n", + "Answer: (13 - 9) * (10 - 4) = 24\n", + "Input: 1 4 8 8\n", + "Answer: (8 / 4 + 1) * 8 = 24\n", + "Input: 5 5 5 9\n", + "Answer: 5 + 5 + 5 + 9 = 24\n", + "Input: {input}\n", + "'''\n", + "\n", + "# 5-shot\n", + "cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.\n", + "Input: 4 4 6 8\n", + "Steps:\n", + "4 + 8 = 12 (left: 4 6 12)\n", + "6 - 4 = 2 (left: 2 12)\n", + "2 * 12 = 24 (left: 24)\n", + "Answer: (6 - 4) * (4 + 8) = 24\n", + "Input: 2 9 10 12\n", + "Steps:\n", + "12 * 2 = 24 (left: 9 10 24)\n", + "10 - 9 = 1 (left: 1 24)\n", + "24 * 1 = 24 (left: 24)\n", + "Answer: (12 * 2) * (10 - 9) = 24\n", + "Input: 4 9 10 13\n", + "Steps:\n", + "13 - 10 = 3 (left: 3 4 9)\n", + "9 - 3 = 6 (left: 4 6)\n", + "4 * 6 = 24 (left: 24)\n", + "Answer: 4 * (9 - (13 - 10)) = 24\n", + "Input: 1 4 8 8\n", + "Steps:\n", + "8 / 4 = 2 (left: 1 2 8)\n", + "1 + 2 = 3 (left: 3 8)\n", + "3 * 8 = 24 (left: 24)\n", + "Answer: (1 + 8 / 4) * 8 = 24\n", + "Input: 5 5 5 9\n", + "Steps:\n", + "5 + 5 = 10 (left: 5 9 10)\n", + "10 + 5 = 15 (left: 9 15)\n", + "15 + 9 = 24 (left: 24)\n", + "Answer: ((5 + 5) + 5) + 9 = 24\n", + "Input: {input}\n", + "'''\n", + "\n", + "# 1-shot\n", + "propose_prompt = '''Input: 2 8 8 14\n", + "Possible next steps:\n", + "2 + 8 = 10 (left: 8 10 14)\n", + "8 / 2 = 4 (left: 4 8 14)\n", + "14 + 2 = 16 (left: 8 8 16)\n", + "2 * 8 = 16 (left: 8 14 16)\n", + "8 - 2 = 6 (left: 6 8 14)\n", + "14 - 8 = 6 (left: 2 6 8)\n", + "14 / 2 = 7 (left: 7 8 8)\n", + "14 - 2 = 12 (left: 8 8 12)\n", + "Input: {input}\n", + "Possible next steps:\n", + "'''\n", + "\n", + "value_prompt = '''Evaluate if given numbers can reach 24 (sure/likely/impossible)\n", + "10 14\n", + "10 + 14 = 24\n", + "sure\n", + "11 12\n", + "11 + 12 = 23\n", + "12 - 11 = 1\n", + "11 * 12 = 132\n", + "11 / 12 = 0.91\n", + "impossible\n", + "4 4 10\n", + "4 + 4 + 10 = 8 + 10 = 18\n", + "4 * 10 - 4 = 40 - 4 = 36\n", + "(10 - 4) * 4 = 6 * 4 = 24\n", + "sure\n", + "4 9 11\n", + "9 + 11 + 4 = 20 + 4 = 24\n", + "sure\n", + "5 7 8\n", + "5 + 7 + 8 = 12 + 8 = 20\n", + "(8 - 5) * 7 = 3 * 7 = 21\n", + "I cannot obtain 24 now, but numbers are within a reasonable range\n", + "likely\n", + "5 6 6\n", + "5 + 6 + 6 = 17\n", + "(6 - 5) * 6 = 1 * 6 = 6\n", + "I cannot obtain 24 now, but numbers are within a reasonable range\n", + "likely\n", + "10 10 11\n", + "10 + 10 + 11 = 31\n", + "(11 - 10) * 10 = 10\n", + "10 10 10 are all too big\n", + "impossible\n", + "1 3 3\n", + "1 * 3 * 3 = 9\n", + "(1 + 3) * 3 = 12\n", + "1 3 3 are all too small\n", + "impossible\n", + "{input}\n", + "'''\n", + "\n", + "value_last_step_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.\n", + "Input: 4 4 6 8\n", + "Answer: (4 + 8) * (6 - 4) = 24\n", + "Judge: \n", + "sure\n", + "Input: 2 9 10 12\n", + "Answer: 2 * 12 * (10 - 9) = 24\n", + "Judge: \n", + "sure\n", + "Input: 4 9 10 13\n", + "Answer: (13 - 9) * (10 - 4) = 24\n", + "Judge: \n", + "sure\n", + "Input: 4 4 6 8\n", + "Answer: (4 + 8) * (6 - 4) + 1 = 25\n", + "Judge: \n", + "impossible\n", + "Input: 2 9 10 12\n", + "Answer: 2 * (12 - 10) = 24\n", + "Judge: \n", + "impossible\n", + "Input: 4 9 10 13\n", + "Answer: (13 - 4) * (10 - 9) = 24\n", + "Judge: \n", + "impossible\n", + "Input: {input}\n", + "Answer: {answer}\n", + "Judge:'''" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll start implementing our ToT algorithm. We'll define a function for each core part of the ToT algorithm.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generation\n", + "def get_proposals(task, x, y): \n", + " propose_prompt = task.propose_prompt_wrap(x, y)\n", + " proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\\n') #TODO: Change GPT to another function that uses mistral\n", + " return [y + _ + '\\n' for _ in proposals]\n", + "\n", + "\n", + "# Evaluation\n", + "def get_value(task, x, y, n_evaluate_sample, cache_value=True):\n", + " value_prompt = task.value_prompt_wrap(x, y)\n", + " if cache_value and value_prompt in task.value_cache:\n", + " return task.value_cache[value_prompt]\n", + " value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)\n", + " value = task.value_outputs_unwrap(x, y, value_outputs)\n", + " if cache_value:\n", + " task.value_cache[value_prompt] = value\n", + " return value\n", + "\n", + "def get_values(task, x, ys, n_evaluate_sample, cache_value=True):\n", + " values = []\n", + " local_value_cache = {}\n", + " for y in ys: # each partial output\n", + " if y in local_value_cache: # avoid duplicate candidates\n", + " value = 0\n", + " else: \n", + " value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)\n", + " local_value_cache[y] = value\n", + " values.append(value)\n", + " return values\n", + "\n", + "\n", + "# Search" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Run ToT with sample data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll take some example data i.e the sequence 4 5 6 10, and check if ToT can generate the correct expression." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = \"4 5 6 10\"\n", + "ys = ['']\n", + "x = data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TODO: Finish for loop " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_of_steps = 4\n", + "\n", + "for step in num_of_steps:\n", + " \n", + " # Generation (Propose / Sample)\n", + " new_ys = [get_proposals(x, y) for y in ys]\n", + " new_ys = list(itertools.chain(*new_ys))\n", + " ids = list(range(len(new_ys)))\n", + "\n", + " # Evaluation (Value / Vote)\n", + " values = get_values(task, x, new_ys, args.n_evaluate_sample)\n", + " \n", + " # Selection (Sample/Greedy)\n", + " select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]\n", + " select_new_ys = [new_ys[select_id] for select_id in select_ids]\n", + " \n", + " #infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})\n", + " ys = select_new_ys\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "transformers", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 320b7d48e4de3902f7924e7aae456b339cf95b80 Mon Sep 17 00:00:00 2001 From: rahulbshrestha Date: Tue, 28 May 2024 00:28:10 +0200 Subject: [PATCH 2/6] Added thought generation + evaluation Signed-off-by: rahulbshrestha --- tree-of-thoughts.ipynb | 203 ++++++++++++++++++++++++++++++----------- 1 file changed, 149 insertions(+), 54 deletions(-) diff --git a/tree-of-thoughts.ipynb b/tree-of-thoughts.ipynb index c2b4914d99..72ff7b571e 100644 --- a/tree-of-thoughts.ipynb +++ b/tree-of-thoughts.ipynb @@ -18,7 +18,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In the paper, \"Tree of Thoughts\", the authors introduced a new tree-based approach to solve LLMs " + "Tree of Thoughts (ToT) is a framework used by LLMs to solve complex reasoning problems. The intermediate steps in a reasoning process are split into “thoughts”, with the ToT algorithm encouraging exploration of these thoughts through search algorithms.\n" ] }, { @@ -41,8 +41,7 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import AutoTokenizer, AutoModelForCausalLM\n", - "import itertools" + "from transformers import AutoTokenizer, AutoModelForCausalLM" ] }, { @@ -68,7 +67,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To test out if your model works, you can run the following code:" + "Next, we create a function called ```mistral``` which we'll use to feed in our prompts and receive completions." ] }, { @@ -77,31 +76,42 @@ "metadata": {}, "outputs": [], "source": [ - "inputs = tokenizer(\"Hi! My name is \", return_tensors=\"pt\")\n", - "outputs = model.generate(**inputs, max_new_tokens=20)\n", - "print(tokenizer.decode(outputs[0], skip_special_tokens=True))" + "def mistral(prompt):\n", + " inputs = tokenizer(prompt, return_tensors=\"pt\")\n", + " outputs = model.generate(**inputs, max_new_tokens=20)\n", + " return tokenizer.decode(outputs[0], skip_special_tokens=True)\n", + "\n", + "mistral(\"Hi! My name is \")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### 2. Implement Tree of Thought (ToT) algorithm" + "### 2. Implementing Tree of Thoughts (ToT) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The ToT algorithm is a tree-based approach that uses the LLM to generate a tree of possible solutions to a problem. The tree is constructed by recursively generating text from the LLM and selecting the most likely continuation at each node. The algorithm is designed to be flexible and can be applied to a wide range of problems. The core feature of the ToT algorithm can be separted into 4 parts:\n", + "ToT can be broken down into 4 key steps:\n", + "\n", + "(a) Thought Decomposition\n", "\n", + "(b) Thought Generation\n", + "- In this step, the LLM is prompted to generate thoughts by either one of two ways:\n", + " - Sample: The thoughts are generated by sampling i.i.d thoughts from a Chain of Thought prompt.\n", + " - Propose: The thoughts are propsed sequentially depending on the previous prompts. \n", "\n", - "- Generation\n", - "- Evaluation\n", - "- Selection\n", + "(c) Thought Evaluation\n", + "- The LLMs are prompted to evaluate the thoughts generated in the previous step, by either: \n", + " - Value:\n", + " - Vote: \n", "\n", + "(d) Search Algorithm\n", "\n", - "Below, we define the prompts (taken from the original repo for ToT) for guiding each of the different parts." + "." ] }, { @@ -254,44 +264,111 @@ "\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we'll define functions necessary for \"Thought Generation\"." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "# def propose_prompt_wrap(x: str, y: str='') -> str:\n", + "# current_numbers = get_current_numbers(y if y else x)\n", + "# if current_numbers == '24':\n", + "# prompt = cot_prompt.format(input=x) + 'Steps:' + y\n", + "# # print([prompt])\n", + "# else:\n", + "# prompt = propose_prompt.format(input=current_numbers)\n", + "# return prompt\n", + " \n", + "\n", "# Generation\n", + "def generate_thoughts(prompt):\n", + " \n", + " current_numbers = get_current_numbers(prompt)\n", + " prompt = propose_prompt.format(input=current_numbers)\n", + " \n", + " thoughts = mistral(prompt)[0].split('\\n') # TODO: Test this out\n", + "\n", + " return thoughts\n", + "\n", + "\n", "def get_proposals(task, x, y): \n", " propose_prompt = task.propose_prompt_wrap(x, y)\n", - " proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\\n') #TODO: Change GPT to another function that uses mistral\n", - " return [y + _ + '\\n' for _ in proposals]\n", + " proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\\n')\n", + " return [y + _ + '\\n' for _ in proposals]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll create the functions necessary for \"Thought Evaluation\", where each of the thoughts are evaluated by the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_current_numbers(y: str) -> str:\n", + " last_line = y.strip().split('\\n')[-1]\n", + " return last_line.split('left: ')[-1].split(')')[0]\n", + "\n", + "# def value_prompt_wrap(x: str, y: str) -> str:\n", + "# last_line = y.strip().split('\\n')[-1]\n", + "# if 'left: ' not in last_line: # last step\n", + "# ans = last_line.lower().replace('answer: ', '')\n", + "# return value_last_step_prompt.format(input=x, answer=ans)\n", + "# current_numbers = get_current_numbers(y)\n", + "# return value_prompt.format(input=current_numbers) # This replaces the input term\n", "\n", "\n", - "# Evaluation\n", - "def get_value(task, x, y, n_evaluate_sample, cache_value=True):\n", - " value_prompt = task.value_prompt_wrap(x, y)\n", - " if cache_value and value_prompt in task.value_cache:\n", - " return task.value_cache[value_prompt]\n", - " value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)\n", - " value = task.value_outputs_unwrap(x, y, value_outputs)\n", - " if cache_value:\n", - " task.value_cache[value_prompt] = value\n", + "def get_value(thought):\n", + " \n", + " current_numbers = get_current_numbers(thought)\n", + " value_prompt = value_prompt.format(input=current_numbers)\n", + " value_outputs = mistral(value_prompt)\n", + " \n", + " value_names = [_.split('\\n')[-1] for _ in value_outputs]\n", + " value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} \n", + " \n", + " for name, value in value_maps.items():\n", + " value = sum(value * value_names.count(name))\n", + "\n", " return value\n", "\n", - "def get_values(task, x, ys, n_evaluate_sample, cache_value=True):\n", + "\n", + "def evaluate_thoughts(thoughts):\n", + " \n", " values = []\n", - " local_value_cache = {}\n", - " for y in ys: # each partial output\n", - " if y in local_value_cache: # avoid duplicate candidates\n", - " value = 0\n", - " else: \n", - " value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)\n", - " local_value_cache[y] = value\n", - " values.append(value)\n", - " return values\n", "\n", + " for thought in thoughts:\n", + " value = get_value(thought)\n", "\n", - "# Search" + " values.append(value)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we'll implement the \"Search Algorithm\" which will be used to search through the thoughts generated by the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Implement search algorithm" ] }, { @@ -305,7 +382,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Next, we'll take some example data i.e the sequence 4 5 6 10, and check if ToT can generate the correct expression." + "We'll test our implementation with some sample data i.e the sequence 4 5 6 10. If ToT works sucessfully, it should output the operations that can be performed to reach 24." ] }, { @@ -314,16 +391,31 @@ "metadata": {}, "outputs": [], "source": [ - "data = \"4 5 6 10\"\n", - "ys = ['']\n", - "x = data" + "data = \"4 5 6 10\"" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "TODO: Finish for loop " + "import itertools\n", + "num_of_steps = 4\n", + "thoughts = []\n", + "\n", + "for step in num_of_steps:\n", + " \n", + " # Thought Generation\n", + " thoughts = generate_thoughts(thoughts)\n", + " thoughts = list(itertools.chain(*thoughts))\n", + " ids = list(range(len(thoughts)))\n", + "\n", + " # Thought evaluation\n", + " values = evaluate_thoughts(thoughts)\n", + "\n", + " # Search algorithm\n", + "\n" ] }, { @@ -332,24 +424,27 @@ "metadata": {}, "outputs": [], "source": [ - "num_of_steps = 4\n", + "# import itertools\n", "\n", - "for step in num_of_steps:\n", + "# num_of_steps = 4\n", + "\n", + "# for step in num_of_steps:\n", + "# thoughts = \n", " \n", - " # Generation (Propose / Sample)\n", - " new_ys = [get_proposals(x, y) for y in ys]\n", - " new_ys = list(itertools.chain(*new_ys))\n", - " ids = list(range(len(new_ys)))\n", + "# # Generation (Propose / Sample)\n", + "# new_ys = [get_proposals(x, y) for y in ys]\n", + "# new_ys = list(itertools.chain(*new_ys))\n", + "# ids = list(range(len(new_ys)))\n", "\n", - " # Evaluation (Value / Vote)\n", - " values = get_values(task, x, new_ys, args.n_evaluate_sample)\n", + "# # Evaluation (Value / Vote)\n", + "# values = get_values(task, x, new_ys, args.n_evaluate_sample)\n", " \n", - " # Selection (Sample/Greedy)\n", - " select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]\n", - " select_new_ys = [new_ys[select_id] for select_id in select_ids]\n", + "# # Selection (Sample/Greedy)\n", + "# select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]\n", + "# select_new_ys = [new_ys[select_id] for select_id in select_ids]\n", " \n", - " #infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})\n", - " ys = select_new_ys\n" + "# #infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})\n", + "# ys = select_new_ys\n" ] } ], From 5304cefe7d6f116f3ff16a8ece38cb14af6eb209 Mon Sep 17 00:00:00 2001 From: rahulbshrestha Date: Sun, 2 Jun 2024 14:50:08 +0200 Subject: [PATCH 3/6] Added support for OpenAI models Signed-off-by: rahulbshrestha --- tree-of-thoughts.ipynb | 219 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 207 insertions(+), 12 deletions(-) diff --git a/tree-of-thoughts.ipynb b/tree-of-thoughts.ipynb index 72ff7b571e..aa3f9122a9 100644 --- a/tree-of-thoughts.ipynb +++ b/tree-of-thoughts.ipynb @@ -84,6 +84,57 @@ "mistral(\"Hi! My name is \")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternative, we can also use OpenAI's GPT-3.5/4 models. We can load the model by:" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import openai\n", + "from openai import OpenAI\n", + "from constants import OPENAI_API_KEY\n", + "\n", + "api_key = os.getenv(\"OPENAI_API_KEY\", OPENAI_API_KEY)\n", + "\n", + "if api_key != \"\":\n", + " openai.api_key = OPENAI_API_KEY\n", + "else:\n", + " print(\"Warning: OPENAI_API_KEY is not set\")\n", + "\n", + "client = OpenAI(api_key=api_key)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "global response\n", + "\n", + "def gpt(prompt, model=\"gpt-4\", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:\n", + " \n", + " messages = [{\"role\": \"user\", \"content\": prompt}]\n", + " \n", + " outputs = []\n", + "\n", + " res = client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)\n", + " response = res\n", + "\n", + " for choice in res.choices:\n", + " outputs.extend([choice.message.content])\n", + "\n", + " return outputs " + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -111,12 +162,17 @@ "\n", "(d) Search Algorithm\n", "\n", - "." + ".\n", + "\n", + "\n", + "In this tutorial, we'll be using ToT with Mistral to solve the Game of 24.\n", + "\n", + "The Game of 24 is a task where given a sequence of 4 numbers, we’ll need to find the correct mathematical operations (add, subtract, multiply, divide) that’ll lead to the number 24. For example, if the sequence is {4, 9, 10, 13}, the correct operations using the 4 numbers are: (10 - 4) * (13 - 9) = 24. Each number in the sequence can only be used once.\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ @@ -273,7 +329,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "def get_current_numbers(y: str) -> str:\n", + " last_line = y.strip().split('\\n')[-1]\n", + " return last_line.split('left: ')[-1].split(')')[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -290,10 +357,10 @@ "# Generation\n", "def generate_thoughts(prompt):\n", " \n", - " current_numbers = get_current_numbers(prompt)\n", + " current_numbers = get_current_numbers(prompt) # current_numbers = get_current_numbers(y if y else x)\n", " prompt = propose_prompt.format(input=current_numbers)\n", " \n", - " thoughts = mistral(prompt)[0].split('\\n') # TODO: Test this out\n", + " thoughts = gpt(prompt)[0].split('\\n')\n", "\n", " return thoughts\n", "\n", @@ -304,6 +371,71 @@ " return [y + _ + '\\n' for _ in proposals]\n" ] }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PROMPT: ['']\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'list' object has no attribute 'strip'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[59], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, num_of_steps):\n\u001b[0;32m----> 5\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_thoughts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mthoughts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mThoughts: \u001b[39m\u001b[38;5;124m'\u001b[39m, thoughts)\n", + "Cell \u001b[0;32mIn[52], line 15\u001b[0m, in \u001b[0;36mgenerate_thoughts\u001b[0;34m(prompt)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mgenerate_thoughts\u001b[39m(prompt):\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPROMPT: \u001b[39m\u001b[38;5;124m'\u001b[39m, prompt)\n\u001b[0;32m---> 15\u001b[0m current_numbers \u001b[38;5;241m=\u001b[39m \u001b[43mget_current_numbers\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m prompt \u001b[38;5;241m=\u001b[39m propose_prompt\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39mcurrent_numbers)\n\u001b[1;32m 18\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m gpt(prompt)[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", + "Cell \u001b[0;32mIn[57], line 2\u001b[0m, in \u001b[0;36mget_current_numbers\u001b[0;34m(y)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_current_numbers\u001b[39m(y: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[0;32m----> 2\u001b[0m last_line \u001b[38;5;241m=\u001b[39m \u001b[43my\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstrip\u001b[49m()\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m last_line\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mleft: \u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n", + "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'strip'" + ] + } + ], + "source": [ + "num_of_steps = 1\n", + "thoughts = ['']\n", + "\n", + "for _ in range(0, num_of_steps):\n", + " thoughts = generate_thoughts(thoughts)\n", + " print('Thoughts: ', thoughts)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import itertools\n", + "num_of_steps = 1\n", + "thoughts = ['']\n", + "\n", + " new_ys = [get_proposals(task, x, y) for y in ys]\n", + " # new_ys = list(itertools.chain(*new_ys))\n", + " # print('FINAL YS: ', new_ys)\n", + " \n", + "for step in range(0, num_of_steps):\n", + " \n", + " # Thought Generation\n", + " thoughts = generate_thoughts(thoughts)\n", + " #thoughts = list(itertools.chain(*thoughts))\n", + " ids = list(range(len(thoughts)))\n", + "\n", + " print('THOUGHTS: ', thoughts)\n", + "\n", + " # Thought evaluation\n", + " #values = evaluate_thoughts(thoughts)\n", + "\n", + " # Search algorithm\n", + "\n" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -339,7 +471,7 @@ " value_names = [_.split('\\n')[-1] for _ in value_outputs]\n", " value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} \n", " \n", - " for name, value in value_maps.items():\n", + " for name, value in value_map.items():\n", " value = sum(value * value_names.count(name))\n", "\n", " return value\n", @@ -396,23 +528,86 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 52, "metadata": {}, "outputs": [], + "source": [ + "# def propose_prompt_wrap(x: str, y: str='') -> str:\n", + "# current_numbers = get_current_numbers(y if y else x)\n", + "# if current_numbers == '24':\n", + "# prompt = cot_prompt.format(input=x) + 'Steps:' + y\n", + "# # print([prompt])\n", + "# else:\n", + "# prompt = propose_prompt.format(input=current_numbers)\n", + "# return prompt\n", + " \n", + "\n", + "# Generation\n", + "def generate_thoughts(prompt):\n", + " for p in prompt:\n", + " \n", + " \n", + " print('PROMPT: ', prompt)\n", + " current_numbers = get_current_numbers(prompt)\n", + " prompt = propose_prompt.format(input=current_numbers)\n", + " \n", + " thoughts = gpt(prompt)[0].split('\\n')\n", + "\n", + " return thoughts\n", + "\n", + "\n", + "def get_proposals(task, x, y): \n", + " propose_prompt = task.propose_prompt_wrap(x, y)\n", + " proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\\n')\n", + " return [y + _ + '\\n' for _ in proposals]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PROMPT: []\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'list' object has no attribute 'strip'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[53], line 8\u001b[0m\n\u001b[1;32m 3\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, num_of_steps):\n\u001b[1;32m 6\u001b[0m \n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Thought Generation\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_thoughts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mthoughts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m#thoughts = list(itertools.chain(*thoughts))\u001b[39;00m\n\u001b[1;32m 10\u001b[0m ids \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(thoughts)))\n", + "Cell \u001b[0;32mIn[52], line 15\u001b[0m, in \u001b[0;36mgenerate_thoughts\u001b[0;34m(prompt)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mgenerate_thoughts\u001b[39m(prompt):\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPROMPT: \u001b[39m\u001b[38;5;124m'\u001b[39m, prompt)\n\u001b[0;32m---> 15\u001b[0m current_numbers \u001b[38;5;241m=\u001b[39m \u001b[43mget_current_numbers\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m prompt \u001b[38;5;241m=\u001b[39m propose_prompt\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39mcurrent_numbers)\n\u001b[1;32m 18\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m gpt(prompt)[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", + "Cell \u001b[0;32mIn[44], line 2\u001b[0m, in \u001b[0;36mget_current_numbers\u001b[0;34m(y)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_current_numbers\u001b[39m(y: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[0;32m----> 2\u001b[0m last_line \u001b[38;5;241m=\u001b[39m \u001b[43my\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstrip\u001b[49m()\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m last_line\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mleft: \u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n", + "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'strip'" + ] + } + ], "source": [ "import itertools\n", - "num_of_steps = 4\n", - "thoughts = []\n", + "num_of_steps = 1\n", + "thoughts = ['']\n", "\n", - "for step in num_of_steps:\n", + " new_ys = [get_proposals(task, x, y) for y in ys]\n", + " # new_ys = list(itertools.chain(*new_ys))\n", + " # print('FINAL YS: ', new_ys)\n", + " \n", + "for step in range(0, num_of_steps):\n", " \n", " # Thought Generation\n", " thoughts = generate_thoughts(thoughts)\n", - " thoughts = list(itertools.chain(*thoughts))\n", + " #thoughts = list(itertools.chain(*thoughts))\n", " ids = list(range(len(thoughts)))\n", "\n", + " print('THOUGHTS: ', thoughts)\n", + "\n", " # Thought evaluation\n", - " values = evaluate_thoughts(thoughts)\n", + " #values = evaluate_thoughts(thoughts)\n", "\n", " # Search algorithm\n", "\n" From 63d3b1d83209cb849b4516a998473b0d15185636 Mon Sep 17 00:00:00 2001 From: rahulbshrestha Date: Mon, 3 Jun 2024 03:16:36 +0200 Subject: [PATCH 4/6] Added search algorithm Signed-off-by: rahulbshrestha --- tree-of-thoughts.ipynb | 356 +++++++++++++++-------------------------- 1 file changed, 126 insertions(+), 230 deletions(-) diff --git a/tree-of-thoughts.ipynb b/tree-of-thoughts.ipynb index aa3f9122a9..f5426ed05e 100644 --- a/tree-of-thoughts.ipynb +++ b/tree-of-thoughts.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 62, "metadata": {}, "outputs": [], "source": [ @@ -114,22 +114,18 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ - "global response\n", - "\n", "def gpt(prompt, model=\"gpt-4\", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:\n", " \n", " messages = [{\"role\": \"user\", \"content\": prompt}]\n", - " \n", " outputs = []\n", "\n", - " res = client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)\n", - " response = res\n", - "\n", - " for choice in res.choices:\n", + " response = client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)\n", + " \n", + " for choice in response.choices:\n", " outputs.extend([choice.message.content])\n", "\n", " return outputs " @@ -172,10 +168,12 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 64, "metadata": {}, "outputs": [], "source": [ + "# Prompts\n", + "\n", "# 5-shot\n", "standard_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24.\n", "Input: 4 4 6 8\n", @@ -191,7 +189,8 @@ "Input: {input}\n", "'''\n", "\n", - "# 5-shot\n", + "\n", + "# PROMPTS FOR THOUGHT GENERATION\n", "cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.\n", "Input: 4 4 6 8\n", "Steps:\n", @@ -241,6 +240,8 @@ "Possible next steps:\n", "'''\n", "\n", + "# PROMPTS FOR THOUGHT EVALUATION\n", + "\n", "value_prompt = '''Evaluate if given numbers can reach 24 (sure/likely/impossible)\n", "10 14\n", "10 + 14 = 24\n", @@ -329,111 +330,34 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "def get_current_numbers(y: str) -> str:\n", - " last_line = y.strip().split('\\n')[-1]\n", - " return last_line.split('left: ')[-1].split(')')[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [], - "source": [ - "# def propose_prompt_wrap(x: str, y: str='') -> str:\n", - "# current_numbers = get_current_numbers(y if y else x)\n", - "# if current_numbers == '24':\n", - "# prompt = cot_prompt.format(input=x) + 'Steps:' + y\n", - "# # print([prompt])\n", - "# else:\n", - "# prompt = propose_prompt.format(input=current_numbers)\n", - "# return prompt\n", - " \n", - "\n", - "# Generation\n", - "def generate_thoughts(prompt):\n", - " \n", - " current_numbers = get_current_numbers(prompt) # current_numbers = get_current_numbers(y if y else x)\n", - " prompt = propose_prompt.format(input=current_numbers)\n", - " \n", - " thoughts = gpt(prompt)[0].split('\\n')\n", - "\n", - " return thoughts\n", "\n", + " last_line = y.strip().split('\\n')[-1]\n", + " return last_line.split('left: ')[-1].split(')')[0]\n", "\n", - "def get_proposals(task, x, y): \n", - " propose_prompt = task.propose_prompt_wrap(x, y)\n", - " proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\\n')\n", - " return [y + _ + '\\n' for _ in proposals]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PROMPT: ['']\n" - ] - }, - { - "ename": "AttributeError", - "evalue": "'list' object has no attribute 'strip'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[59], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, num_of_steps):\n\u001b[0;32m----> 5\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_thoughts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mthoughts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mThoughts: \u001b[39m\u001b[38;5;124m'\u001b[39m, thoughts)\n", - "Cell \u001b[0;32mIn[52], line 15\u001b[0m, in \u001b[0;36mgenerate_thoughts\u001b[0;34m(prompt)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mgenerate_thoughts\u001b[39m(prompt):\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPROMPT: \u001b[39m\u001b[38;5;124m'\u001b[39m, prompt)\n\u001b[0;32m---> 15\u001b[0m current_numbers \u001b[38;5;241m=\u001b[39m \u001b[43mget_current_numbers\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m prompt \u001b[38;5;241m=\u001b[39m propose_prompt\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39mcurrent_numbers)\n\u001b[1;32m 18\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m gpt(prompt)[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", - "Cell \u001b[0;32mIn[57], line 2\u001b[0m, in \u001b[0;36mget_current_numbers\u001b[0;34m(y)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_current_numbers\u001b[39m(y: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[0;32m----> 2\u001b[0m last_line \u001b[38;5;241m=\u001b[39m \u001b[43my\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstrip\u001b[49m()\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m last_line\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mleft: \u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n", - "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'strip'" - ] - } - ], - "source": [ - "num_of_steps = 1\n", - "thoughts = ['']\n", "\n", - "for _ in range(0, num_of_steps):\n", - " thoughts = generate_thoughts(thoughts)\n", - " print('Thoughts: ', thoughts)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import itertools\n", - "num_of_steps = 1\n", - "thoughts = ['']\n", + "def generate_thoughts(data, thoughts):\n", "\n", - " new_ys = [get_proposals(task, x, y) for y in ys]\n", - " # new_ys = list(itertools.chain(*new_ys))\n", - " # print('FINAL YS: ', new_ys)\n", - " \n", - "for step in range(0, num_of_steps):\n", + " new_thoughts = []\n", " \n", - " # Thought Generation\n", - " thoughts = generate_thoughts(thoughts)\n", - " #thoughts = list(itertools.chain(*thoughts))\n", - " ids = list(range(len(thoughts)))\n", + " for thought in thoughts:\n", "\n", - " print('THOUGHTS: ', thoughts)\n", + " # Prepare prompt\n", + " current_numbers = get_current_numbers(thought if thought else data)\n", + " if current_numbers == '24':\n", + " prompt = cot_prompt.format(input=data) + 'Steps: ' + thought\n", + " else:\n", + " prompt = propose_prompt.format(input=current_numbers)\n", "\n", - " # Thought evaluation\n", - " #values = evaluate_thoughts(thoughts)\n", + " # Generate thoughts with prompt\n", + " proposals = gpt(prompt, n=1, stop=None)[0].split('\\n')\n", + " new_thoughts.extend([thought + _ + '\\n' for _ in proposals])\n", "\n", - " # Search algorithm\n", - "\n" + " return new_thoughts" ] }, { @@ -445,46 +369,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 66, "metadata": {}, "outputs": [], "source": [ - "def get_current_numbers(y: str) -> str:\n", + "def value_prompt_wrap(x: str, y: str) -> str:\n", " last_line = y.strip().split('\\n')[-1]\n", - " return last_line.split('left: ')[-1].split(')')[0]\n", - "\n", - "# def value_prompt_wrap(x: str, y: str) -> str:\n", - "# last_line = y.strip().split('\\n')[-1]\n", - "# if 'left: ' not in last_line: # last step\n", - "# ans = last_line.lower().replace('answer: ', '')\n", - "# return value_last_step_prompt.format(input=x, answer=ans)\n", - "# current_numbers = get_current_numbers(y)\n", - "# return value_prompt.format(input=current_numbers) # This replaces the input term\n", - "\n", - "\n", - "def get_value(thought):\n", - " \n", - " current_numbers = get_current_numbers(thought)\n", - " value_prompt = value_prompt.format(input=current_numbers)\n", - " value_outputs = mistral(value_prompt)\n", - " \n", + " if 'left: ' not in last_line: # last step\n", + " ans = last_line.lower().replace('answer: ', '')\n", + " # print([value_last_step_prompt.format(input=x, answer=ans)])\n", + " return value_last_step_prompt.format(input=x, answer=ans)\n", + " current_numbers = get_current_numbers(y)\n", + " return value_prompt.format(input=current_numbers)\n", + "\n", + "def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:\n", + " if len(y.strip().split('\\n')) == 4 and 'answer' not in y.lower():\n", + " return 0\n", " value_names = [_.split('\\n')[-1] for _ in value_outputs]\n", - " value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} \n", - " \n", - " for name, value in value_map.items():\n", - " value = sum(value * value_names.count(name))\n", - "\n", + " value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} # TODO: ad hoc\n", + " value = sum(value * value_names.count(name) for name, value in value_map.items())\n", " return value\n", - "\n", - "\n", - "def evaluate_thoughts(thoughts):\n", " \n", - " values = []\n", - "\n", + "def evaluate_thoughts(data, thoughts, n_evaluate_sample):\n", + " scores = []\n", " for thought in thoughts:\n", - " value = get_value(thought)\n", - "\n", - " values.append(value)\n" + " value_prompt = value_prompt_wrap(data, thought)\n", + " value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)\n", + " value = value_outputs_unwrap(data, thought, value_outputs)\n", + " scores.append(value)\n", + " return scores" ] }, { @@ -496,11 +409,12 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 67, "metadata": {}, "outputs": [], "source": [ - "# TODO: Implement search algorithm" + "# selected_ids = sorted(ids, key=lambda x: scores[x], reverse=True)[:n_select_sample] # Take top n_select_sample from list based on scores\n", + "# select_new_thoughts = [new_thoughts[select_id] for select_id in selected_ids] " ] }, { @@ -514,132 +428,114 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We'll test our implementation with some sample data i.e the sequence 4 5 6 10. If ToT works sucessfully, it should output the operations that can be performed to reach 24." + "We'll test our implementation with some sample data i.e the sequence 4 5 6 10. We'll comebine the functions from above. If ToT works sucessfully, it should output the operations that can be performed to reach 24." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ - "data = \"4 5 6 10\"" + "data = '4 5 6 10'" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ - "# def propose_prompt_wrap(x: str, y: str='') -> str:\n", - "# current_numbers = get_current_numbers(y if y else x)\n", - "# if current_numbers == '24':\n", - "# prompt = cot_prompt.format(input=x) + 'Steps:' + y\n", - "# # print([prompt])\n", - "# else:\n", - "# prompt = propose_prompt.format(input=current_numbers)\n", - "# return prompt\n", - " \n", + "from functools import partial\n", "\n", - "# Generation\n", - "def generate_thoughts(prompt):\n", - " for p in prompt:\n", - " \n", - " \n", - " print('PROMPT: ', prompt)\n", - " current_numbers = get_current_numbers(prompt)\n", - " prompt = propose_prompt.format(input=current_numbers)\n", - " \n", - " thoughts = gpt(prompt)[0].split('\\n')\n", "\n", - " return thoughts\n", + "def solve():\n", + "\n", + " backend='gpt-3.5-turbo'\n", + " temperature=0.7\n", + " n_evaluate_sample=3\n", + " n_select_sample=5\n", + "\n", + " global gpt\n", + " gpt = partial(gpt, model=backend, temperature=temperature)\n", + " print(gpt)\n", + "\n", + " thoughts = ['']\n", + " data = '4 5 6 10'\n", + "\n", + " steps = 4\n", + "\n", + " for step in range(steps):\n", + "\n", + " print('STEP NUMBER ::::: ', step)\n", + " print('(Step 0) Thoughts so far: ', thoughts)\n", + "\n", + " # Step 1: Thought Generation\n", + " new_thoughts = generate_thoughts(data, thoughts)\n", + " ids = list(range(len(new_thoughts)))\n", + "\n", + " print('(Step 1) new_thoughts: ', new_thoughts)\n", + " print('(Step 1) ids ', ids)\n", "\n", + " # Step 2: Thought Evaluation\n", + " scores = evaluate_thoughts(data, new_thoughts, n_evaluate_sample)\n", + " print('(Step 2) Values: ', scores)\n", "\n", - "def get_proposals(task, x, y): \n", - " propose_prompt = task.propose_prompt_wrap(x, y)\n", - " proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\\n')\n", - " return [y + _ + '\\n' for _ in proposals]\n" + " # Step 3: Search algorithm\n", + " selected_ids = sorted(ids, key=lambda x: scores[x], reverse=True)[:n_select_sample] # Take top n_select_sample from list based on scores\n", + " select_new_thoughts = [new_thoughts[select_id] for select_id in selected_ids] \n", + " print('(Step 3) Selected new thoughts: ', select_new_thoughts)\n", + "\n", + " thoughts = select_new_thoughts\n", + "\n", + " print('-----------------------------------------------------------------------------------------------------------------')\n", + " \n", + " return thoughts" ] }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 70, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "PROMPT: []\n" + "functools.partial(, model='gpt-3.5-turbo', temperature=0.7)\n", + "STEP NUMBER ::::: 0\n", + "(Step 0) Thoughts so far: ['']\n", + "(Step 1) new_thoughts: ['4 + 5 = 9 (left: 6 9 10)\\n', '5 + 6 = 11 (left: 4 11 10)\\n', '6 + 4 = 10 (left: 10 5 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\n', '6 * 5 = 30 (left: 4 30 10)\\n', '10 - 5 = 5 (left: 4 6 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\n', '6 / 4 = 1.5 (left: 5 1.5 10)\\n']\n", + "(Step 1) ids [0, 1, 2, 3, 4, 5, 6, 7]\n", + "(Step 2) Values: [2.001, 1.002, 3.0, 21.001, 41.0, 41.0, 22.0, 2.0]\n", + "(Step 3) Selected new thoughts: ['6 * 5 = 30 (left: 4 30 10)\\n', '10 - 5 = 5 (left: 4 6 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n', '6 + 4 = 10 (left: 10 5 6)\\n']\n", + "-----------------------------------------------------------------------------------------------------------------\n", + "STEP NUMBER ::::: 1\n", + "(Step 0) Thoughts so far: ['6 * 5 = 30 (left: 4 30 10)\\n', '10 - 5 = 5 (left: 4 6 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n', '6 + 4 = 10 (left: 10 5 6)\\n']\n", + "(Step 1) new_thoughts: ['6 * 5 = 30 (left: 4 30 10)\\n4 + 30 = 34 (left: 10 34)\\n', '6 * 5 = 30 (left: 4 30 10)\\n30 / 4 = 7.5 (left: 7.5 10)\\n', '6 * 5 = 30 (left: 4 30 10)\\n10 + 4 = 14 (left: 14 30)\\n', '6 * 5 = 30 (left: 4 30 10)\\n4 * 30 = 120 (left: 10 120)\\n', '6 * 5 = 30 (left: 4 30 10)\\n30 - 4 = 26 (left: 26 10)\\n', '6 * 5 = 30 (left: 4 30 10)\\n10 / 4 = 2.5 (left: 2.5 10)\\n', '6 * 5 = 30 (left: 4 30 10)\\n30 - 10 = 20 (left: 4 20)\\n', '6 * 5 = 30 (left: 4 30 10)\\n30 / 10 = 3 (left: 3 4)\\n', '10 - 5 = 5 (left: 4 6 5)\\n4 + 6 = 10 (left: 5 10)\\n', '10 - 5 = 5 (left: 4 6 5)\\n6 - 4 = 2 (left: 2 5)\\n', '10 - 5 = 5 (left: 4 6 5)\\n4 * 6 = 24 (left: 5 24)\\n', '10 - 5 = 5 (left: 4 6 5)\\n6 / 4 = 1.5 (left: 1.5 5)\\n', '10 - 5 = 5 (left: 4 6 5)\\n5 + 4 = 9 (left: 6 9)\\n', '10 - 5 = 5 (left: 4 6 5)\\n5 - 4 = 1 (left: 1 6)\\n', '10 - 5 = 5 (left: 4 6 5)\\nInput: 3 9 4 12\\n', '10 - 5 = 5 (left: 4 6 5)\\nPossible next steps:\\n', '10 - 5 = 5 (left: 4 6 5)\\n3 + 9 = 12 (left: 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 / 3 = 3 (left: 3 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n4 * 3 = 12 (left: 12 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n12 - 3 = 9 (left: 9 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 / 3 = 3 (left: 4 3 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n12 - 4 = 8 (left: 3 8 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n12 / 4 = 3 (left: 3 9 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 + 3 = 12 (left: 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 - 3 = 6 (left: 6 4 12)\\n', '10 - 6 = 4 (left: 4 5 10)\\n4 + 5 = 9 (left: 9 10)\\n', '10 - 6 = 4 (left: 4 5 10)\\n5 * 4 = 20 (left: 20 10)\\n', '10 - 6 = 4 (left: 4 5 10)\\n10 - 4 = 6 (left: 6 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\n10 / 5 = 2 (left: 4 2)\\n', '10 - 6 = 4 (left: 4 5 10)\\n5 - 4 = 1 (left: 1 10)\\n', '10 - 6 = 4 (left: 4 5 10)\\n10 - 5 = 5 (left: 4 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\nInput: 1 3 5 7\\n', '10 - 6 = 4 (left: 4 5 10)\\nPossible next steps:\\n', '10 - 6 = 4 (left: 4 5 10)\\n1 + 3 = 4 (left: 4 5 7)\\n', '10 - 6 = 4 (left: 4 5 10)\\n3 * 5 = 15 (left: 1 15 7)\\n', '10 - 6 = 4 (left: 4 5 10)\\n7 - 1 = 6 (left: 6 5 7)\\n', '10 - 6 = 4 (left: 4 5 10)\\n5 - 3 = 2 (left: 1 2 7)\\n', '10 - 6 = 4 (left: 4 5 10)\\n7 - 5 = 2 (left: 1 3 2)\\n', '10 - 6 = 4 (left: 4 5 10)\\n7 - 3 = 4 (left: 1 5 4)\\n', '10 - 6 = 4 (left: 4 5 10)\\n7 / 1 = 7 (left: 3 5 7)\\n', '10 - 4 = 6 (left: 6 5 10)\\n6 + 5 = 11 (left: 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n6 * 5 = 30 (left: 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n6 - 5 = 1 (left: 1 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n5 * 10 = 50 (left: 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\n5 + 10 = 15 (left: 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\n10 - 6 = 4 (left: 4 5)\\n', '10 - 4 = 6 (left: 6 5 10)\\n10 / 5 = 2 (left: 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\nInput: 3 9 27\\n', '10 - 4 = 6 (left: 6 5 10)\\nPossible next steps:\\n', '10 - 4 = 6 (left: 6 5 10)\\n3 * 9 = 27 (left: 27)\\n', '10 - 4 = 6 (left: 6 5 10)\\n3 + 9 = 12 (left: 12 27)\\n', '10 - 4 = 6 (left: 6 5 10)\\n9 / 3 = 3 (left: 3 27)\\n', '10 - 4 = 6 (left: 6 5 10)\\n9 * 27 = 243 (left: 3)\\n', '10 - 4 = 6 (left: 6 5 10)\\n9 + 27 = 36 (left: 3)\\n', '10 - 4 = 6 (left: 6 5 10)\\n27 / 3 = 9 (left: 9)\\n', '10 - 4 = 6 (left: 6 5 10)\\n27 - 9 = 18 (left: 3)\\n', '10 - 4 = 6 (left: 6 5 10)\\n9 - 3 = 6 (left: 6 27)\\n', '10 - 4 = 6 (left: 6 5 10)\\nInput: 4 2 2\\n', '10 - 4 = 6 (left: 6 5 10)\\nPossible next steps:\\n', '10 - 4 = 6 (left: 6 5 10)\\n4 + 2 = 6 (left: 2 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\n4 * 2 = 8 (left: 2 8)\\n', '10 - 4 = 6 (left: 6 5 10)\\n4 / 2 = 2 (left: 2 2)\\n', '10 - 4 = 6 (left: 6 5 10)\\n2 * 2 = 4 (left: 4 2)\\n', '10 - 4 = 6 (left: 6 5 10)\\n2 / 2 = 1 (left: 1 2)\\n', '10 - 4 = 6 (left: 6 5 10)\\n2 + 2 = 4 (left: 4)\\n', '10 - 4 = 6 (left: 6 5 10)\\n2 - 2 = 0 (left: 4)\\n', '10 - 4 = 6 (left: 6 5 10)\\n4 - 2 = 2 (left: 2 2)\\n', '6 + 4 = 10 (left: 10 5 6)\\n10 + 5 = 15 (left: 6 15)\\n', '6 + 4 = 10 (left: 10 5 6)\\n10 / 5 = 2 (left: 2 6)\\n', '6 + 4 = 10 (left: 10 5 6)\\n10 - 5 = 5 (left: 5 6)\\n', '6 + 4 = 10 (left: 10 5 6)\\n5 + 6 = 11 (left: 10 11)\\n', '6 + 4 = 10 (left: 10 5 6)\\n6 - 5 = 1 (left: 1 6)\\n']\n", + "(Step 1) ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]\n", + "(Step 2) Values: [0.003, 0.003, 20.002, 0.002, 0.003, 0.003, 40.001, 21.0, 0.003, 60.0, 0.002, 1.001, 0.003, 60.0, 0.002, 0.003, 0.003, 40.001, 40.001, 60.0, 60.0, 40.001, 1.002, 0.003, 60.0, 0.003, 0.001, 0.002, 1.001, 60.0, 40.001, 0.003, 0.003, 1.002, 2.0, 2.001, 1.002, 3.0, 2.0, 3.0, 0.002, 0.002, 60.0, 0.003, 0.003, 40.001, 0.003, 0.003, 0.003, 0.002, 0.003, 60.0, 0.0, 0.0, 0.003, 0.002, 0.003, 0.002, 0.003, 1.0, 1.002, 1.0, 1.0, 0.003, 0.0, 0.002, 21.0, 0.003, 0.003, 0.002, 0.003, 60.0]\n", + "(Step 3) Selected new thoughts: ['10 - 5 = 5 (left: 4 6 5)\\n6 - 4 = 2 (left: 2 5)\\n', '10 - 5 = 5 (left: 4 6 5)\\n5 - 4 = 1 (left: 1 6)\\n', '10 - 5 = 5 (left: 4 6 5)\\n12 - 3 = 9 (left: 9 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 / 3 = 3 (left: 4 3 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 - 3 = 6 (left: 6 4 12)\\n']\n", + "-----------------------------------------------------------------------------------------------------------------\n" ] }, { - "ename": "AttributeError", - "evalue": "'list' object has no attribute 'strip'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[53], line 8\u001b[0m\n\u001b[1;32m 3\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, num_of_steps):\n\u001b[1;32m 6\u001b[0m \n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Thought Generation\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_thoughts\u001b[49m\u001b[43m(\u001b[49m\u001b[43mthoughts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m#thoughts = list(itertools.chain(*thoughts))\u001b[39;00m\n\u001b[1;32m 10\u001b[0m ids \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(thoughts)))\n", - "Cell \u001b[0;32mIn[52], line 15\u001b[0m, in \u001b[0;36mgenerate_thoughts\u001b[0;34m(prompt)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mgenerate_thoughts\u001b[39m(prompt):\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPROMPT: \u001b[39m\u001b[38;5;124m'\u001b[39m, prompt)\n\u001b[0;32m---> 15\u001b[0m current_numbers \u001b[38;5;241m=\u001b[39m \u001b[43mget_current_numbers\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m prompt \u001b[38;5;241m=\u001b[39m propose_prompt\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39mcurrent_numbers)\n\u001b[1;32m 18\u001b[0m thoughts \u001b[38;5;241m=\u001b[39m gpt(prompt)[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", - "Cell \u001b[0;32mIn[44], line 2\u001b[0m, in \u001b[0;36mget_current_numbers\u001b[0;34m(y)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_current_numbers\u001b[39m(y: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[0;32m----> 2\u001b[0m last_line \u001b[38;5;241m=\u001b[39m \u001b[43my\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstrip\u001b[49m()\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m last_line\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mleft: \u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m)\u001b[39m\u001b[38;5;124m'\u001b[39m)[\u001b[38;5;241m0\u001b[39m]\n", - "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'strip'" - ] + "data": { + "text/plain": [ + "['10 - 5 = 5 (left: 4 6 5)\\n6 - 4 = 2 (left: 2 5)\\n',\n", + " '10 - 5 = 5 (left: 4 6 5)\\n5 - 4 = 1 (left: 1 6)\\n',\n", + " '10 - 5 = 5 (left: 4 6 5)\\n12 - 3 = 9 (left: 9 4 12)\\n',\n", + " '10 - 5 = 5 (left: 4 6 5)\\n9 / 3 = 3 (left: 4 3 12)\\n',\n", + " '10 - 5 = 5 (left: 4 6 5)\\n9 - 3 = 6 (left: 6 4 12)\\n']" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "import itertools\n", - "num_of_steps = 1\n", - "thoughts = ['']\n", - "\n", - " new_ys = [get_proposals(task, x, y) for y in ys]\n", - " # new_ys = list(itertools.chain(*new_ys))\n", - " # print('FINAL YS: ', new_ys)\n", - " \n", - "for step in range(0, num_of_steps):\n", - " \n", - " # Thought Generation\n", - " thoughts = generate_thoughts(thoughts)\n", - " #thoughts = list(itertools.chain(*thoughts))\n", - " ids = list(range(len(thoughts)))\n", - "\n", - " print('THOUGHTS: ', thoughts)\n", - "\n", - " # Thought evaluation\n", - " #values = evaluate_thoughts(thoughts)\n", - "\n", - " # Search algorithm\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# import itertools\n", - "\n", - "# num_of_steps = 4\n", - "\n", - "# for step in num_of_steps:\n", - "# thoughts = \n", - " \n", - "# # Generation (Propose / Sample)\n", - "# new_ys = [get_proposals(x, y) for y in ys]\n", - "# new_ys = list(itertools.chain(*new_ys))\n", - "# ids = list(range(len(new_ys)))\n", - "\n", - "# # Evaluation (Value / Vote)\n", - "# values = get_values(task, x, new_ys, args.n_evaluate_sample)\n", - " \n", - "# # Selection (Sample/Greedy)\n", - "# select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]\n", - "# select_new_ys = [new_ys[select_id] for select_id in select_ids]\n", - " \n", - "# #infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})\n", - "# ys = select_new_ys\n" + "solve()" ] } ], From 2755e3c7e2c3ef9a70dff1a90b30cdb57d0ff84d Mon Sep 17 00:00:00 2001 From: rahulbshrestha Date: Tue, 4 Jun 2024 17:46:26 +0200 Subject: [PATCH 5/6] Update instructions Signed-off-by: rahulbshrestha --- tree-of-thoughts.ipynb | 197 +++++++++++++++++++++++++++++------------ 1 file changed, 138 insertions(+), 59 deletions(-) diff --git a/tree-of-thoughts.ipynb b/tree-of-thoughts.ipynb index f5426ed05e..4cba55edae 100644 --- a/tree-of-thoughts.ipynb +++ b/tree-of-thoughts.ipynb @@ -18,7 +18,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Tree of Thoughts (ToT) is a framework used by LLMs to solve complex reasoning problems. The intermediate steps in a reasoning process are split into “thoughts”, with the ToT algorithm encouraging exploration of these thoughts through search algorithms.\n" + "Tree of Thoughts (ToT) is a framework used by LLMs to solve complex reasoning problems. The intermediate steps in a reasoning process are split into \"thoughts\" as similar to Chain of Thought, but there are multiple thoughts generated per step, resulting in a tree-like structure. A search algorithm is implemented allowing ToT to explore among the thoughts." ] }, { @@ -32,7 +32,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We'll be using Hugging face ```transformers``` to generate text with our LLMs. First, we start off by importing the necessary libraries." + "In this tutorial, we'll use two different LLMs: Mistral and GPT-4. \n", + "\n", + "We can use the Hugging face ```transformers``` library to generate text with our LLMs. First, we start off by importing the necessary libraries." ] }, { @@ -142,28 +144,47 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "ToT can be broken down into 4 key steps:\n", + "ToT can be broken down into 4 key steps as described below. \n", "\n", - "(a) Thought Decomposition\n", + "(Step 0): Thought Decomposition\n", "\n", - "(b) Thought Generation\n", + "(Step 1): Thought Generation\n", "- In this step, the LLM is prompted to generate thoughts by either one of two ways:\n", " - Sample: The thoughts are generated by sampling i.i.d thoughts from a Chain of Thought prompt.\n", " - Propose: The thoughts are propsed sequentially depending on the previous prompts. \n", "\n", - "(c) Thought Evaluation\n", + "(Step 2): Thought Evaluation\n", "- The LLMs are prompted to evaluate the thoughts generated in the previous step, by either: \n", - " - Value:\n", - " - Vote: \n", + " - Value: The thoughts are assigned a score individually. \n", + " - Vote: All of thoughts are evaluated together and assigned a score.\n", "\n", - "(d) Search Algorithm\n", + "(Step 3): Search Algorithm\n", + "- The search algorithm is used to explore the thoughts generated in the previous steps:\n", + " - Breadth first search\n", + " - Depth first search\n", "\n", - ".\n", "\n", + "In this tutorial, we'll be using ToT with OpenAI's GPT-3.5 Turbo to solve the Game of 24.\n", "\n", - "In this tutorial, we'll be using ToT with Mistral to solve the Game of 24.\n", + "The Game of 24 is a task where given a sequence of 4 numbers, we’ll need to find the correct mathematical operations (add, subtract, multiply, divide) that’ll lead to the number 24. For example, if the sequence is {4, 9, 10, 13}, the correct operations using the 4 numbers are: (10 - 4) * (13 - 9) = 24. Each number in the sequence can only be used once.\n", "\n", - "The Game of 24 is a task where given a sequence of 4 numbers, we’ll need to find the correct mathematical operations (add, subtract, multiply, divide) that’ll lead to the number 24. For example, if the sequence is {4, 9, 10, 13}, the correct operations using the 4 numbers are: (10 - 4) * (13 - 9) = 24. Each number in the sequence can only be used once.\n" + "In this tutorial, we'll be using 'Propose' for Thought Generation, 'Value' for Thought Evaluation and 'Breadth-first search' for the search algorithm." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.0 Thought Decomposition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define the prompts necessary for the ToT framework. Each prompt will be used in different stages of the ToT process. The propose_prompt is meant to guide the LLM to come up with possible next steps, given a certain point in the problem.\n", + "\n", + "The value_prompt assigns a classification (sure/likely/impossible) to each thought, depending on how likely it is to reach the number 24 given the current sequence of thoughts." ] }, { @@ -325,7 +346,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "First, we'll define functions necessary for \"Thought Generation\"." + "#### 2.1 Thought Generation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we'll define functions necessary for \"Thought Generation\".\n" ] }, { @@ -339,6 +367,15 @@ " last_line = y.strip().split('\\n')[-1]\n", " return last_line.split('left: ')[-1].split(')')[0]\n", "\n", + "def prepare_generate_prompt(current_numbers, thought, data):\n", + " \n", + " if current_numbers == '24':\n", + " prompt = cot_prompt.format(input=data) + 'Steps: ' + thought\n", + " else:\n", + " prompt = propose_prompt.format(input=current_numbers)\n", + "\n", + " return prompt\n", + "\n", "\n", "def generate_thoughts(data, thoughts):\n", "\n", @@ -348,11 +385,8 @@ "\n", " # Prepare prompt\n", " current_numbers = get_current_numbers(thought if thought else data)\n", - " if current_numbers == '24':\n", - " prompt = cot_prompt.format(input=data) + 'Steps: ' + thought\n", - " else:\n", - " prompt = propose_prompt.format(input=current_numbers)\n", - "\n", + " prompt = prepare_generate_prompt(current_numbers, thought, data)\n", + " \n", " # Generate thoughts with prompt\n", " proposals = gpt(prompt, n=1, stop=None)[0].split('\\n')\n", " new_thoughts.extend([thought + _ + '\\n' for _ in proposals])\n", @@ -360,6 +394,13 @@ " return new_thoughts" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.2 Thought Evaluation" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -367,39 +408,81 @@ "Next, we'll create the functions necessary for \"Thought Evaluation\", where each of the thoughts are evaluated by the LLM." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we create the ```prepare_evaluate_prompt``` function which turns our current thought into an evaluation prompt by using the ```value_prompt``` from above." + ] + }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "def value_prompt_wrap(x: str, y: str) -> str:\n", - " last_line = y.strip().split('\\n')[-1]\n", - " if 'left: ' not in last_line: # last step\n", + "def prepare_evaluate_prompt(data: str, thought: str) -> str:\n", + " last_line = thought.strip().split('\\n')[-1]\n", + " if 'left: ' not in last_line:\n", " ans = last_line.lower().replace('answer: ', '')\n", - " # print([value_last_step_prompt.format(input=x, answer=ans)])\n", - " return value_last_step_prompt.format(input=x, answer=ans)\n", - " current_numbers = get_current_numbers(y)\n", - " return value_prompt.format(input=current_numbers)\n", - "\n", - "def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:\n", - " if len(y.strip().split('\\n')) == 4 and 'answer' not in y.lower():\n", + " return value_last_step_prompt.format(input=data, answer=ans) \n", + " current_numbers = get_current_numbers(thought)\n", + " return value_prompt.format(input=current_numbers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We then create the ```evaluate_outputs_unwrap``` which converts the values assigned to each thought into a list of integers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_outputs_unwrap(thought: str, evaluate_outputs: list) -> float:\n", + " if len(thought.strip().split('\\n')) == 4 and 'answer' not in thought.lower():\n", " return 0\n", - " value_names = [_.split('\\n')[-1] for _ in value_outputs]\n", - " value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} # TODO: ad hoc\n", - " value = sum(value * value_names.count(name) for name, value in value_map.items())\n", - " return value\n", - " \n", + " value_names = [_.split('\\n')[-1] for _ in evaluate_outputs]\n", + " value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}\n", + " score = sum(value * value_names.count(name) for name, value in value_map.items())\n", + " return score" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And, finally we wrap the above functions into the ```evaluate``` function." + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ "def evaluate_thoughts(data, thoughts, n_evaluate_sample):\n", " scores = []\n", " for thought in thoughts:\n", - " value_prompt = value_prompt_wrap(data, thought)\n", - " value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)\n", - " value = value_outputs_unwrap(data, thought, value_outputs)\n", - " scores.append(value)\n", + " evaluate_prompt = prepare_evaluate_prompt(data, thought)\n", + " evaluate_outputs = gpt(evaluate_prompt, n=n_evaluate_sample, stop=None)\n", + " print('Value outputs: ', evaluate_outputs)\n", + " score = evaluate_outputs_unwrap(thought, evaluate_outputs)\n", + " scores.append(score)\n", " return scores" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.3 Search algorithm" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -413,8 +496,10 @@ "metadata": {}, "outputs": [], "source": [ - "# selected_ids = sorted(ids, key=lambda x: scores[x], reverse=True)[:n_select_sample] # Take top n_select_sample from list based on scores\n", - "# select_new_thoughts = [new_thoughts[select_id] for select_id in selected_ids] " + "def search_algorithm(new_thoughts, ids, scores, n_select_sample):\n", + " selected_ids = sorted(ids, key=lambda x: scores[x], reverse=True)[:n_select_sample] # Take top n_select_sample from list based on scores\n", + " select_new_thoughts = [new_thoughts[select_id] for select_id in selected_ids] \n", + " return select_new_thoughts" ] }, { @@ -449,17 +534,11 @@ "from functools import partial\n", "\n", "\n", - "def solve():\n", - "\n", - " backend='gpt-3.5-turbo'\n", - " temperature=0.7\n", - " n_evaluate_sample=3\n", - " n_select_sample=5\n", + "def solve(model, temperature, n_evaluate_sample, n_select_sample):\n", "\n", " global gpt\n", - " gpt = partial(gpt, model=backend, temperature=temperature)\n", - " print(gpt)\n", - "\n", + " gpt = partial(gpt, model=model, temperature=temperature)\n", + " \n", " thoughts = ['']\n", " data = '4 5 6 10'\n", "\n", @@ -467,8 +546,8 @@ "\n", " for step in range(steps):\n", "\n", - " print('STEP NUMBER ::::: ', step)\n", - " print('(Step 0) Thoughts so far: ', thoughts)\n", + " print('Step Number ::', step)\n", + " print('(Step 0) Thoughts: ', thoughts)\n", "\n", " # Step 1: Thought Generation\n", " new_thoughts = generate_thoughts(data, thoughts)\n", @@ -479,17 +558,17 @@ "\n", " # Step 2: Thought Evaluation\n", " scores = evaluate_thoughts(data, new_thoughts, n_evaluate_sample)\n", - " print('(Step 2) Values: ', scores)\n", + " print('(Step 2) Scores: ', scores)\n", "\n", " # Step 3: Search algorithm\n", - " selected_ids = sorted(ids, key=lambda x: scores[x], reverse=True)[:n_select_sample] # Take top n_select_sample from list based on scores\n", - " select_new_thoughts = [new_thoughts[select_id] for select_id in selected_ids] \n", - " print('(Step 3) Selected new thoughts: ', select_new_thoughts)\n", + " \n", + " selected_new_thoughts = search_algorithm(new_thoughts, ids, scores, n_select_sample) \n", + " print('(Step 3) Selected new thoughts: ', selected_new_thoughts)\n", "\n", - " thoughts = select_new_thoughts\n", + " thoughts = selected_new_thoughts\n", + "\n", + " print('--------')\n", "\n", - " print('-----------------------------------------------------------------------------------------------------------------')\n", - " \n", " return thoughts" ] }, @@ -535,7 +614,7 @@ } ], "source": [ - "solve()" + "solve(model='gpt-3.5-turbo', temperature=0.7, n_evaluate_sample=3, n_select_sample=5)" ] } ], From 6ed80ebd49340e9a4481440a9abfc65c2e295c06 Mon Sep 17 00:00:00 2001 From: Rahul Shrestha Date: Tue, 25 Jun 2024 16:16:34 +0200 Subject: [PATCH 6/6] Update blog --- tree-of-thoughts.ipynb | 207 ++++++++++++++++++++++++++++++----------- 1 file changed, 153 insertions(+), 54 deletions(-) diff --git a/tree-of-thoughts.ipynb b/tree-of-thoughts.ipynb index 4cba55edae..6856286921 100644 --- a/tree-of-thoughts.ipynb +++ b/tree-of-thoughts.ipynb @@ -39,7 +39,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3a9253adff2a409084d62fe896a7b04e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
, model='gpt-3.5-turbo', temperature=0.7)\n", - "STEP NUMBER ::::: 0\n", - "(Step 0) Thoughts so far: ['']\n", - "(Step 1) new_thoughts: ['4 + 5 = 9 (left: 6 9 10)\\n', '5 + 6 = 11 (left: 4 11 10)\\n', '6 + 4 = 10 (left: 10 5 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\n', '6 * 5 = 30 (left: 4 30 10)\\n', '10 - 5 = 5 (left: 4 6 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\n', '6 / 4 = 1.5 (left: 5 1.5 10)\\n']\n", - "(Step 1) ids [0, 1, 2, 3, 4, 5, 6, 7]\n", - "(Step 2) Values: [2.001, 1.002, 3.0, 21.001, 41.0, 41.0, 22.0, 2.0]\n", - "(Step 3) Selected new thoughts: ['6 * 5 = 30 (left: 4 30 10)\\n', '10 - 5 = 5 (left: 4 6 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n', '6 + 4 = 10 (left: 10 5 6)\\n']\n", - "-----------------------------------------------------------------------------------------------------------------\n", - "STEP NUMBER ::::: 1\n", - "(Step 0) Thoughts so far: ['6 * 5 = 30 (left: 4 30 10)\\n', '10 - 5 = 5 (left: 4 6 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n', '6 + 4 = 10 (left: 10 5 6)\\n']\n", - "(Step 1) new_thoughts: ['6 * 5 = 30 (left: 4 30 10)\\n4 + 30 = 34 (left: 10 34)\\n', '6 * 5 = 30 (left: 4 30 10)\\n30 / 4 = 7.5 (left: 7.5 10)\\n', '6 * 5 = 30 (left: 4 30 10)\\n10 + 4 = 14 (left: 14 30)\\n', '6 * 5 = 30 (left: 4 30 10)\\n4 * 30 = 120 (left: 10 120)\\n', '6 * 5 = 30 (left: 4 30 10)\\n30 - 4 = 26 (left: 26 10)\\n', '6 * 5 = 30 (left: 4 30 10)\\n10 / 4 = 2.5 (left: 2.5 10)\\n', '6 * 5 = 30 (left: 4 30 10)\\n30 - 10 = 20 (left: 4 20)\\n', '6 * 5 = 30 (left: 4 30 10)\\n30 / 10 = 3 (left: 3 4)\\n', '10 - 5 = 5 (left: 4 6 5)\\n4 + 6 = 10 (left: 5 10)\\n', '10 - 5 = 5 (left: 4 6 5)\\n6 - 4 = 2 (left: 2 5)\\n', '10 - 5 = 5 (left: 4 6 5)\\n4 * 6 = 24 (left: 5 24)\\n', '10 - 5 = 5 (left: 4 6 5)\\n6 / 4 = 1.5 (left: 1.5 5)\\n', '10 - 5 = 5 (left: 4 6 5)\\n5 + 4 = 9 (left: 6 9)\\n', '10 - 5 = 5 (left: 4 6 5)\\n5 - 4 = 1 (left: 1 6)\\n', '10 - 5 = 5 (left: 4 6 5)\\nInput: 3 9 4 12\\n', '10 - 5 = 5 (left: 4 6 5)\\nPossible next steps:\\n', '10 - 5 = 5 (left: 4 6 5)\\n3 + 9 = 12 (left: 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 / 3 = 3 (left: 3 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n4 * 3 = 12 (left: 12 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n12 - 3 = 9 (left: 9 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 / 3 = 3 (left: 4 3 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n12 - 4 = 8 (left: 3 8 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n12 / 4 = 3 (left: 3 9 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 + 3 = 12 (left: 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 - 3 = 6 (left: 6 4 12)\\n', '10 - 6 = 4 (left: 4 5 10)\\n4 + 5 = 9 (left: 9 10)\\n', '10 - 6 = 4 (left: 4 5 10)\\n5 * 4 = 20 (left: 20 10)\\n', '10 - 6 = 4 (left: 4 5 10)\\n10 - 4 = 6 (left: 6 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\n10 / 5 = 2 (left: 4 2)\\n', '10 - 6 = 4 (left: 4 5 10)\\n5 - 4 = 1 (left: 1 10)\\n', '10 - 6 = 4 (left: 4 5 10)\\n10 - 5 = 5 (left: 4 5)\\n', '10 - 6 = 4 (left: 4 5 10)\\nInput: 1 3 5 7\\n', '10 - 6 = 4 (left: 4 5 10)\\nPossible next steps:\\n', '10 - 6 = 4 (left: 4 5 10)\\n1 + 3 = 4 (left: 4 5 7)\\n', '10 - 6 = 4 (left: 4 5 10)\\n3 * 5 = 15 (left: 1 15 7)\\n', '10 - 6 = 4 (left: 4 5 10)\\n7 - 1 = 6 (left: 6 5 7)\\n', '10 - 6 = 4 (left: 4 5 10)\\n5 - 3 = 2 (left: 1 2 7)\\n', '10 - 6 = 4 (left: 4 5 10)\\n7 - 5 = 2 (left: 1 3 2)\\n', '10 - 6 = 4 (left: 4 5 10)\\n7 - 3 = 4 (left: 1 5 4)\\n', '10 - 6 = 4 (left: 4 5 10)\\n7 / 1 = 7 (left: 3 5 7)\\n', '10 - 4 = 6 (left: 6 5 10)\\n6 + 5 = 11 (left: 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n6 * 5 = 30 (left: 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n6 - 5 = 1 (left: 1 10)\\n', '10 - 4 = 6 (left: 6 5 10)\\n5 * 10 = 50 (left: 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\n5 + 10 = 15 (left: 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\n10 - 6 = 4 (left: 4 5)\\n', '10 - 4 = 6 (left: 6 5 10)\\n10 / 5 = 2 (left: 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\nInput: 3 9 27\\n', '10 - 4 = 6 (left: 6 5 10)\\nPossible next steps:\\n', '10 - 4 = 6 (left: 6 5 10)\\n3 * 9 = 27 (left: 27)\\n', '10 - 4 = 6 (left: 6 5 10)\\n3 + 9 = 12 (left: 12 27)\\n', '10 - 4 = 6 (left: 6 5 10)\\n9 / 3 = 3 (left: 3 27)\\n', '10 - 4 = 6 (left: 6 5 10)\\n9 * 27 = 243 (left: 3)\\n', '10 - 4 = 6 (left: 6 5 10)\\n9 + 27 = 36 (left: 3)\\n', '10 - 4 = 6 (left: 6 5 10)\\n27 / 3 = 9 (left: 9)\\n', '10 - 4 = 6 (left: 6 5 10)\\n27 - 9 = 18 (left: 3)\\n', '10 - 4 = 6 (left: 6 5 10)\\n9 - 3 = 6 (left: 6 27)\\n', '10 - 4 = 6 (left: 6 5 10)\\nInput: 4 2 2\\n', '10 - 4 = 6 (left: 6 5 10)\\nPossible next steps:\\n', '10 - 4 = 6 (left: 6 5 10)\\n4 + 2 = 6 (left: 2 6)\\n', '10 - 4 = 6 (left: 6 5 10)\\n4 * 2 = 8 (left: 2 8)\\n', '10 - 4 = 6 (left: 6 5 10)\\n4 / 2 = 2 (left: 2 2)\\n', '10 - 4 = 6 (left: 6 5 10)\\n2 * 2 = 4 (left: 4 2)\\n', '10 - 4 = 6 (left: 6 5 10)\\n2 / 2 = 1 (left: 1 2)\\n', '10 - 4 = 6 (left: 6 5 10)\\n2 + 2 = 4 (left: 4)\\n', '10 - 4 = 6 (left: 6 5 10)\\n2 - 2 = 0 (left: 4)\\n', '10 - 4 = 6 (left: 6 5 10)\\n4 - 2 = 2 (left: 2 2)\\n', '6 + 4 = 10 (left: 10 5 6)\\n10 + 5 = 15 (left: 6 15)\\n', '6 + 4 = 10 (left: 10 5 6)\\n10 / 5 = 2 (left: 2 6)\\n', '6 + 4 = 10 (left: 10 5 6)\\n10 - 5 = 5 (left: 5 6)\\n', '6 + 4 = 10 (left: 10 5 6)\\n5 + 6 = 11 (left: 10 11)\\n', '6 + 4 = 10 (left: 10 5 6)\\n6 - 5 = 1 (left: 1 6)\\n']\n", - "(Step 1) ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]\n", - "(Step 2) Values: [0.003, 0.003, 20.002, 0.002, 0.003, 0.003, 40.001, 21.0, 0.003, 60.0, 0.002, 1.001, 0.003, 60.0, 0.002, 0.003, 0.003, 40.001, 40.001, 60.0, 60.0, 40.001, 1.002, 0.003, 60.0, 0.003, 0.001, 0.002, 1.001, 60.0, 40.001, 0.003, 0.003, 1.002, 2.0, 2.001, 1.002, 3.0, 2.0, 3.0, 0.002, 0.002, 60.0, 0.003, 0.003, 40.001, 0.003, 0.003, 0.003, 0.002, 0.003, 60.0, 0.0, 0.0, 0.003, 0.002, 0.003, 0.002, 0.003, 1.0, 1.002, 1.0, 1.0, 0.003, 0.0, 0.002, 21.0, 0.003, 0.003, 0.002, 0.003, 60.0]\n", - "(Step 3) Selected new thoughts: ['10 - 5 = 5 (left: 4 6 5)\\n6 - 4 = 2 (left: 2 5)\\n', '10 - 5 = 5 (left: 4 6 5)\\n5 - 4 = 1 (left: 1 6)\\n', '10 - 5 = 5 (left: 4 6 5)\\n12 - 3 = 9 (left: 9 4 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 / 3 = 3 (left: 4 3 12)\\n', '10 - 5 = 5 (left: 4 6 5)\\n9 - 3 = 6 (left: 6 4 12)\\n']\n", - "-----------------------------------------------------------------------------------------------------------------\n" + "(Step 2) Scores: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n", + "(Step 3) Selected new thoughts: ['Input: 2 8 8 14\\n', 'Possible next steps:\\n', '2 + 8 = 10 (left: 8 10 14)\\n', '8 / 2 = 4 (left: 4 8 14)\\n', '14 + 2 = 16 (left: 8 8 16)\\n']\n", + "--------\n" ] }, { "data": { "text/plain": [ - "['10 - 5 = 5 (left: 4 6 5)\\n6 - 4 = 2 (left: 2 5)\\n',\n", - " '10 - 5 = 5 (left: 4 6 5)\\n5 - 4 = 1 (left: 1 6)\\n',\n", - " '10 - 5 = 5 (left: 4 6 5)\\n12 - 3 = 9 (left: 9 4 12)\\n',\n", - " '10 - 5 = 5 (left: 4 6 5)\\n9 / 3 = 3 (left: 4 3 12)\\n',\n", - " '10 - 5 = 5 (left: 4 6 5)\\n9 - 3 = 6 (left: 6 4 12)\\n']" + "['Input: 2 8 8 14\\n',\n", + " 'Possible next steps:\\n',\n", + " '2 + 8 = 10 (left: 8 10 14)\\n',\n", + " '8 / 2 = 4 (left: 4 8 14)\\n',\n", + " '14 + 2 = 16 (left: 8 8 16)\\n']" ] }, - "execution_count": 70, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -620,9 +719,9 @@ ], "metadata": { "kernelspec": { - "display_name": "transformers", + "display_name": "Python [conda env:.conda-rahul_env]", "language": "python", - "name": "python3" + "name": "conda-env-.conda-rahul_env-py" }, "language_info": { "codemirror_mode": { @@ -634,9 +733,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.4" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }