From 8fdb8aacfae72cf1a2232370af18ddc435cc3cc3 Mon Sep 17 00:00:00 2001 From: lorenzobalzani Date: Sat, 25 Jan 2025 13:23:18 +0100 Subject: [PATCH] feat: custom A* algorithm --- .github/workflows/lint_and_style.yaml | 34 + .gitignore | 6 + .pre-commit-config.yaml | 62 ++ README.md | 74 ++ env.yml | 180 +++++ planning.ipynb | 973 ++++++++++++++++++++++++++ problem_statement.txt | 43 ++ setup_env.sh | 34 + 8 files changed, 1406 insertions(+) create mode 100644 .github/workflows/lint_and_style.yaml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 README.md create mode 100644 env.yml create mode 100644 planning.ipynb create mode 100644 problem_statement.txt create mode 100755 setup_env.sh diff --git a/.github/workflows/lint_and_style.yaml b/.github/workflows/lint_and_style.yaml new file mode 100644 index 0000000..4c3e8e3 --- /dev/null +++ b/.github/workflows/lint_and_style.yaml @@ -0,0 +1,34 @@ +name: lint_and_style + +on: + pull_request: + push: + branches: [main, master] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: conda-incubator/setup-miniconda@v2 + with: + environment-file: env.yml + activate-environment: lint_and_style + auto-activate-base: false + - uses: pre-commit/action@v3.0.1 + + pylint: + runs-on: ubuntu-latest + needs: pre-commit + continue-on-error: true + if: github.event_name == 'pull_request' + steps: + - uses: actions/checkout@v3 + - uses: conda-incubator/setup-miniconda@v2 + with: + environment-file: env.yml + activate-environment: lint_and_style + auto-activate-base: false + - uses: pre-commit/action@v3.0.1 + with: + extra_args: --hook-stage manual pylint-all --all-files diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8cc5d7c --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*.env +__pycache__ +.vscode +.idea +**/.DS_Store +mcp_solver.egg-info diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8542b1d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,62 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-docstring-first + - id: check-toml + - id: check-yaml + exclude: packaging/.* + - id: mixed-line-ending + args: [--fix=lf] + - id: end-of-file-fixer + - repo: https://github.com/hhatto/autopep8 + rev: v2.1.0 + hooks: + - id: autopep8 + args: [--in-place, --aggressive, --exit-code] + types: [python] + - repo: local + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + types: [python] + args: + [ + "--max-line-length=120", + "--errors-only", + ] + - id: pylint + alias: pylint-all + name: pylint-all + entry: pylint + language: system + types: [python] + args: + [ + "--max-line-length=120", + "--disable=W2402", # non-ascii-file-name, but Streamlit page names contain emojis + ] + stages: [manual] + - repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: [--config=setup.cfg] + additional_dependencies: + - flake8-bugbear==22.10.27 + - flake8-comprehensions==3.10.1 + - torchfix==0.0.2 + - repo: https://github.com/facebook/usort + rev: v1.0.7 + hooks: + - id: usort + name: Sort imports with µsort + description: Safe, minimal import sorting + language: python + types_or: + - python + - pyi + entry: usort format + require_serial: true diff --git a/README.md b/README.md new file mode 100644 index 0000000..5bc761c --- /dev/null +++ b/README.md @@ -0,0 +1,74 @@ +# A3I Project - Choreography Planning by Decoupling A* Deterministic Search and Heuristic Computation + +## Overview +This project uses advanced techniques to design and evaluate dance choreographies for the NAO robot. By combining heuristic methods, a custom A* search algorithm, and large language model (LLM) evaluations, the system generates creative and engaging robot movements tailored to specific constraints and criteria. + +## Features +- **Choreography Evaluation**: Rates choreography based on storytelling, rhythm, movement technique, public involvement, space use, human characterization, and human reproducibility. +- **A* Search Algorithm**: Custom implementation to find the optimal sequence of dance moves while adhering to constraints like duration and mandatory moves. +- **Heuristic Computation**: Combines past scores (g-function) and future estimations (h-function) to guide the search process. +- **Dynamic Scoring**: Allows adjusting weights for different scoring components (e.g., waypoint adherence, duration bonus). + +## Components +1. **LLM Scoring Functions**: + - `g_llm`: Evaluates the coolness of a complete dance sequence. + - `h_llm`: Estimates the coolness of potential moves for an incomplete sequence. +2. **Custom A* Implementation**: + - Uses LLM-generated scores to find optimal paths for robot choreography. + - Incorporates dynamic constraint satisfaction. +3. **Scoring Criteria**: + - `gh_weight`: Importance of coolness scores. + - `t_weight`: Importance of adhering to waypoints. + - `db_weight`: Bonus for optimal duration. + - `csc_weight`: Importance of satisfying constraints. + +## Input Data +- **Initial State**: Starting position of the robot. +- **Final State**: Target position to complete the choreography. +- **Waypoints**: Intermediate positions the robot must pass through. +- **Mandatory Positions**: Specific moves required in the choreography. +- **Intermediate Positions**: Optional moves for enhanced creativity. +- **Graph**: A connectivity graph representing valid transitions between moves. + +## Outputs +- A complete sequence of moves. +- Coolness scores for the choreography. +- Computation cost and number of LLM calls. + +## Installation +1. Clone this repository: + ```bash + git clone https://github.com/lorenzobalzani/llm-a-star-planner.git + ``` +2. Make sure to have conda installed onto your system. +3. Set up the conda environment by launching the script `./setup_env.sh` in the terminal. +4. Activate it with `conda activate a3i`. +5. For installing the ipython kernel, please run `python -m ipykernel install --user --name=a3i`. + +## Setup +Create a file named `secrets.env` in the root directory of the project and add the following lines: +```bash +AZURE_OPENAI_ENDPOINT = +AZURE_OPENAI_API_KEY = +OPENAI_API_VERSION = +``` +If you prefer to use an LLM other than Azure OpenAI, update the LLM client definitions in the notebook accordingly + +## Example Output +``` +Choreography: ['INITIAL_stand_init', 'mandatory_stand', 'diagonal_left', ... 'FINAL_crouch'] +Cost: 0.06451€ +Number of calls to LLM: 59 +Elapsed time: 31.50s +``` + +## Contributing +Contributions are welcome! Please open an issue or submit a pull request. + +## Authors +- **Davide Bombardi** (davide.bombardi@studio.unibo.it) +- **Lorenzo Balzani** (lorenzo.balzani@studio.unibo.it) + +## License +This project is licensed under the MIT License. + diff --git a/env.yml b/env.yml new file mode 100644 index 0000000..d792967 --- /dev/null +++ b/env.yml @@ -0,0 +1,180 @@ +name: a3i +channels: + - conda-forge + - defaults + - https://repo.anaconda.com/pkgs/main + - https://repo.anaconda.com/pkgs/r +dependencies: + - bzip2=1.0.8=h80987f9_6 + - ca-certificates=2024.9.24=hca03da5_0 + - expat=2.6.3=h313beb8_0 + - libcxx=14.0.6=h848a8c0_0 + - libffi=3.4.4=hca03da5_1 + - ncurses=6.4=h313beb8_0 + - openssl=3.0.15=h80987f9_0 + - pip=24.2=py312hca03da5_0 + - python=3.12.7=h99e199e_0 + - readline=8.2=h1a28f6b_0 + - setuptools=75.1.0=py312hca03da5_0 + - sqlite=3.45.3=h80987f9_0 + - tk=8.6.14=h6ba3021_0 + - tzdata=2024b=h04d1e81_0 + - wheel=0.44.0=py312hca03da5_0 + - xz=5.4.6=h80987f9_1 + - zlib=1.2.13=h18a0788_1 + - pip: + - aiohappyeyeballs==2.4.3 + - aiohttp==3.11.2 + - aiosignal==1.3.1 + - annotated-types==0.7.0 + - anyio==4.6.2.post1 + - appnope==0.1.4 + - argon2-cffi==23.1.0 + - argon2-cffi-bindings==21.2.0 + - arrow==1.3.0 + - astroid==3.3.5 + - asttokens==2.4.1 + - async-lru==2.0.4 + - attrs==24.2.0 + - babel==2.16.0 + - beautifulsoup4==4.12.3 + - bleach==6.2.0 + - certifi==2024.8.30 + - cffi==1.17.1 + - cfgv==3.4.0 + - charset-normalizer==3.4.0 + - comm==0.2.2 + - contourpy==1.3.1 + - cycler==0.12.1 + - dataclasses-json==0.6.7 + - debugpy==1.8.8 + - decorator==5.1.1 + - defusedxml==0.7.1 + - dill==0.3.9 + - distlib==0.3.9 + - distro==1.9.0 + - executing==2.1.0 + - fastjsonschema==2.20.0 + - filelock==3.16.1 + - fonttools==4.55.0 + - fqdn==1.5.1 + - frozenlist==1.5.0 + - h11==0.14.0 + - httpcore==1.0.7 + - httpx==0.27.2 + - httpx-sse==0.4.0 + - identify==2.6.2 + - idna==3.10 + - ipykernel==6.29.5 + - ipython==8.29.0 + - ipywidgets==8.1.5 + - isoduration==20.11.0 + - isort==5.13.2 + - jedi==0.19.2 + - jinja2==3.1.4 + - jiter==0.7.1 + - json5==0.9.28 + - jsonpatch==1.33 + - jsonpointer==3.0.0 + - jsonschema==4.23.0 + - jsonschema-specifications==2024.10.1 + - jupyter==1.1.1 + - jupyter-client==8.6.3 + - jupyter-console==6.6.3 + - jupyter-core==5.7.2 + - jupyter-events==0.10.0 + - jupyter-lsp==2.2.5 + - jupyter-server==2.14.2 + - jupyter-server-terminals==0.5.3 + - jupyterlab==4.2.6 + - jupyterlab-pygments==0.3.0 + - jupyterlab-server==2.27.3 + - jupyterlab-widgets==3.0.13 + - kiwisolver==1.4.7 + - langchain==0.3.7 + - langchain-community==0.3.7 + - langchain-core==0.3.19 + - langchain-openai==0.2.8 + - langchain-text-splitters==0.3.2 + - langsmith==0.1.143 + - load-dotenv==0.1.0 + - loguru==0.7.2 + - markupsafe==3.0.2 + - marshmallow==3.23.1 + - matplotlib==3.9.2 + - matplotlib-inline==0.1.7 + - mccabe==0.7.0 + - mistune==3.0.2 + - multidict==6.1.0 + - mypy-extensions==1.0.0 + - nbclient==0.10.0 + - nbconvert==7.16.4 + - nbformat==5.10.4 + - nest-asyncio==1.6.0 + - networkx==3.4.2 + - nodeenv==1.9.1 + - notebook==7.2.2 + - notebook-shim==0.2.4 + - numpy==1.26.4 + - openai==1.54.4 + - orjson==3.10.11 + - overrides==7.7.0 + - packaging==24.2 + - pandocfilters==1.5.1 + - parso==0.8.4 + - pexpect==4.9.0 + - pillow==11.0.0 + - platformdirs==4.3.6 + - pre-commit==4.0.1 + - prometheus-client==0.21.0 + - prompt-toolkit==3.0.48 + - propcache==0.2.0 + - psutil==6.1.0 + - ptyprocess==0.7.0 + - pure-eval==0.2.3 + - pycparser==2.22 + - pydantic==2.9.2 + - pydantic-core==2.23.4 + - pydantic-settings==2.6.1 + - pygments==2.18.0 + - pylint==3.3.1 + - pyparsing==3.2.0 + - python-dateutil==2.9.0.post0 + - python-dotenv==1.0.1 + - python-json-logger==2.0.7 + - pyyaml==6.0.2 + - pyzmq==26.2.0 + - referencing==0.35.1 + - regex==2024.11.6 + - requests==2.32.3 + - requests-toolbelt==1.0.0 + - rfc3339-validator==0.1.4 + - rfc3986-validator==0.1.1 + - rpds-py==0.21.0 + - scipy==1.14.1 + - send2trash==1.8.3 + - six==1.16.0 + - sniffio==1.3.1 + - soupsieve==2.6 + - sqlalchemy==2.0.35 + - stack-data==0.6.3 + - tenacity==9.0.0 + - terminado==0.18.1 + - tiktoken==0.8.0 + - tinycss2==1.4.0 + - tomlkit==0.13.2 + - tornado==6.4.1 + - tqdm==4.67.0 + - traitlets==5.14.3 + - types-python-dateutil==2.9.0.20241003 + - typing-extensions==4.12.2 + - typing-inspect==0.9.0 + - uri-template==1.3.0 + - urllib3==2.2.3 + - virtualenv==20.27.1 + - wcwidth==0.2.13 + - webcolors==24.11.1 + - webencodings==0.5.1 + - websocket-client==1.8.0 + - widgetsnbextension==4.0.13 + - yarl==1.17.1 diff --git a/planning.ipynb b/planning.ipynb new file mode 100644 index 0000000..b5c25cd --- /dev/null +++ b/planning.ipynb @@ -0,0 +1,973 @@ +{ + "cells": [ + { + "metadata": { + "id": "df7334004bee5154" + }, + "cell_type": "markdown", + "source": [ + "

A3I Project - Choreography Planning by Decoupling A* Deterministic Search and Heuristic Computation

\n", + "

Authors: Davide Bombardi, Lorenzo Balzani

\n", + "

\n", + " davide.bombardi@studio.unibo.it -\n", + " lorenzo.balzani@studio.unibo.it\n", + "

\n", + "\n", + "\n", + "

Abstract

\n", + "
\n", + "

\n", + " Our project focuses on designing and implementing a customized A* algorithm to generate robotic dance choreographies. This algorithm is inspired by the traditional A* search method but incorporates a unique heuristic powered by a Large Language Model (LLM). The primary goal of the algorithm is to create choreographies that are both dynamic and visually appealing (or “cool”), while adhering to specific constraints. These constraints ensure that the choreographies are functional, safe, and compatible with the capabilities of the robotic platform.\n", + "

\n", + "
" + ], + "id": "df7334004bee5154" + }, + { + "metadata": { + "id": "890a231612e4c9b1" + }, + "cell_type": "markdown", + "source": [ + "## Setup\n", + "Perform the following steps to setup the environment for running the code in this notebook:\n", + "1. Create a file named `secrets.env` in the root directory of the project and add the following lines:\n", + " ```bash\n", + " OPENAI_API_KEY=\n", + " OPENAI_API_URL=https://api.openai.com\n", + " ```\n", + "2. Setup the environment by launching the script `./setup_env.sh` in the terminal.\n", + "3. After activate it with `conda activate a3i`.\n", + "4. For installing the ipython kernel, please run `python -m ipykernel install --user --name=a3i`." + ], + "id": "890a231612e4c9b1" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T14:21:27.761455Z", + "start_time": "2025-01-12T14:21:24.132599Z" + }, + "id": "346c1ee767482e62" + }, + "cell_type": "code", + "source": [ + "import random\n", + "import re\n", + "import heapq\n", + "\n", + "from time import time\n", + "from collections import defaultdict\n", + "from itertools import product\n", + "\n", + "import networkx as nx\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "from dotenv import load_dotenv\n", + "from langchain_openai import AzureChatOpenAI\n", + "from langchain_core.prompts import PromptTemplate\n", + "\n", + "# Models\n", + "gpt_4o = AzureChatOpenAI(\n", + " deployment_name=\"gpt-4o\",\n", + " temperature=0,\n", + " verbose=True,\n", + ")\n", + "\n", + "gpt_4o_mini = AzureChatOpenAI(\n", + " deployment_name=\"gpt-4o-mini\",\n", + " temperature=0,\n", + " verbose=True,\n", + ")\n", + "\n", + "load_dotenv(\"secrets.env\")\n", + "SEED = 42\n", + "random.seed(SEED)" + ], + "id": "346c1ee767482e62", + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "id": "e0fc61e0fb53ddf6" + }, + "cell_type": "markdown", + "source": [ + "## Position Graph Creation" + ], + "id": "e0fc61e0fb53ddf6" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T14:21:07.617880Z", + "start_time": "2025-01-12T14:21:06.957604Z" + }, + "id": "eeaa4b0ca3e7bb68", + "outputId": "d2ea8478-36b5-4669-89d0-2b9706d91658" + }, + "cell_type": "code", + "source": [ + "initial_state = 'INITIAL_stand_init'\n", + "final_state = 'FINAL_crouch'\n", + "\n", + "# Define the list of nodes\n", + "nodes = {initial_state, 'rotation_handgun_object', 'right_arm_rotation', 'double_movement_rotation_of_arms',\n", + " 'arms_opening', 'union_arms', 'move_forward', 'move_backward', 'diagonal_left', 'diagonal_right',\n", + " 'rotation_foot_left_leg', 'rotation_foot_right_leg', 'play_guitar', 'arms_dance', 'birthday_dance',\n", + " 'sprinkler_dance', 'workout_legs_and_arms', 'superman', 'mandatory_sit', 'mandatory_wipe_forehead',\n", + " 'mandatory_hello', 'mandatory_stand', 'mandatory_zero', final_state}\n", + "\n", + "# Define the exception nodes for which there are incompatible transitions\n", + "exception_nodes = {\"mandatory_sit\": {'diagonal_right', 'diagonal_left', 'rotation_foot_right_leg', 'rotation_foot_left_leg',\n", + " final_state, 'move_backward', 'move_forward', 'play_guitar', 'arms_dance', 'birthday_dance',\n", + " 'sprinkler_dance', 'workout_legs_and_arms', 'superman'}\n", + "}\n", + "\n", + "mandatory_positions = {'mandatory_sit', 'mandatory_wipe_forehead', 'mandatory_hello', 'mandatory_stand', 'mandatory_zero'}\n", + "intermediate_positions = list(nodes - {initial_state, final_state} - mandatory_positions)\n", + "\n", + "# Initialize the directed graph, since there are incompatible transitions between some nodes\n", + "G = nx.DiGraph()\n", + "G.add_nodes_from(nodes)\n", + "\n", + "# Add edges between all pairs of nodes except the specified exceptions\n", + "for u, v in product(nodes, repeat=2):\n", + " if u == final_state: # Skip adding edges from 'FINAL_crouch' to other nodes, since it is the final node\n", + " continue\n", + " if v == initial_state: # Skip adding edges to 'INITIAL_stand_init', since it is the initial node\n", + " continue\n", + " if u in exception_nodes and v in exception_nodes[u]: # Skip adding edges if the pair is in the exception list\n", + " continue\n", + " G.add_edge(u, v, weight=1) # Add an edge with a cost of 1\n", + "\n", + "# Visualize the graph\n", + "plt.figure(figsize=(12, 12))\n", + "pos = nx.spring_layout(G, k=0.15)\n", + "nx.draw_networkx_nodes(G, pos, node_size=500, node_color='skyblue')\n", + "nx.draw_networkx_edges(G, pos, width=1.0, alpha=0.5)\n", + "nx.draw_networkx_labels(G, pos, font_size=8, font_family='sans-serif')\n", + "\n", + "plt.axis('off')\n", + "plt.title('Positions Graph')\n", + "plt.show()" + ], + "id": "eeaa4b0ca3e7bb68", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": null + }, + { + "metadata": { + "id": "9485e0715b382212" + }, + "cell_type": "markdown", + "source": [ + "## Graph Statistics" + ], + "id": "9485e0715b382212" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T14:21:33.862338Z", + "start_time": "2025-01-12T14:21:29.498152Z" + }, + "id": "62f284f78c2b1419", + "outputId": "4a79a953-6004-403a-ff1e-89ba88d81fcf" + }, + "cell_type": "code", + "source": [ + "# 1. Number of nodes\n", + "num_nodes = G.number_of_nodes()\n", + "print(f\"1. Number of nodes: {num_nodes}\")\n", + "\n", + "# 2. Number of edges\n", + "num_edges = G.number_of_edges()\n", + "print(f\"2. Number of edges: {num_edges}\")\n", + "\n", + "# 3. Average in-degree and out-degree\n", + "in_degrees = G.in_degree()\n", + "out_degrees = G.out_degree()\n", + "avg_in_degree = sum(dict(in_degrees).values()) / num_nodes\n", + "avg_out_degree = sum(dict(out_degrees).values()) / num_nodes\n", + "print(f\"3. Average in-degree: {avg_in_degree:.2f}\\t--\\tAverage out-degree: {avg_out_degree:.2f}\")\n", + "\n", + "# 4. Density of the graph\n", + "density = nx.density(G)\n", + "print(f\"4. Graph density: {density:.4f}\")\n", + "\n", + "# 5. Is the graph strongly connected?\n", + "is_strongly_connected = nx.is_strongly_connected(G)\n", + "print(f\"5. Is the graph strongly connected? {is_strongly_connected}\")\n", + "\n", + "# 6. Number of strongly connected components\n", + "num_scc = nx.number_strongly_connected_components(G)\n", + "print(f\"6. Number of strongly connected components: {num_scc}\")\n", + "\n", + "# 7. Average clustering coefficient\n", + "# For DiGraph, we use the directed version of clustering coefficient\n", + "avg_clustering = nx.average_clustering(G)\n", + "print(f\"7. Average clustering coefficient: {avg_clustering:.4f}\")\n", + "\n", + "# 8. Top 3 nodes by PageRank\n", + "pagerank = nx.pagerank(G)\n", + "top_3_pagerank = sorted(pagerank.items(), key=lambda x: x[1], reverse=True)[:3]\n", + "print(\"8. Top 3 nodes by PageRank:\")\n", + "for node, pr in top_3_pagerank:\n", + " print(f\" {node}: {pr:.4f}\")\n", + "\n", + "# 9. In-degree and out-degree of 'mandatory_sit'\n", + "in_degree_mandatory_sit = G.in_degree('mandatory_sit')\n", + "out_degree_mandatory_sit = G.out_degree('mandatory_sit')\n", + "print(f\"9. 'mandatory_sit' in-degree: {in_degree_mandatory_sit}\\t--\\tout-degree: {out_degree_mandatory_sit}\")\n", + "\n", + "# 10. Average shortest path length (if the graph is strongly connected)\n", + "if is_strongly_connected:\n", + " avg_shortest_path_length = nx.average_shortest_path_length(G)\n", + " print(f\"10. Average shortest path length: {avg_shortest_path_length:.4f}\")\n", + "else:\n", + " print(\"10. Graph is not strongly connected; average shortest path length not defined.\")\n", + "\n", + "# 11. Number of nodes with zero in-degree (sources)\n", + "zero_in_degree = [n for n, d in G.in_degree() if d == 0]\n", + "print(f\"11. Number of nodes with zero in-degree (sources): {len(zero_in_degree)}\")\n", + "\n", + "# 12. Number of nodes with zero out-degree (sinks)\n", + "zero_out_degree = [n for n, d in G.out_degree() if d == 0]\n", + "print(f\"12. Number of nodes with zero out-degree (sinks): {len(zero_out_degree)}\")\n", + "\n", + "# 13. Diameter of the graph (if strongly connected)\n", + "if is_strongly_connected:\n", + " diameter = nx.diameter(G)\n", + " print(f\"13. Diameter of the graph: {diameter}\")\n", + "else:\n", + " print(\"13. Graph is not strongly connected; diameter not defined.\")\n", + "\n", + "# 14. Eccentricity of the initial state (if strongly connected)\n", + "if is_strongly_connected:\n", + " eccentricity = nx.eccentricity(G, initial_state)\n", + " print(f\"14. Eccentricity of '{initial_state}': {eccentricity}\")\n", + "else:\n", + " print(\"14. Graph is not strongly connected; eccentricity not defined.\")" + ], + "id": "62f284f78c2b1419", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1. Number of nodes: 24\n", + "2. Number of edges: 516\n", + "3. Average in-degree: 21.50\t--\tAverage out-degree: 21.50\n", + "4. Graph density: 0.9348\n", + "5. Is the graph strongly connected? False\n", + "6. Number of strongly connected components: 3\n", + "7. Average clustering coefficient: 0.9365\n", + "8. Top 3 nodes by PageRank:\n", + " mandatory_wipe_forehead: 0.0453\n", + " right_arm_rotation: 0.0453\n", + " double_movement_rotation_of_arms: 0.0453\n", + "9. 'mandatory_sit' in-degree: 23\t--\tout-degree: 10\n", + "10. Graph is not strongly connected; average shortest path length not defined.\n", + "11. Number of nodes with zero in-degree (sources): 1\n", + "12. Number of nodes with zero out-degree (sinks): 1\n", + "13. Graph is not strongly connected; diameter not defined.\n", + "14. Graph is not strongly connected; eccentricity not defined.\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": { + "id": "8b557b83ee39cad5" + }, + "cell_type": "markdown", + "source": [ + "## Graph Waypoints definition" + ], + "id": "8b557b83ee39cad5" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T14:21:41.047327Z", + "start_time": "2025-01-12T14:21:40.228032Z" + }, + "id": "d202d2924ac24915", + "outputId": "77c08c14-ba99-43ef-fcd1-d2b38215ca30" + }, + "cell_type": "code", + "source": [ + "# Define a prompt template for getting waypoints\n", + "waypoint_prompt = PromptTemplate(\n", + " input_variables=['initial_state', 'final_state', 'positions'],\n", + " template=\"\"\"\n", + " Given the initial state {initial_state} and the final state {final_state}, output a **strictly comma-separated list** of positions from the available ones that create a very cool, elegant and creative movement sequence. Include initial and final state. Do not include any other text in the output, and format the list exactly as: `position1, position2, position3`.\n", + "\n", + " # POSITIONS\n", + " {positions}\n", + " \"\"\"\n", + ")\n", + "\n", + "# Create an LLMChain for waypoints\n", + "waypoint_chain = waypoint_prompt | gpt_4o\n", + "\n", + "# Define the available positions as the union of all the nodes except the initial and final states\n", + "available_positions = list(nodes - {initial_state, final_state})\n", + "random.shuffle(available_positions)\n", + "\n", + "# Run the chain\n", + "result = waypoint_chain.invoke({\"initial_state\": initial_state, \"final_state\": final_state, \"positions\": available_positions})\n", + "\n", + "# Extract the comma-separated list using regex\n", + "waypoints = re.findall(r'\\b\\w+(?:, \\w+)*\\b', result.content)\n", + "print(f\"Waypoints sequence: {waypoints}\")" + ], + "id": "d202d2924ac24915", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Waypoints sequence: ['INITIAL_stand_init', 'mandatory_stand', 'diagonal_left', 'superman', 'rotation_foot_right_leg', 'arms_opening', 'FINAL_crouch']\n" + ] + } + ], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "# Custom A* Algorithm\n", + "This function is the implementation of a personalized A\\* algorithm. In our case, the solution space Σ is the set of all the possible choreographies composed of the given moves (i.e. of the moves stored in the set `nodes`). We consider a solution (i.e. a choreography) σ ∈ Σ correct if it satisfies the following constraints:\n", + "1. The choreography σ doesn't contain any unfeasible transitions, i.e. σ respects all the incompatibilities between moves in the dict `exception_nodes`. In other words, for every move 𝑘 in `exception_nodes.keys()`, if 𝑘 ∈ σ then 𝑘 is not followed by any of the moves in `exception_nodes[𝑘]`;\n", + "2. The choreography σ contains all the mandatory positions (i.e. moves in the set `mandatory_positions`);\n", + "3. The choreography σ contains at least `n_intermediate_positions` intermediate positions (i.e. moves in the set `intermediate_positions`). In our case `n_intermediate_positions = 5`, as stated in the reference paper;\n", + "4. The execution of σ takes an amount of seconds that is greater than 120 - `duration_tolerance` and lower than 120 + `duration_tolerance`. In our case `duration_tolerance = 5`, as stated in the reference paper.\n", + "\n", + "Our A\\* algorithm builds the final choreography by adding a move to the most promising solution found so far at each iteration. In this perspective, the first constraint is achieved by considering only the neighbours of the last node ( i.e. the last move) of the choreography considered at the current iteration of the A\\* algorithm. The second and third constraint are achieved via a simple check on the moves selected in the candidate solution.\n", + "For the fourth constraint we need to estimate the duration of the execution of each move. This is done once before running A\\* by calling the LLM with prompt `moves_duration_prompt`.\n" + ], + "metadata": { + "id": "NbtfZLSbKZCS" + }, + "id": "NbtfZLSbKZCS" + }, + { + "cell_type": "markdown", + "source": [ + "### The heuristic\n", + "A\\* builds the solution by taking at each iteration the most promising solution found so far and by adding a move at the end of the considered solution. Therefore, the definition of the function `f_function` that computes the score of a solution (where the most promising solution is the one with the highest score) is fundamental in order to guarantee that A\\* reaches a good solution in a reasonable amount of time.\n", + "Once A\\* has considered the most promising solution σ, for every neighbour 𝑘 of the last move of σ, the scoring function `f_function` computes the score of σ + [𝑘] by doing a weighted sum of 4 normalized scores:\n", + "1. `g_value + h_value`, where:\n", + " - `g_value` is the (normalized) coolness score of σ, computed by calling the LLM with `g_prompt` as prompt, weighted by the (normalized) duration of σ. The prompt takes into account the evaluation targets described in the reference paper. Then `g_value` is the average of the scores given by the LLM in each evaluation target;\n", + " - `h_value` is the (normalized) coolness score of 𝑘 if added at the end of σ, computed by calling the LLM with `h_prompt` as prompt, weighted by the (normalized) time left to consider σ a complete choreography. In order to reduce the calls to the LLM, the `h_value` is computed in one call to the LLM for all the neighbours of the last move of σ. Then when dealing with the neighbour 𝑘 we consider the correspondent `h_value`;\n", + "2. `t_value`\n", + "3. `duration_bonus`\n", + "4. `constraint_sat_count`\n" + ], + "metadata": { + "id": "Sklws2sGYrsb" + }, + "id": "Sklws2sGYrsb" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Setup and Utilities", + "id": "cae3165ec93d5850" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "max_duration = 120\n", + "duration_tolerance = 5\n", + "max_coolness_value = 5\n", + "debug = False\n", + "\n", + "# LLM pricing\n", + "gpt_4o_mini_costs = {\"prompt\": 0.14311/1e6, \"completion\": 0.5725/1e6}\n", + "gpt_4o_costs = {\"prompt\": 2.31514/1e6, \"completion\": 9.2606/1e6}" + ], + "id": "2a5bdc4ae9a4664a" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "def asymmetric_distribution(x, mu=max_duration, sigma=50, alpha=-0.1):\n", + " \"\"\"\n", + " Computes an asymmetric distribution value for a given input `x`.\n", + "\n", + " The distribution behaves like a normal (Gaussian) distribution for `x <= mu`\n", + " and introduces an exponential penalty for `x > mu`. This is useful for scenarios\n", + " where values greater than a specified mean (`mu`) should be penalized more heavily\n", + " than those below it.\n", + "\n", + " Args:\n", + " x (float): The input value for which the distribution is calculated.\n", + " mu (float, optional): The mean of the distribution. Default is `max_duration`.\n", + " sigma (float, optional): The standard deviation of the normal distribution. Default is 50.\n", + " alpha (float, optional): The exponential decay factor for values of `x > mu`. Default is -0.1.\n", + "\n", + " Returns:\n", + " float: The calculated value of the asymmetric distribution at `x`.\n", + " \"\"\"\n", + " if x <= mu:\n", + " return np.exp(- (x - mu)**2 / (2 * sigma**2))\n", + "\n", + " return np.exp(- (x - mu)**2 / (2 * sigma**2)) * np.exp(alpha * (x - mu))" + ], + "id": "edfb9714922fab65" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Prompts Definition", + "id": "8d45b54517004c26" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "#### Moves Duration Computation", + "id": "84d3d30c84ec1f9" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T16:14:19.924338Z", + "start_time": "2025-01-12T16:14:18.440913Z" + }, + "id": "c78a51d64a58220b" + }, + "cell_type": "code", + "source": [ + "moves_duration_prompt = PromptTemplate(\n", + " input_variables=[\"moves\"],\n", + " template=\"\"\"\n", + " You are a robotics expert analyzing the motion efficiency of a NAO robot. Below is a list of moves the robot will perform. For each move, provide the estimated duration in seconds based on typical execution times for similar robotic actions:\n", + "\n", + " Moves list:\n", + " {moves}\n", + "\n", + " ** Criteria **\n", + " 1.\tProvide an estimate for each move in seconds.\n", + "\t2.\tConsider the complexity, speed, and transition time for typical NAO robot motions.\n", + "\n", + "\tOutput Format:\n", + " Provide N mappings (separated by newlines) between move and estimated duration (without unit of measurement) in the format:\n", + " `move_name:estimated_duration_in_seconds`\n", + "\n", + " DO NOT OUTPUT ANY OTHER INFORMATION, INCLUDING BACKTICKS.\n", + " \"\"\"\n", + ")\n", + "\n", + "response = (moves_duration_prompt | gpt_4o).invoke(\"\\n- \".join(nodes)).content\n", + "moves_duration = {value.split(\":\")[0].strip(): int(value.split(\":\")[1].strip()) for value in response.split(\"\\n\")}\n" + ], + "id": "c78a51d64a58220b", + "outputs": [], + "execution_count": null + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T16:15:25.108653Z", + "start_time": "2025-01-12T16:15:25.105927Z" + }, + "id": "2209c196bd9d3d1e" + }, + "cell_type": "code", + "source": [ + "def compute_total_duration(dance_sequence):\n", + " total_duration = 0\n", + " for move in dance_sequence:\n", + " total_duration += moves_duration[move]\n", + " return total_duration" + ], + "id": "2209c196bd9d3d1e", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "#### Sequence Evaluation and Heuristic", + "id": "bc666b31bfb9d7da" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T16:15:32.596452Z", + "start_time": "2025-01-12T16:15:32.592515Z" + }, + "id": "defef91d61e54ff6" + }, + "cell_type": "code", + "source": [ + "g_prompt = PromptTemplate(\n", + " input_variables=['len_sequence', 'dance_sequence'],\n", + " template=\"\"\"\n", + "Role and Context: You are an expert robotic dance choreographer. Your task is to evaluate a complete choreography for a NAO robot, represented as a sequence of {len_sequence} dance moves or positions. You must provide 7 numerical outputs as the 7 coolness scores of the choreography.\n", + "\n", + "Task:\n", + "Evaluate the given dance sequence based on the following criteria:\n", + "\n", + "**Coolness Score:**\n", + "Rate the choreography on a scale from 0 to 5 in each of the following categories: (1) Storytelling; (2) Rhythm; (3) Movement Technique; (4) Public Involvement; (5) Space Use; (6) Human Characterization; (7) Human Reproducibility.\n", + "\n", + "Output Format:\n", + "Provide ONLY 7 numerical values in the format:\n", + "`Storytelling_score,Rhythm_score,MovementTechnique_score,PublicInvolvement_score,SpaceUse_score,HumanCharacterization_score,HumanReproducibility_score`\n", + "\n", + "Do not include any text, commentary, or explanation.\n", + "\n", + "Input:\n", + "Dance Sequence: {dance_sequence}\n", + "\"\"\")\n", + "\n", + "h_prompt = PromptTemplate(\n", + " input_variables=['n_neighbors', 'neighbors_list', 'dance_sequence', 'duration'],\n", + " template=\"\"\"\n", + " Role and Context: You are an expert robotic dance choreographer. Your task is to improve an incomplete choreography for a NAO robot. You will be provided with:\n", + " 1. An incomplete list of dance moves or positions (referred to as the current choreography).\n", + " 2. A list of possible moves to evaluate for inclusion in the choreography (referred to as the possible moves). Your job is to assign a coolness score to each possible move based on how well it improves the choreography. The highest-rated moves will be considered for inclusion.\n", + "\n", + "Task: Provide {n_neighbors} numerical scores between 0 and 5, one for each possible move, in the exact order they are listed. The first score corresponds to the first move in the list, the second score to the second move, and so on.\n", + "\n", + "Important Guidelines:\n", + "- Scoring Criteria: Moves should be evaluated based on how much they improve the choreography in storytelling, rhythm, movement technique, public involvement, space use, human characterization and human reproducibility.\n", + "- Duration: The choreography should have a total duration of {duration} seconds.\n", + "- Consistency: Rate moves consistently and fairly. Similar moves should receive similar scores.\n", + "- Diversity: Consider the variety, creativity, and aesthetic appeal of the moves. Give lower scores to moves that are repetitive.\n", + "- Output Format: Provide ONLY the {n_neighbors} scores in the format: coolness_score_first_move,coolness_score_second_move,...,coolness_score_{n_neighbors}-th_move. Do not include any explanations, text, or comments.\n", + "\n", + "Input:\n", + "Current choreography: {dance_sequence}\n", + "Possible Moves: {neighbors_list}\n", + "\"\"\")\n" + ], + "id": "defef91d61e54ff6", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Sequence Evaluation via LLM", + "id": "3673387a05ba5bfc" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T16:21:57.796107Z", + "start_time": "2025-01-12T16:21:57.779457Z" + }, + "id": "c4a58597c651819e" + }, + "cell_type": "code", + "source": [ + "def g_llm(dance_sequence):\n", + " \"\"\"\n", + " Evaluate a complete dance_sequence using the LLM.\n", + "\n", + " Parameters:\n", + " - dance_sequence (List[str]): The sequence of dance moves.\n", + "\n", + " Output:\n", + " - coolness_score (float): The estimated coolness score for the dance_sequence.\n", + " \"\"\"\n", + " # Remove INITIAL_stand_init from the dance_sequence\n", + " if dance_sequence[0] == 'INITIAL_stand_init':\n", + " dance_sequence = dance_sequence[1:]\n", + "\n", + " dance_sequence = tuple(dance_sequence)\n", + "\n", + " # Check if the coolness value is already computed and cached\n", + " if dance_sequence in heuristic_cache[\"g\"]:\n", + " return heuristic_cache[\"g\"][dance_sequence], 0\n", + "\n", + " # Prepare the inputs for the LLM prompt\n", + " prompt_inputs = {\n", + " 'len_sequence': len(dance_sequence),\n", + " 'dance_sequence': str(list(dance_sequence))\n", + " }\n", + "\n", + " # Run the chain, extract and process the response\n", + " try:\n", + " # Call the LLM to get the coolness score\n", + " response = (g_prompt | gpt_4o).invoke(prompt_inputs)\n", + " cost = response.response_metadata[\"token_usage\"]\n", + " total_cost = cost[\"prompt_tokens\"] * gpt_4o_costs[\"prompt\"] + cost[\"completion_tokens\"] * gpt_4o_costs[\"completion\"]\n", + "\n", + " output_list = response.content.strip().split(\",\")\n", + " all_scores = [float(score.strip()) for score in output_list]\n", + "\n", + " assert len(all_scores) == 7\n", + " coolness_estimation = sum(all_scores)/len(all_scores) # average of the scores\n", + "\n", + " # Force the coolness from 0 to max_coolness_score\n", + " coolness_estimation = max(0.0, float(coolness_estimation))\n", + " coolness_estimation = min(coolness_estimation, 5)\n", + "\n", + " # Store the coolness score in the cache\n", + " heuristic_cache[\"g\"][dance_sequence] = coolness_estimation\n", + "\n", + " return heuristic_cache[\"g\"][dance_sequence], total_cost\n", + " except ValueError as e:\n", + " raise ValueError(f\"Invalid response from the g_LLM: {e}\") from e" + ], + "id": "c4a58597c651819e", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Heuristic Computation via LLM", + "id": "4443bedf83b2a318" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "def h_llm(neighbors_list, dance_sequence):\n", + " \"\"\"\n", + " Give one score for each move in a list of moves neighbors_list that could follow the current dance_sequence, using the LLM.\n", + "\n", + " Parameters:\n", + " - neighbors_list (List[str]): The list of possible moves to evaluate.\n", + " - dance_sequence (List[str]): The sequence of dance moves.\n", + "\n", + " Output:\n", + " - coolness_scores (List[float]): The estimated coolness scores for the possible moves.\n", + " \"\"\"\n", + " dance_sequence = tuple(dance_sequence)\n", + "\n", + " # Check if the coolness value is already computed and cached\n", + " if dance_sequence in heuristic_cache:\n", + " return heuristic_cache[\"h\"][dance_sequence], 0\n", + "\n", + " # Prepare the inputs for the LLM prompt\n", + " prompt_inputs = {\n", + " 'n_neighbors': len(neighbors_list),\n", + " 'neighbors_list': str(neighbors_list),\n", + " 'dance_sequence': str(list(dance_sequence)),\n", + " 'duration': max_duration\n", + " }\n", + "\n", + " # Run the chain, extract and process the response\n", + " try:\n", + " # Call the LLM to get the coolness score\n", + " response = (h_prompt | gpt_4o).invoke(prompt_inputs)\n", + " cost = response.response_metadata[\"token_usage\"]\n", + " total_cost = cost[\"prompt_tokens\"] * gpt_4o_costs[\"prompt\"] + cost[\"completion_tokens\"] * gpt_4o_costs[\"completion\"]\n", + " coolness_estimations = response.content.strip().split(\",\")\n", + "\n", + " if len(neighbors_list) != len(coolness_estimations):\n", + " raise ValueError(\"The number of coolness estimations does not match the number of neighbors\")\n", + "\n", + " for i in range(len(coolness_estimations)):\n", + " coolness_estimations[i] = max(0.0, float(coolness_estimations[i]))\n", + " coolness_estimations[i] = min(coolness_estimations[i], 5)\n", + "\n", + " # Store the coolness score in the cache\n", + " heuristic_cache[\"h\"][dance_sequence] = coolness_estimations\n", + "\n", + " return coolness_estimations, total_cost\n", + " except ValueError as e:\n", + " raise ValueError(f\"Invalid response from the h_LLM: {e}\") from e" + ], + "id": "f092b715fb7998a9" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Combining Past and Future (f-function)", + "id": "4a268e2cd1321e49" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T16:26:42.769481Z", + "start_time": "2025-01-12T16:26:42.764876Z" + }, + "id": "b29734fb6a1fcd68" + }, + "cell_type": "code", + "source": [ + "def f_function(coolness_value, cool_estimation, choreo_duration, new_sequence, current_target, shortest_path, shortest_path_inverted, final_state, mandatory_positions, intermediate_positions, n_intermediate_positions):\n", + " \"\"\"\n", + " Compute the total score of a choreography as the weighted sum of four elements. This is intended to use within our A* implementation.\n", + " 1. g + h: the sum of the coolness (g) and heuristic (h) scores\n", + " 2. t: progress indicator that leads to the next waypoint\n", + " 3. duration bonus: the more similar to the duration constraints the higher the bonus, i.e. follows an asymmetric distribution where points smaller than the average behave like a normal distribution, whereas points larger than the average are penalized more as they increase in values.\n", + " 4. constraint satisfaction count: percentage of given satisfied constraints.\n", + "\n", + " Args:\n", + " coolness_value (float): A measure of the choreography's coolness. Output of LLM via g-function.\n", + " cool_estimation (float): A heuristic estimation of the choreography’s coolness. Output of LLM via h-function.\n", + " choreo_duration (float): The duration of the choreography. Output of LLM via g-function.\n", + " new_sequence (list[str]): The sequence of moves in the choreography.\n", + " current_target (str): The current waypoint target in the choreography.\n", + " shortest_path (dict): Mapping (matrix) of nodes to the shortest path lengths for each target (computed via NetworkX), i.e. source in rows and target in columns.\n", + " shortest_path_inverted (dict): Inverted mapping of targets to path lengths for each node, i.e. target in rows and source in columns.\n", + " final_state (str): The final target state for the choreography.\n", + " mandatory_positions (set): Moves that are mandatory in the choreography sequence.\n", + " intermediate_positions (set): Optional intermediate moves for added complexity.\n", + " n_intermediate_positions (int): The minimum number of distinct intermediate moves required.\n", + "\n", + " Returns:\n", + " float: the weighed sum of the 4 components.\n", + " \"\"\"\n", + "\n", + " # g and h such that g+h in [0,1], 1 ==> the coolest choreography\n", + " g_value = coolness_value * min(choreo_duration, max_duration) / (max_duration * max_coolness_value)\n", + " h_value = cool_estimation * max((max_duration-choreo_duration), 0) / (max_duration * max_coolness_value)\n", + "\n", + " # t such that t in [0,1], 1 ==> the shortest path to the target (i.e. target itself)\n", + " current_node = new_sequence[-1]\n", + " #If the choreography is complete but there are still waypoints t_value = 0\n", + " if current_node == final_state and current_target!=final_state:\n", + " t_value = 0\n", + " #If the choreography is complete and the only waypoint left is final_state t_value = 1\n", + " elif current_node == final_state and current_target==final_state:\n", + " t_value = 1\n", + " #If the choreography is not complete and the only waypoint left is final_state t_value = 1 !!! More freedom if the next waypoint is the final one\n", + " elif current_node != final_state and current_target==final_state:\n", + " t_value = 1\n", + " else:\n", + " t_value = 1-(shortest_path[current_node][current_target] / (max(shortest_path_inverted[current_target].values()) + 1))\n", + "\n", + " # duration bonus such that it is in [0,1], 1 ==> choreography has the perfect duration\n", + " duration_bonus = asymmetric_distribution(choreo_duration)\n", + "\n", + " # constraint satisfaction counts in [0,1], 1 ==> all constraint are satisfied\n", + " n_mandatory_check = len({move for move in new_sequence if move in mandatory_positions})\n", + " n_intermediate_check = min(len({move for move in new_sequence if move in intermediate_positions}), n_intermediate_positions)\n", + " constraint_sat_count = (n_mandatory_check + n_intermediate_check)/(len(mandatory_positions) + n_intermediate_positions)\n", + "\n", + " if debug:\n", + " print(f\"F-FUNCTION -> count={constraint_sat_count}, n_mandatory_check={n_mandatory_check}, n_intermediate_check={n_intermediate_check} \"\"(g+h)={(g_value + h_value)}, choreo_duration={choreo_duration}, duration_bonus={duration_bonus}\")\n", + "\n", + " # final score\n", + " final_score = ((g_value + h_value) * gh_weight) + (t_value * t_weight) + (duration_bonus * db_weight) + (constraint_sat_count * csc_weight)\n", + "\n", + " return final_score" + ], + "id": "b29734fb6a1fcd68", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Solution Definition", + "id": "6bc58043b54fed76" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T16:26:44.501436Z", + "start_time": "2025-01-12T16:26:44.493089Z" + }, + "id": "b5ddf80ca2833e6" + }, + "cell_type": "code", + "source": [ + "def custom_a_star(graph, initial_state, final_state, targets, mandatory_positions, intermediate_positions, n_intermediate_positions, f_function, timeout=60):\n", + " open_set = []\n", + " candidate_return = []\n", + " total_cost = 0\n", + " total_calls = 0\n", + " heapq.heappush(open_set, (0, [initial_state], 1)) # (estimated coolness of complete path, uncompleted path, index of current target)\n", + " shortest_path = dict(nx.all_pairs_shortest_path_length(graph))\n", + "\n", + " # Used in t-functon\n", + " shortest_path_inverted = defaultdict(dict)\n", + " for source, target_dict in shortest_path.items():\n", + " for target, length in target_dict.items():\n", + " shortest_path_inverted[target][source] = length\n", + "\n", + " # Run the A* algorithm\n", + " time_now = time()\n", + " while open_set:\n", + " elapsed_time = time()-time_now\n", + "\n", + " #If timeout is reached and a candidate solution is found, return it.\n", + " if candidate_return and elapsed_time > timeout:\n", + " candidate_return.sort(key=lambda x: x[0])\n", + " if debug:\n", + " print(f\"Timeout reached. Returning best candidate so far (score: {candidate_return[0][1]})\")\n", + " return candidate_return[0], total_cost, total_calls, elapsed_time\n", + " _, current_sequence, current_target_index = heapq.heappop(open_set)\n", + " current_target = targets[current_target_index]\n", + " current_node = current_sequence[-1]\n", + "\n", + " #If the sequence with the best score is a complete choreography, return it.\n", + " if current_node == final_state:\n", + " return current_sequence, total_cost, total_calls, elapsed_time\n", + "\n", + " #If the sequence reached a target, update the index to point to the next target\n", + " if current_node == current_target:\n", + " if current_target_index < len(targets) - 1:\n", + " current_target_index += 1\n", + " current_target = targets[current_target_index]\n", + "\n", + " n_mandatory_check = len({move for move in current_sequence if move in mandatory_positions})\n", + " n_intermediate_check = len({move for move in current_sequence if move in intermediate_positions})\n", + "\n", + " all_constraint_satisfied = n_mandatory_check == len(mandatory_positions) and n_intermediate_check >= n_intermediate_positions\n", + "\n", + " #Estimate the coolness of current_sequence\n", + " coolness_value, heuristic_cost = g_llm(current_sequence)\n", + " choreo_duration = compute_total_duration(current_sequence)\n", + " total_cost += heuristic_cost\n", + " total_calls += 1\n", + "\n", + " neighbors_list = list(graph.neighbors(current_node))\n", + "\n", + " if choreo_duration >= max_duration - duration_tolerance and all_constraint_satisfied and final_state in neighbors_list: # Push to heap IFF the max_duration is reached. Hopefully will be returned at next iteration\n", + " estimated_coolness = f_function(coolness_value, 0, choreo_duration, current_sequence, current_target, shortest_path, shortest_path_inverted, final_state, mandatory_positions, intermediate_positions, n_intermediate_positions)\n", + " candidate_return.append((-estimated_coolness, current_sequence + [final_state]))\n", + " if debug:\n", + " print(f\"Found a candidate with score {coolness_value} and duration {choreo_duration}\")\n", + " heapq.heappush(open_set, (*candidate_return[-1], current_target_index))\n", + " # if you receive a longer choreography (+5%), just discard it\n", + " elif choreo_duration <= max_duration + duration_tolerance:\n", + " # If the duration is not reached, the final state cannot be reach\n", + " if final_state in neighbors_list:\n", + " neighbors_list.remove(final_state)\n", + "\n", + " # Check the neighbors\n", + " coolness_estimations, heuristic_cost = h_llm(neighbors_list, current_sequence)\n", + "\n", + " total_cost += heuristic_cost\n", + " total_calls += 1\n", + "\n", + " for move, cool_estimation in zip(neighbors_list, coolness_estimations):\n", + " new_sequence = current_sequence + [move]\n", + "\n", + " estimated_coolness = f_function(coolness_value, cool_estimation, choreo_duration, new_sequence, current_target, shortest_path, shortest_path_inverted, final_state, mandatory_positions, intermediate_positions, n_intermediate_positions)\n", + " heapq.heappush(open_set, (-estimated_coolness, new_sequence, current_target_index))\n", + "\n", + " if debug:\n", + " print(f\"CURRENT sequence: {current_sequence} - Total duration: {choreo_duration}s - Coolness: {coolness_value} - All constraint satisfied: {all_constraint_satisfied}\")\n", + " print(f\"CURRENT cost: {total_cost:.6f}€ - Total calls: {total_calls}\\n\")" + ], + "id": "b5ddf80ca2833e6", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Solution Execution\n", + "\n", + "#### Weights of the 4 Components of the Score\n", + "\n", + "- **`gh_weight`**: Higher values indicate greater importance placed on coolness.\n", + "- **`t_weight`**: Higher values indicate greater importance placed on following the waypoints.\n", + "- **`db_weight`**: Higher values provide a greater bonus to the choreographer with the right duration (A* checks fewer alternatives).\n", + "- **`csc_weight`**: Higher values indicate greater importance placed on satisfying the constraint." + ], + "id": "51321fae20f94868" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Balanced\n", + "gh_weight, t_weight, db_weight, csc_weight = 1, 3, 4, 3\n", + "\n", + "# Very high Constraint and Waypoints\n", + "#gh_weight, t_weight, db_weight, csc_weight = 1,10,1,10\n", + "\n", + "#High Constraint and Waypoints\n", + "#gh_weight, t_weight, db_weight, csc_weight = 1,2,1,2\n", + "\n", + "#Less importance to waypoints\n", + "#gh_weight, t_weight, db_weight, csc_weight = 3,1,1,2" + ], + "id": "b0b49a7a95e3034b" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-12T16:28:09.863731Z", + "start_time": "2025-01-12T16:27:38.357202Z" + }, + "id": "8f08f9668aac9921", + "outputId": "402742e8-0708-4e8a-acd1-a081e6f8a9f0" + }, + "cell_type": "code", + "source": [ + "heuristic_cache = {'g': {}, 'h': {}}\n", + "choreography, cost, n_calls, elapsed_time = custom_a_star(graph=G, initial_state=initial_state, final_state=final_state, targets=waypoints, mandatory_positions=mandatory_positions, intermediate_positions=intermediate_positions, n_intermediate_positions=5, f_function=f_function)\n", + "print(f\"Choreography: {choreography}\\nCost: {cost:.5f}€\\nNumber of calls to LLM: {n_calls}\\nElapsed time: {elapsed_time:.2f}s\")" + ], + "id": "8f08f9668aac9921", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Choreography: ['INITIAL_stand_init', 'mandatory_stand', 'diagonal_left', 'superman', 'rotation_foot_right_leg', 'arms_opening', 'birthday_dance', 'mandatory_wipe_forehead', 'mandatory_hello', 'mandatory_sit', 'mandatory_zero', 'play_guitar', 'double_movement_rotation_of_arms', 'sprinkler_dance', 'workout_legs_and_arms', 'move_backward', 'arms_dance', 'move_forward', 'rotation_handgun_object', 'diagonal_right', 'right_arm_rotation', 'rotation_foot_left_leg', 'arms_dance', 'arms_dance', 'union_arms', 'arms_opening', 'birthday_dance', 'mandatory_hello', 'FINAL_crouch']\n", + "Cost: 0.06451€\n", + "Number of calls to LLM: 59\n", + "Elapsed time: 31.50s\n" + ] + } + ], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "a3i", + "language": "python", + "name": "a3i" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + }, + "colab": { + "provenance": [] + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/problem_statement.txt b/problem_statement.txt new file mode 100644 index 0000000..809ab60 --- /dev/null +++ b/problem_statement.txt @@ -0,0 +1,43 @@ +Initial state: INITIAL_stand_init + +Final state: FINAL_crouch + +Constraints: - number of DIFFERENT mandatory positions >= 5 + - number of DIFFERENT intermediate positions >= 5 + +Incompatibilities between consecutive positions: +(('Mandatory_sit', 'diagonal_right')) +(('Mandatory_sit’, 'diagonal_left') ) +(('Mandatory_sit', 'rotation_foot_right_leg')) +(('Mandatory_sit', 'rotation_foot_left_leg')) +(('Mandatory_sit', 'FINAL_crouch')) +(('Mandatory_sit', 'move_backward')) +(('Mandatory_sit', 'move_forward')) +(('Mandatory_sit', 'play_guitar')) +(('Mandatory_sit', 'arms_dance')) +(('Mandatory_sit', 'birthday_dance')) +(('Mandatory_sit', 'sprinkler_dance')) +(('Mandatory_sit', 'workout_legs_and_arms')) +(('Mandatory_sit', 'superman')) + +List of intermediate positions: +[rotation_handgun_object, right_arm_rotation, double_movement_rotation_of_arms, arms_opening, union_arms, move_forward, move_backward, diagonal_left, diagonal_right, rotation_foot_left_leg, rotation_foot_right_leg, play_guitar, arms_dance, birthday_dance, sprinkler_dance, workout_legs_and_arms, superman] + +List of mandatory positions: +[Mandatory_sit, Mandatory_wipe_forehead, Mandatory_hello, Mandatory_stand, Mandatory_zero] + +Example of prompt tested with a LLM in zero-shot modality: +"Create a choreography as a sequence of positions for a NAO robot, by respecting the following constraints: -Start position:[INITIAL_stand_init]. +- End position: [FINAL_crouch]. - Mandatory positions: [Mandatory_sit, Mandatory_wipe_forehead, Mandatory_hello, Mandatory_stand, Mandatory_zero]. +-Intermediate positions: [rotation_handgun_object, right_arm_rotation, double_movement_rotation_of _arms, arms_opening, union_arms, move_forward, +move_backward, diagonal_left, diagonal_right, rotation_foot_left_leg, rotation_foot_right_leg, play_guitar, arms_dance, birthday_dance, sprinkler_dance, +workout_legs_and_arms, superman]. + +- Use all the mandatory positions. +- Use at least 5 intermediate positions. +- Positions can be repeated. +- Shuffle the order of mandatory and intermediate positions." + +Repo: https://github.com/alleDe/LLMsChoreography + +Paper: https://www.ijcai.org/proceedings/2024/844 diff --git a/setup_env.sh b/setup_env.sh new file mode 100755 index 0000000..4e719a0 --- /dev/null +++ b/setup_env.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Setup the Conda environment for the project +ENV_NAME="a3i" +YAML_FILE="env.yml" + +# Ensure `conda` is initialized +eval "$(conda shell.bash hook)" + +# Check if the Conda environment exists +conda env list | grep -q "^$ENV_NAME " + +# shellcheck disable=SC2181 +if [ $? -eq 0 ]; then + echo "Conda environment '$ENV_NAME' already exists." +else + echo "Conda environment '$ENV_NAME' does not exist. Creating it from '$YAML_FILE'..." + if [ -f "$YAML_FILE" ]; then + conda env create -f "$YAML_FILE" + if [ $? -eq 0 ]; then + echo "Environment '$ENV_NAME' created successfully." + else + echo "Error: Failed to create the environment. Check the YAML file." + exit 1 + fi + else + echo "Error: YAML file '$YAML_FILE' not found." + exit 1 + fi +fi + +# Output the command to activate the environment in the parent shell +echo "Run the following command to activate the environment in your shell:" +echo "conda activate $ENV_NAME"