Skip to content

Commit

Permalink
Merge pull request #54 from web-arena-x/new_eval
Browse files Browse the repository at this point in the history
Massive PR on many parts
  • Loading branch information
shuyanzhou authored Oct 21, 2023
2 parents ba3c07c + 9f0900f commit a68aa1a
Show file tree
Hide file tree
Showing 25 changed files with 1,365 additions and 678 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
mypy --version
# Run this mypy instance against our main package.
mypy --install-types --non-interactive .
mypy --strict .
mypy --strict . --exclude scripts
- name: Enviroment prepare
run: |
bash prepare.sh
Expand Down
23 changes: 12 additions & 11 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,19 @@ run.sh

# trajectory visualization
render_cache/*
cache/*

# TMP IGNORE
agent/prompts/jsons/*
log_files/
config_files/*0.json
config_files/*1.json
config_files/*2.json
config_files/*3.json
config_files/*4.json
config_files/*5.json
config_files/*6.json
config_files/*7.json
config_files/*8.json
config_files/*9.json
config_files/test.json
config_files*/*0.json
config_files*/*1.json
config_files*/*2.json
config_files*/*3.json
config_files*/*4.json
config_files*/*5.json
config_files*/*6.json
config_files*/*7.json
config_files*/*8.json
config_files*/*9.json
config_files*/test.json
95 changes: 33 additions & 62 deletions agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
create_playwright_action,
)
from browser_env.utils import Observation, StateInfo
from llms import lm_config
from llms.providers.openai_utils import (
from llms import (
call_llm,
generate_from_huggingface_completion,
generate_from_openai_chat_completion,
generate_from_openai_completion,
lm_config,
)
from llms.tokenizers import Tokenizer


class Agent:
Expand Down Expand Up @@ -120,82 +123,50 @@ def next_action(
trajectory, intent, meta_data
)
lm_config = self.lm_config
if lm_config.provider == "openai":
if lm_config.mode == "chat":
response = generate_from_openai_chat_completion(
messages=prompt,
model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
context_length=lm_config.gen_config["context_length"],
max_tokens=lm_config.gen_config["max_tokens"],
stop_token=None,
)
elif lm_config.mode == "completion":
response = generate_from_openai_completion(
prompt=prompt,
engine=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
)
else:
raise ValueError(
f"OpenAI models do not support mode {lm_config.mode}"
n = 0
while True:
response = call_llm(lm_config, prompt)
force_prefix = self.prompt_constructor.instruction[
"meta_data"
].get("force_prefix", "")
response = f"{force_prefix}{response}"
n += 1
try:
parsed_response = self.prompt_constructor.extract_action(
response
)
else:
raise NotImplementedError(
f"Provider {lm_config.provider} not implemented"
)

try:
parsed_response = self.prompt_constructor.extract_action(response)
if self.action_set_tag == "id_accessibility_tree":
action = create_id_based_action(parsed_response)
elif self.action_set_tag == "playwright":
action = create_playwright_action(parsed_response)
else:
raise ValueError(f"Unknown action type {self.action_set_tag}")

action["raw_prediction"] = response

except ActionParsingError as e:
action = create_none_action()
action["raw_prediction"] = response
if self.action_set_tag == "id_accessibility_tree":
action = create_id_based_action(parsed_response)
elif self.action_set_tag == "playwright":
action = create_playwright_action(parsed_response)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
)
action["raw_prediction"] = response
break
except ActionParsingError as e:
if n >= lm_config.gen_config["max_retry"]:
action = create_none_action()
action["raw_prediction"] = response
break

return action

def reset(self, test_config_file: str) -> None:
pass


def construct_llm_config(args: argparse.Namespace) -> lm_config.LMConfig:
llm_config = lm_config.LMConfig(
provider=args.provider, model=args.model, mode=args.mode
)
if args.provider == "openai":
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["context_length"] = args.context_length
llm_config.gen_config["max_tokens"] = args.max_tokens
llm_config.gen_config["stop_token"] = args.stop_token
llm_config.gen_config["max_obs_length"] = args.max_obs_length
else:
raise NotImplementedError(f"provider {args.provider} not implemented")
return llm_config


def construct_agent(args: argparse.Namespace) -> Agent:
llm_config = construct_llm_config(args)
llm_config = lm_config.construct_llm_config(args)

agent: Agent
if args.agent_type == "teacher_forcing":
agent = TeacherForcingAgent()
elif args.agent_type == "prompt":
with open(args.instruction_path) as f:
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
tokenizer = tiktoken.encoding_for_model(llm_config.model)
tokenizer = Tokenizer(args.provider, args.model)
prompt_constructor = eval(constructor_type)(
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
)
Expand Down
58 changes: 45 additions & 13 deletions agent/prompts/prompt_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from pathlib import Path
from typing import Any, TypedDict

import tiktoken

from browser_env import Action, ActionParsingError, Trajectory
from browser_env.env_config import URL_MAPPINGS
from browser_env.utils import StateInfo
from llms import lm_config

APIInput = str | list[Any] | dict[str, Any]
from llms.tokenizers import Tokenizer
from llms.utils import APIInput


class Instruction(TypedDict):
Expand All @@ -27,12 +25,12 @@ def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
tokenizer: Tokenizer,
):
self.instrction_path = Path(instruction_path)
self.instruction_path = Path(instruction_path)
self.obs_modality = "text"
self.lm_config = lm_config
instruction = json.load(open(self.instrction_path))
instruction = json.load(open(self.instruction_path))
instruction["examples"] = [tuple(e) for e in instruction["examples"]]
self.instruction: Instruction = instruction
self.tokenizer = tokenizer
Expand Down Expand Up @@ -77,6 +75,37 @@ def get_lm_api_input(
raise ValueError(
f"OpenAI models do not support mode {self.lm_config.mode}"
)
elif "huggingface" in self.lm_config.provider:
# https://huggingface.co/blog/llama2#how-to-prompt-llama-2
# https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L320
if "Llama-2" in self.lm_config.model:
if self.lm_config.mode == "chat":
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
BOS, EOS = "<s>", "</s>"
# adding the system message to be the starting of the first example
examples = [
(
B_SYS + intro + E_SYS + examples[0][0],
examples[0][1],
)
] + examples[1:]
message = "".join(
[
f"{BOS}{B_INST} {x.strip()} {E_INST} {y.strip()} {EOS}"
for (x, y) in examples
]
)
# add the current observation
message += f"{BOS}{B_INST} {current.strip()} {E_INST} {self.instruction['meta_data'].get('force_prefix', '')}"

return message
else:
raise ValueError("Only chat mode is supported for Llama-2")
else:
raise ValueError(
f"Huggingface models do not support model_tag {self.lm_config.gen_config['model_tag']}"
)
else:
raise NotImplementedError(
f"Provider {self.lm_config.provider} not implemented"
Expand All @@ -102,6 +131,9 @@ def map_url_to_local(self, url: str) -> str:
for i, j in URL_MAPPINGS.items():
if j in url:
url = url.replace(j, i)
# https
if j.replace("http", "https") in url:
url = url.replace(j.replace("http", "https"), i)
return url

def _extract_action(self, response: str) -> str:
Expand All @@ -120,7 +152,7 @@ def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
tokenizer: Tokenizer,
):
super().__init__(instruction_path, lm_config, tokenizer)

Expand Down Expand Up @@ -161,10 +193,10 @@ def construct(

def _extract_action(self, response: str) -> str:
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}(.*?){action_splitter}"
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1)
return match.group(1).strip()
else:
raise ActionParsingError(
f"Cannot parse action from response {response}"
Expand All @@ -178,7 +210,7 @@ def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: tiktoken.core.Encoding,
tokenizer: Tokenizer,
):
super().__init__(instruction_path, lm_config, tokenizer)
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
Expand Down Expand Up @@ -218,10 +250,10 @@ def construct(
def _extract_action(self, response: str) -> str:
# find the first occurence of action
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}(.*?){action_splitter}"
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1)
return match.group(1).strip()
else:
raise ActionParsingError(
f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"'
Expand Down
82 changes: 82 additions & 0 deletions agent/prompts/raw/p_cot_id_actree_2s_no_na.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
prompt = {
"intro": """You are an autonomous intelligent agent tasked with navigating a web browser. You will be given web-based tasks. These tasks will be accomplished through the use of specific actions you can issue.
Here's the information you'll have:
The user's objective: This is the task you're trying to complete.
The current web page's accessibility tree: This is a simplified representation of the webpage, providing key information.
The current web page's URL: This is the page you're currently navigating.
The open tabs: These are the tabs you have open.
The previous action: This is the action you just performed. It may be helpful to track your progress.
The actions you can perform fall into several categories:
Page Operation Actions:
`click [id]`: This action clicks on an element with a specific id on the webpage.
`type [id] [content] [press_enter_after=0|1]`: Use this to type the content into the field with id. By default, the "Enter" key is pressed after typing unless press_enter_after is set to 0.
`hover [id]`: Hover over an element with id.
`press [key_comb]`: Simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v).
`scroll [direction=down|up]`: Scroll the page up or down.
Tab Management Actions:
`new_tab`: Open a new, empty browser tab.
`tab_focus [tab_index]`: Switch the browser's focus to a specific tab using its index.
`close_tab`: Close the currently active tab.
URL Navigation Actions:
`goto [url]`: Navigate to a specific URL.
`go_back`: Navigate to the previously viewed page.
`go_forward`: Navigate to the next page (if a previous 'go_back' action was performed).
Completion Action:
`stop [answer]`: Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket.
Homepage:
If you want to visit other websites, check out the homepage at http://homepage.com. It has a list of websites you can visit.
http://homepage.com/password.html lists all the account name and password for the websites. You can use them to log in to the websites.
To be successful, it is very important to follow the following rules:
1. You should only issue an action that is valid given the current observation
2. You should only issue one action at a time.
3. You should follow the examples to reason step by step and then issue the next action.
4. Generate the action in the correct format. Start with a "In summary, the next action I will perform is" phrase, followed by action inside ``````. For example, "In summary, the next action I will perform is ```click [1234]```".
5. Issue stop action when you think you have achieved the objective. Don't generate anything after stop.""",
"examples": [
(
"""OBSERVATION:
[1744] link 'HP CB782A#ABA 640 Inkjet Fax Machine (Renewed)'
[1749] StaticText '$279.49'
[1757] button 'Add to Cart'
[1760] button 'Add to Wish List'
[1761] button 'Add to Compare'
URL: http://onestopmarket.com/office-products/office-electronics.html
OBJECTIVE: What is the price of HP Inkjet Fax Machine
PREVIOUS ACTION: None""",
"Let's think step-by-step. This page list the information of HP Inkjet Fax Machine, which is the product identified in the objective. Its price is $279.49. I think I have achieved the objective. I will issue the stop action with the answer. In summary, the next action I will perform is ```stop [$279.49]```",
),
(
"""OBSERVATION:
[164] textbox 'Search' focused: True required: False
[171] button 'Go'
[174] link 'Find directions between two points'
[212] heading 'Search Results'
[216] button 'Close'
URL: http://openstreetmap.org
OBJECTIVE: Show me the restaurants near CMU
PREVIOUS ACTION: None""",
"Let's think step-by-step. This page has a search box whose ID is [164]. According to the nominatim rule of openstreetmap, I can search for the restaurants near a location by \"restaurants near\". I can submit my typing by pressing the Enter afterwards. In summary, the next action I will perform is ```type [164] [restaurants near CMU] [1]```",
),
],
"template": """OBSERVATION:
{observation}
URL: {url}
OBJECTIVE: {objective}
PREVIOUS ACTION: {previous_action}""",
"meta_data": {
"observation": "accessibility_tree",
"action_type": "id_accessibility_tree",
"keywords": ["url", "objective", "observation", "previous_action"],
"prompt_constructor": "CoTPromptConstructor",
"answer_phrase": "In summary, the next action I will perform is",
"action_splitter": "```"
},
}
Loading

0 comments on commit a68aa1a

Please sign in to comment.