-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 8fdb8aa
Showing
8 changed files
with
1,406 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
name: lint_and_style | ||
|
||
on: | ||
pull_request: | ||
push: | ||
branches: [main, master] | ||
|
||
jobs: | ||
pre-commit: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: conda-incubator/setup-miniconda@v2 | ||
with: | ||
environment-file: env.yml | ||
activate-environment: lint_and_style | ||
auto-activate-base: false | ||
- uses: pre-commit/action@v3.0.1 | ||
|
||
pylint: | ||
runs-on: ubuntu-latest | ||
needs: pre-commit | ||
continue-on-error: true | ||
if: github.event_name == 'pull_request' | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: conda-incubator/setup-miniconda@v2 | ||
with: | ||
environment-file: env.yml | ||
activate-environment: lint_and_style | ||
auto-activate-base: false | ||
- uses: pre-commit/action@v3.0.1 | ||
with: | ||
extra_args: --hook-stage manual pylint-all --all-files |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
*.env | ||
__pycache__ | ||
.vscode | ||
.idea | ||
**/.DS_Store | ||
mcp_solver.egg-info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.5.0 | ||
hooks: | ||
- id: check-docstring-first | ||
- id: check-toml | ||
- id: check-yaml | ||
exclude: packaging/.* | ||
- id: mixed-line-ending | ||
args: [--fix=lf] | ||
- id: end-of-file-fixer | ||
- repo: https://github.com/hhatto/autopep8 | ||
rev: v2.1.0 | ||
hooks: | ||
- id: autopep8 | ||
args: [--in-place, --aggressive, --exit-code] | ||
types: [python] | ||
- repo: local | ||
hooks: | ||
- id: pylint | ||
name: pylint | ||
entry: pylint | ||
language: system | ||
types: [python] | ||
args: | ||
[ | ||
"--max-line-length=120", | ||
"--errors-only", | ||
] | ||
- id: pylint | ||
alias: pylint-all | ||
name: pylint-all | ||
entry: pylint | ||
language: system | ||
types: [python] | ||
args: | ||
[ | ||
"--max-line-length=120", | ||
"--disable=W2402", # non-ascii-file-name, but Streamlit page names contain emojis | ||
] | ||
stages: [manual] | ||
- repo: https://github.com/pycqa/flake8 | ||
rev: 7.0.0 | ||
hooks: | ||
- id: flake8 | ||
args: [--config=setup.cfg] | ||
additional_dependencies: | ||
- flake8-bugbear==22.10.27 | ||
- flake8-comprehensions==3.10.1 | ||
- torchfix==0.0.2 | ||
- repo: https://github.com/facebook/usort | ||
rev: v1.0.7 | ||
hooks: | ||
- id: usort | ||
name: Sort imports with µsort | ||
description: Safe, minimal import sorting | ||
language: python | ||
types_or: | ||
- python | ||
- pyi | ||
entry: usort format | ||
require_serial: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# A3I Project - Choreography Planning by Decoupling A* Deterministic Search and Heuristic Computation | ||
|
||
## Overview | ||
This project uses advanced techniques to design and evaluate dance choreographies for the NAO robot. By combining heuristic methods, a custom A* search algorithm, and large language model (LLM) evaluations, the system generates creative and engaging robot movements tailored to specific constraints and criteria. | ||
|
||
## Features | ||
- **Choreography Evaluation**: Rates choreography based on storytelling, rhythm, movement technique, public involvement, space use, human characterization, and human reproducibility. | ||
- **A* Search Algorithm**: Custom implementation to find the optimal sequence of dance moves while adhering to constraints like duration and mandatory moves. | ||
- **Heuristic Computation**: Combines past scores (g-function) and future estimations (h-function) to guide the search process. | ||
- **Dynamic Scoring**: Allows adjusting weights for different scoring components (e.g., waypoint adherence, duration bonus). | ||
|
||
## Components | ||
1. **LLM Scoring Functions**: | ||
- `g_llm`: Evaluates the coolness of a complete dance sequence. | ||
- `h_llm`: Estimates the coolness of potential moves for an incomplete sequence. | ||
2. **Custom A* Implementation**: | ||
- Uses LLM-generated scores to find optimal paths for robot choreography. | ||
- Incorporates dynamic constraint satisfaction. | ||
3. **Scoring Criteria**: | ||
- `gh_weight`: Importance of coolness scores. | ||
- `t_weight`: Importance of adhering to waypoints. | ||
- `db_weight`: Bonus for optimal duration. | ||
- `csc_weight`: Importance of satisfying constraints. | ||
|
||
## Input Data | ||
- **Initial State**: Starting position of the robot. | ||
- **Final State**: Target position to complete the choreography. | ||
- **Waypoints**: Intermediate positions the robot must pass through. | ||
- **Mandatory Positions**: Specific moves required in the choreography. | ||
- **Intermediate Positions**: Optional moves for enhanced creativity. | ||
- **Graph**: A connectivity graph representing valid transitions between moves. | ||
|
||
## Outputs | ||
- A complete sequence of moves. | ||
- Coolness scores for the choreography. | ||
- Computation cost and number of LLM calls. | ||
|
||
## Installation | ||
1. Clone this repository: | ||
```bash | ||
git clone https://github.com/lorenzobalzani/llm-a-star-planner.git | ||
``` | ||
2. Make sure to have conda installed onto your system. | ||
3. Set up the conda environment by launching the script `./setup_env.sh` in the terminal. | ||
4. Activate it with `conda activate a3i`. | ||
5. For installing the ipython kernel, please run `python -m ipykernel install --user --name=a3i`. | ||
|
||
## Setup | ||
Create a file named `secrets.env` in the root directory of the project and add the following lines: | ||
```bash | ||
AZURE_OPENAI_ENDPOINT = <your-deployment-edpoint> | ||
AZURE_OPENAI_API_KEY = <your-api-key> | ||
OPENAI_API_VERSION = <your-api-version> | ||
``` | ||
If you prefer to use an LLM other than Azure OpenAI, update the LLM client definitions in the notebook accordingly | ||
|
||
## Example Output | ||
``` | ||
Choreography: ['INITIAL_stand_init', 'mandatory_stand', 'diagonal_left', ... 'FINAL_crouch'] | ||
Cost: 0.06451€ | ||
Number of calls to LLM: 59 | ||
Elapsed time: 31.50s | ||
``` | ||
|
||
## Contributing | ||
Contributions are welcome! Please open an issue or submit a pull request. | ||
|
||
## Authors | ||
- **Davide Bombardi** (<a href="mailto:davide.bombardi@studio.unibo.it">davide.bombardi@studio.unibo.it</a>) | ||
- **Lorenzo Balzani** (<a href="mailto:lorenzo.balzani@studio.unibo.it">lorenzo.balzani@studio.unibo.it</a>) | ||
|
||
## License | ||
This project is licensed under the MIT License. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
name: a3i | ||
channels: | ||
- conda-forge | ||
- defaults | ||
- https://repo.anaconda.com/pkgs/main | ||
- https://repo.anaconda.com/pkgs/r | ||
dependencies: | ||
- bzip2=1.0.8=h80987f9_6 | ||
- ca-certificates=2024.9.24=hca03da5_0 | ||
- expat=2.6.3=h313beb8_0 | ||
- libcxx=14.0.6=h848a8c0_0 | ||
- libffi=3.4.4=hca03da5_1 | ||
- ncurses=6.4=h313beb8_0 | ||
- openssl=3.0.15=h80987f9_0 | ||
- pip=24.2=py312hca03da5_0 | ||
- python=3.12.7=h99e199e_0 | ||
- readline=8.2=h1a28f6b_0 | ||
- setuptools=75.1.0=py312hca03da5_0 | ||
- sqlite=3.45.3=h80987f9_0 | ||
- tk=8.6.14=h6ba3021_0 | ||
- tzdata=2024b=h04d1e81_0 | ||
- wheel=0.44.0=py312hca03da5_0 | ||
- xz=5.4.6=h80987f9_1 | ||
- zlib=1.2.13=h18a0788_1 | ||
- pip: | ||
- aiohappyeyeballs==2.4.3 | ||
- aiohttp==3.11.2 | ||
- aiosignal==1.3.1 | ||
- annotated-types==0.7.0 | ||
- anyio==4.6.2.post1 | ||
- appnope==0.1.4 | ||
- argon2-cffi==23.1.0 | ||
- argon2-cffi-bindings==21.2.0 | ||
- arrow==1.3.0 | ||
- astroid==3.3.5 | ||
- asttokens==2.4.1 | ||
- async-lru==2.0.4 | ||
- attrs==24.2.0 | ||
- babel==2.16.0 | ||
- beautifulsoup4==4.12.3 | ||
- bleach==6.2.0 | ||
- certifi==2024.8.30 | ||
- cffi==1.17.1 | ||
- cfgv==3.4.0 | ||
- charset-normalizer==3.4.0 | ||
- comm==0.2.2 | ||
- contourpy==1.3.1 | ||
- cycler==0.12.1 | ||
- dataclasses-json==0.6.7 | ||
- debugpy==1.8.8 | ||
- decorator==5.1.1 | ||
- defusedxml==0.7.1 | ||
- dill==0.3.9 | ||
- distlib==0.3.9 | ||
- distro==1.9.0 | ||
- executing==2.1.0 | ||
- fastjsonschema==2.20.0 | ||
- filelock==3.16.1 | ||
- fonttools==4.55.0 | ||
- fqdn==1.5.1 | ||
- frozenlist==1.5.0 | ||
- h11==0.14.0 | ||
- httpcore==1.0.7 | ||
- httpx==0.27.2 | ||
- httpx-sse==0.4.0 | ||
- identify==2.6.2 | ||
- idna==3.10 | ||
- ipykernel==6.29.5 | ||
- ipython==8.29.0 | ||
- ipywidgets==8.1.5 | ||
- isoduration==20.11.0 | ||
- isort==5.13.2 | ||
- jedi==0.19.2 | ||
- jinja2==3.1.4 | ||
- jiter==0.7.1 | ||
- json5==0.9.28 | ||
- jsonpatch==1.33 | ||
- jsonpointer==3.0.0 | ||
- jsonschema==4.23.0 | ||
- jsonschema-specifications==2024.10.1 | ||
- jupyter==1.1.1 | ||
- jupyter-client==8.6.3 | ||
- jupyter-console==6.6.3 | ||
- jupyter-core==5.7.2 | ||
- jupyter-events==0.10.0 | ||
- jupyter-lsp==2.2.5 | ||
- jupyter-server==2.14.2 | ||
- jupyter-server-terminals==0.5.3 | ||
- jupyterlab==4.2.6 | ||
- jupyterlab-pygments==0.3.0 | ||
- jupyterlab-server==2.27.3 | ||
- jupyterlab-widgets==3.0.13 | ||
- kiwisolver==1.4.7 | ||
- langchain==0.3.7 | ||
- langchain-community==0.3.7 | ||
- langchain-core==0.3.19 | ||
- langchain-openai==0.2.8 | ||
- langchain-text-splitters==0.3.2 | ||
- langsmith==0.1.143 | ||
- load-dotenv==0.1.0 | ||
- loguru==0.7.2 | ||
- markupsafe==3.0.2 | ||
- marshmallow==3.23.1 | ||
- matplotlib==3.9.2 | ||
- matplotlib-inline==0.1.7 | ||
- mccabe==0.7.0 | ||
- mistune==3.0.2 | ||
- multidict==6.1.0 | ||
- mypy-extensions==1.0.0 | ||
- nbclient==0.10.0 | ||
- nbconvert==7.16.4 | ||
- nbformat==5.10.4 | ||
- nest-asyncio==1.6.0 | ||
- networkx==3.4.2 | ||
- nodeenv==1.9.1 | ||
- notebook==7.2.2 | ||
- notebook-shim==0.2.4 | ||
- numpy==1.26.4 | ||
- openai==1.54.4 | ||
- orjson==3.10.11 | ||
- overrides==7.7.0 | ||
- packaging==24.2 | ||
- pandocfilters==1.5.1 | ||
- parso==0.8.4 | ||
- pexpect==4.9.0 | ||
- pillow==11.0.0 | ||
- platformdirs==4.3.6 | ||
- pre-commit==4.0.1 | ||
- prometheus-client==0.21.0 | ||
- prompt-toolkit==3.0.48 | ||
- propcache==0.2.0 | ||
- psutil==6.1.0 | ||
- ptyprocess==0.7.0 | ||
- pure-eval==0.2.3 | ||
- pycparser==2.22 | ||
- pydantic==2.9.2 | ||
- pydantic-core==2.23.4 | ||
- pydantic-settings==2.6.1 | ||
- pygments==2.18.0 | ||
- pylint==3.3.1 | ||
- pyparsing==3.2.0 | ||
- python-dateutil==2.9.0.post0 | ||
- python-dotenv==1.0.1 | ||
- python-json-logger==2.0.7 | ||
- pyyaml==6.0.2 | ||
- pyzmq==26.2.0 | ||
- referencing==0.35.1 | ||
- regex==2024.11.6 | ||
- requests==2.32.3 | ||
- requests-toolbelt==1.0.0 | ||
- rfc3339-validator==0.1.4 | ||
- rfc3986-validator==0.1.1 | ||
- rpds-py==0.21.0 | ||
- scipy==1.14.1 | ||
- send2trash==1.8.3 | ||
- six==1.16.0 | ||
- sniffio==1.3.1 | ||
- soupsieve==2.6 | ||
- sqlalchemy==2.0.35 | ||
- stack-data==0.6.3 | ||
- tenacity==9.0.0 | ||
- terminado==0.18.1 | ||
- tiktoken==0.8.0 | ||
- tinycss2==1.4.0 | ||
- tomlkit==0.13.2 | ||
- tornado==6.4.1 | ||
- tqdm==4.67.0 | ||
- traitlets==5.14.3 | ||
- types-python-dateutil==2.9.0.20241003 | ||
- typing-extensions==4.12.2 | ||
- typing-inspect==0.9.0 | ||
- uri-template==1.3.0 | ||
- urllib3==2.2.3 | ||
- virtualenv==20.27.1 | ||
- wcwidth==0.2.13 | ||
- webcolors==24.11.1 | ||
- webencodings==0.5.1 | ||
- websocket-client==1.8.0 | ||
- widgetsnbextension==4.0.13 | ||
- yarl==1.17.1 |
Oops, something went wrong.