diff --git a/ai_dungeon_cli/__init__.py b/ai_dungeon_cli/__init__.py index fca3c0b..a121e8d 100755 --- a/ai_dungeon_cli/__init__.py +++ b/ai_dungeon_cli/__init__.py @@ -38,10 +38,21 @@ class QuitSession(Exception): class AbstractAiDungeonGame(ABC): def __init__(self, api: AiDungeonApiClient, conf: Config, user_io: UserIo): self.stop_session: bool = False + self.user_id: str = None self.session_id: str = None + + self.scenario_id: str = '' # REVIEW: maybe call it setting_id ? + self.character_name: str = '' + self.adventure_id: str = '' self.public_id: str = None + + self.story_pitch_template: str = '' + self.story_pitch: str = '' + self.quests: str = '' + self.setting_name: str = None + self.is_multiplayer: bool = False self.story_configuration: Dict[str, str] = {} self.session: requests.Session = requests.Session() @@ -83,7 +94,7 @@ def choose_selection(self, allowed_values: Dict[str, str], k_or_v='v') -> str: continue - def choose_config(self): + def make_user_choose_config(self): pass # Initialize story @@ -149,15 +160,16 @@ def _choose_character_name(self): if character_name == "/quit": raise QuitSession("/quit") - self.api.character_name = character_name # TODO: create a setter instead + self.character_name = character_name # TODO: create a setter instead def join_multiplayer(self): - self.api.character_name = self.conf.character_name - self.api.join_multi_adventure(self.conf.public_adventure_id) + self.is_multiplayer = True + self.character_name = self.conf.character_name + self.adventure_id = self.api.join_multi_adventure(self.conf.public_adventure_id) - def choose_config(self): + def make_user_choose_config(self): # self.api.perform_init_handshake() ## SETTING SELECTION @@ -174,18 +186,19 @@ def choose_config(self): # setting_select_dict['0'] = '0' # secret mode selected_i = self.choose_selection(setting_select_dict, 'k') setting_id, self.setting_name = settings[selected_i] - self.api.scenario_id = setting_id # TODO: create a setter instead + self.scenario_id = setting_id if self.setting_name == "custom": return elif self.setting_name == "archive": while True: - prompt, options = self.api.get_options(self.api.scenario_id) + prompt, options = self.api.get_options(self.scenario_id) if options is None: - self.api.story_pitch_template = prompt + self.story_pitch_template = prompt self._choose_character_name() - self.api.set_story_pitch() + self.story_pitch = self.api.make_story_pitch(self.story_pitch_template, + self.character_name) return print(prompt + "\n") @@ -198,12 +211,12 @@ def choose_config(self): # setting_select_dict['0'] = '0' # secret mode selected_i = self.choose_selection(select_dict, 'k') option_id, option_name = options[selected_i] - self.api.scenario_id = option_id # TODO: create a setter instead + self.scenario_id = option_id ## CHARACTER SELECTION - prompt, characters = self.api.get_characters() + prompt, characters = self.api.get_characters(self.scenario_id) print(prompt + "\n") @@ -214,30 +227,29 @@ def choose_config(self): character_select_dict[str(i)] = character_type selected_i = self.choose_selection(character_select_dict, 'k') character_id, character_type = characters[selected_i] - self.api.scenario_id = character_id # TODO: create a setter instead + self.scenario_id = character_id # TODO: create a setter instead self._choose_character_name() ## PITCH - self.api.get_story_for_scenario() - self.api.set_story_pitch() + self.story_pitch_template = self.api.get_story_template_for_scenario(self.scenario_id) + self.story_pitch = self.api.make_story_pitch(self.story_pitch_template, + self.character_name) # Initialize story def init_story(self): - if self.setting_name == "custom": + if self.is_multiplayer: + self.api.init_story_multi_adventure(self.conf.public_adventure_id) + elif self.setting_name == "custom": self.init_story_custom() else: print("Generating story... Please wait...\n") - self.api.init_story() + self.adventure_id, self.public_id, self.story_pitch, self.quests = self.api.init_story(self.scenario_id, + self.story_pitch) - self.user_io.handle_story_output(self.api.story_pitch) - - - def init_story_multi_adventure(self): - self.api.init_story_multi_adventure(self.conf.public_adventure_id) - self.user_io.handle_story_output(self.api.story_pitch) + self.user_io.handle_story_output(self.story_pitch) def init_story_custom(self): @@ -248,9 +260,9 @@ def init_story_custom(self): ) user_story_pitch = self.user_io.handle_user_input() - self.api.story_pitch = None - self.api._create_adventure(self.api.scenario_id) - self.api.init_custom_story_pitch(user_story_pitch) + self.story_pitch = None + self.adventure_id, _ = self.api.create_adventure(self.scenario_id, self.story_pitch) + self.story_pitch = self.api.init_custom_story_pitch(self.adventure_id, user_story_pitch) def find_action_type(self, user_input: str): @@ -281,12 +293,12 @@ def process_regular_action(self, user_input: str): (action, user_input) = self.find_action_type(user_input) - resp = self.api.perform_regular_action(action, user_input) + resp = self.api.perform_regular_action(self.adventure_id, action, user_input, self.character_name) self.user_io.handle_story_output(resp) def process_remember_action(self, user_input: str): - self.api.perform_remember_action(user_input) + self.api.perform_remember_action(user_input, self.adventure_id) def process_next_action(self): user_input = self.user_io.handle_user_input() @@ -339,17 +351,12 @@ def main(): # Loads the current session configuration if conf.public_adventure_id: - # print('1') ai_dungeon.join_multiplayer() else: - # print('2') - ai_dungeon.choose_config() + ai_dungeon.make_user_choose_config() # Initializes the story - if conf.public_adventure_id: - ai_dungeon.init_story_multi_adventure() - else: - ai_dungeon.init_story() + ai_dungeon.init_story() # Starts the game ai_dungeon.start_game() diff --git a/ai_dungeon_cli/impl/api/client.py b/ai_dungeon_cli/impl/api/client.py index 39570a8..63d20b2 100644 --- a/ai_dungeon_cli/impl/api/client.py +++ b/ai_dungeon_cli/impl/api/client.py @@ -18,13 +18,6 @@ def __init__(self): self.access_token: str = '' self.single_player_mode_id: str = 'scenario:458612' - self.scenario_id: str = '' # REVIEW: maybe call it setting_id ? - self.character_name: str = '' - self.story_pitch_template: str = '' - self.story_pitch: str = '' - self.adventure_id: str = '' - self.public_id: str = '' - self.quests: str = '' async def _execute_query_pseudo_async(self, query, params={}): @@ -154,11 +147,11 @@ def join_multi_adventure(self, public_adventure_id): mutation ($adventurePlayPublicId: String) { addUserToAdventure(adventurePlayPublicId: $adventurePlayPublicId)} ''', {"adventurePlayPublicId": public_adventure_id}) - self.adventure_id = result['addUserToAdventure'] debug_print(result) + return result['addUserToAdventure'] - def get_characters(self): + def get_characters(self, scenario_id): prompt = '' characters = {} @@ -166,7 +159,7 @@ def get_characters(self): result = self._execute_query(''' query ($id: String) { user { id username __typename } content(id: $id) { id userId contentType contentId prompt gameState options { id title __typename } playPublicId __typename }} ''', - {"id": self.scenario_id}) + {"id": scenario_id}) debug_print(result) prompt = result['content']['prompt'] characters = self.normalize_options(result['content']['options']) @@ -175,7 +168,7 @@ def get_characters(self): # result = self._execute_query(''' # query ($id: String) { content(id: $id) { id contentType contentId title description prompt memory tags nsfw published createdAt updatedAt deletedAt options { id title __typename } __typename }} # ''', - # {"id": self.scenario_id}) + # {"id": scenario_id}) # debug_print(result) # prompt = result['content']['prompt'] # characters = self.normalize_options(result['content']['options']) @@ -183,15 +176,15 @@ def get_characters(self): return [prompt, characters] - def get_story_for_scenario(self): + def get_story_template_for_scenario(self, scenario_id): debug_print("query get story for scenario") result = self._execute_query(''' query ($id: String) { user { id username __typename } content(id: $id) { id userId contentType contentId prompt gameState options { id title __typename } playPublicId __typename }} ''', - {"id": self.scenario_id}) + {"id": scenario_id}) debug_print(result) - self.story_pitch_template = result['content']['prompt'] + return result['content']['prompt'] # debug_print("query get story for scenario #2") # result = self._execute_query(''' @@ -212,11 +205,11 @@ def initial_story_from_history_list(history_list): return pitch - def set_story_pitch(self): - self.story_pitch = self.story_pitch_template.replace('${character.name}', self.character_name) + def make_story_pitch(self, story_pitch_template, character_name): + return story_pitch_template.replace('${character.name}', character_name) - def init_custom_story_pitch(self, user_input): + def init_custom_story_pitch(self, adventure_id, user_input): debug_print("send custom settings story pitch") result = self._execute_query(''' @@ -226,25 +219,27 @@ def init_custom_story_pitch(self, user_input): "input": { "type": "story", "text": user_input, - "id": self.adventure_id}}) + "id": adventure_id}}) debug_print(result) - self.story_pitch = ''.join([a['text'] for a in result['sendAction']['actions']]) + return ''.join([a['text'] for a in result['sendAction']['actions']]) - def _create_adventure(self, scenario_id): + def create_adventure(self, scenario_id, story_pitch): debug_print("create adventure") result = self._execute_query(''' mutation ($id: String, $prompt: String) { createAdventureFromScenarioId(id: $id, prompt: $prompt) { id contentType contentId title description musicTheme tags nsfw published createdAt updatedAt deletedAt publicId historyList __typename }} ''', { "id": scenario_id, - "prompt": self.story_pitch + "prompt": story_pitch }) debug_print(result) - self.adventure_id = result['createAdventureFromScenarioId']['id'] + adventure_id = result['createAdventureFromScenarioId']['id'] + story_pitch = None if 'historyList' in result['createAdventureFromScenarioId']: - # NB: not present when self.story_pitch is None, as is the case for a custom scenario - self.story_pitch = self.initial_story_from_history_list(result['createAdventureFromScenarioId']['historyList']) + # NB: not present when story_pitch is None, as is the case for a custom scenario + story_pitch = self.initial_story_from_history_list(result['createAdventureFromScenarioId']['historyList']) + return [adventure_id, story_pitch] def init_story_multi_adventure(self, public_adventure_id): @@ -262,27 +257,29 @@ def init_story_multi_adventure(self, public_adventure_id): if entry.startswith("\n>"): entry = "\n" + entry + "\n" # mo' spacing please entries.append(entry) - self.story_pitch = ''.join(entries) + return ''.join(entries) - def init_story(self): - - self._create_adventure(self.scenario_id) + def init_story(self, scenario_id, story_pitch): + adventure_id, story_pitch = self.create_adventure(scenario_id, story_pitch) debug_print("get created adventure ids") result = self._execute_query(''' query ($id: String, $playPublicId: String) { content(id: $id, playPublicId: $playPublicId) { id historyList quests playPublicId userId __typename }} ''', { - "id": self.adventure_id, + "id": adventure_id, }) debug_print(result) - self.quests = result['content']['quests'] - self.public_id = result['content']['playPublicId'] - # self.story_pitch = self.initial_story_from_history_list(result['content']['historyList']) + quests = result['content']['quests'] + public_id = result['content']['playPublicId'] + # story_pitch = self.initial_story_from_history_list(result['content']['historyList']) + + return [adventure_id, public_id, story_pitch, quests] + - def perform_remember_action(self, user_input): + def perform_remember_action(self, user_input, adventure_id): debug_print("remember something") result = self._execute_query(''' mutation ($input: ContentActionInput) { updateMemory(input: $input) { id memory __typename }} @@ -292,13 +289,13 @@ def perform_remember_action(self, user_input): { "text": user_input, "type":"remember", - "id": self.adventure_id + "id": adventure_id } }) debug_print(result) - def perform_regular_action(self, action, user_input): + def perform_regular_action(self, adventure_id, action, user_input, character_name = None): story_continuation = "" @@ -310,8 +307,8 @@ def perform_regular_action(self, action, user_input): "input": { "type": action, "text": user_input, - "id": self.adventure_id, - "characterName": self.character_name + "id": adventure_id, + "characterName": character_name } }) debug_print(result) @@ -330,7 +327,7 @@ def perform_regular_action(self, action, user_input): } ''', { - "id": self.adventure_id + "id": adventure_id }) debug_print(result) story_continuation = result['content']['actions'][-1]['text']