Skip to content

Commit

Permalink
Merge pull request #29 from Christopher-R-Perkins/feature/mlstacks
Browse files Browse the repository at this point in the history
Stack + stack component creation, updates and deletion
  • Loading branch information
strickvl authored Jul 25, 2024
2 parents a0ce66e + 1bb20f9 commit 6ac8078
Show file tree
Hide file tree
Showing 33 changed files with 2,791 additions and 352 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ this extension and your Python version needs to be 3.8 or greater.
## Using ZenML in VSCode

- **Manage Server Connections**: Connect or disconnect from ZenML servers and refresh server status.
- **Stack Operations**: View stack details, rename, copy, or set active stacks directly from VSCode.
- **Stack Operations**: View stack details, register, update, delete, copy, or set active stacks directly from VSCode.
- **Stack Component Operations**: View stack component details, register, update, or delete stack components directly from VSCode.
- **Pipeline Runs**: Monitor and manage pipeline runs, including deleting runs from the system and rendering DAGs.
- **Environment Information**: Get detailed snapshots of the development environment, aiding troubleshooting.

Expand Down
94 changes: 82 additions & 12 deletions bundled/tool/lsp_zenml.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from zen_watcher import ZenConfigWatcher
from zenml_client import ZenMLClient

zenml_init_error = {"error": "ZenML is not initialized. Please check ZenML version requirements."}
zenml_init_error = {
"error": "ZenML is not initialized. Please check ZenML version requirements."
}


class ZenLanguageServer(LanguageServer):
Expand All @@ -58,7 +60,9 @@ async def is_zenml_installed(self) -> bool:
if process.returncode == 0:
self.show_message_log("✅ ZenML installation check: Successful.")
return True
self.show_message_log("❌ ZenML installation check failed.", lsp.MessageType.Error)
self.show_message_log(
"❌ ZenML installation check failed.", lsp.MessageType.Error
)
return False
except Exception as e:
self.show_message_log(
Expand Down Expand Up @@ -93,7 +97,9 @@ async def initialize_zenml_client(self):
# initialize watcher
self.initialize_global_config_watcher()
except Exception as e:
self.notify_user(f"Failed to initialize ZenML client: {str(e)}", lsp.MessageType.Error)
self.notify_user(
f"Failed to initialize ZenML client: {str(e)}", lsp.MessageType.Error
)

def initialize_global_config_watcher(self):
"""Sets up and starts the Global Configuration Watcher."""
Expand Down Expand Up @@ -133,7 +139,9 @@ def wrapper(*args, **kwargs):

with suppress_stdout_temporarily():
if wrapper_name:
wrapper_instance = getattr(self.zenml_client, wrapper_name, None)
wrapper_instance = getattr(
self.zenml_client, wrapper_name, None
)
if not wrapper_instance:
return {"error": f"Wrapper '{wrapper_name}' not found."}
return func(wrapper_instance, *args, **kwargs)
Expand Down Expand Up @@ -177,25 +185,33 @@ def _construct_version_validation_response(self, meets_requirement, version_str)

def send_custom_notification(self, method: str, args: dict):
"""Sends a custom notification to the LSP client."""
self.show_message_log(f"Sending custom notification: {method} with args: {args}")
self.show_message_log(
f"Sending custom notification: {method} with args: {args}"
)
self.send_notification(method, args)

def update_python_interpreter(self, interpreter_path):
"""Updates the Python interpreter path and handles errors."""
try:
self.python_interpreter = interpreter_path
self.show_message_log(f"LSP_Python_Interpreter Updated: {self.python_interpreter}")
self.show_message_log(
f"LSP_Python_Interpreter Updated: {self.python_interpreter}"
)
# pylint: disable=broad-exception-caught
except Exception as e:
self.show_message_log(
f"Failed to update Python interpreter: {str(e)}", lsp.MessageType.Error
)

def notify_user(self, message: str, msg_type: lsp.MessageType = lsp.MessageType.Info):
def notify_user(
self, message: str, msg_type: lsp.MessageType = lsp.MessageType.Info
):
"""Logs a message and also notifies the user."""
self.show_message(message, msg_type)

def log_to_output(self, message: str, msg_type: lsp.MessageType = lsp.MessageType.Log) -> None:
def log_to_output(
self, message: str, msg_type: lsp.MessageType = lsp.MessageType.Log
) -> None:
"""Log to output."""
self.show_message_log(message, msg_type)

Expand Down Expand Up @@ -261,6 +277,60 @@ def rename_stack(wrapper_instance, args):
def copy_stack(wrapper_instance, args):
"""Copies a specified ZenML stack to a new stack."""
return wrapper_instance.copy_stack(args)

@self.command(f"{TOOL_MODULE_NAME}.registerStack")
@self.zenml_command(wrapper_name="stacks_wrapper")
def register_stack(wrapper_instance, args):
"""Registers a new ZenML stack."""
return wrapper_instance.register_stack(args)

@self.command(f"{TOOL_MODULE_NAME}.updateStack")
@self.zenml_command(wrapper_name="stacks_wrapper")
def update_stack(wrapper_instance, args):
"""Updates a specified ZenML stack ."""
return wrapper_instance.update_stack(args)

@self.command(f"{TOOL_MODULE_NAME}.deleteStack")
@self.zenml_command(wrapper_name="stacks_wrapper")
def delete_stack(wrapper_instance, args):
"""Deletes a specified ZenML stack ."""
return wrapper_instance.delete_stack(args)

@self.command(f"{TOOL_MODULE_NAME}.registerComponent")
@self.zenml_command(wrapper_name="stacks_wrapper")
def register_component(wrapper_instance, args):
"""Registers a Zenml stack component"""
return wrapper_instance.register_component(args)

@self.command(f"{TOOL_MODULE_NAME}.updateComponent")
@self.zenml_command(wrapper_name="stacks_wrapper")
def update_component(wrapper_instance, args):
"""Updates a ZenML stack component"""
return wrapper_instance.update_component(args)

@self.command(f"{TOOL_MODULE_NAME}.deleteComponent")
@self.zenml_command(wrapper_name="stacks_wrapper")
def delete_component(wrapper_instance, args):
"""Deletes a specified ZenML stack component"""
return wrapper_instance.delete_component(args)

@self.command(f"{TOOL_MODULE_NAME}.listComponents")
@self.zenml_command(wrapper_name="stacks_wrapper")
def list_components(wrapper_instance, args):
"""Get paginated list of stack components from ZenML"""
return wrapper_instance.list_components(args)

@self.command(f"{TOOL_MODULE_NAME}.getComponentTypes")
@self.zenml_command(wrapper_name="stacks_wrapper")
def get_component_types(wrapper_instance, args):
"""Get list of component types from ZenML"""
return wrapper_instance.get_component_types()

@self.command(f"{TOOL_MODULE_NAME}.listFlavors")
@self.zenml_command(wrapper_name="stacks_wrapper")
def list_flavors(wrapper_instance, args):
"""Get paginated list of component flavors from ZenML"""
return wrapper_instance.list_flavors(args)

@self.command(f"{TOOL_MODULE_NAME}.getPipelineRuns")
@self.zenml_command(wrapper_name="pipeline_runs_wrapper")
Expand All @@ -273,13 +343,13 @@ def fetch_pipeline_runs(wrapper_instance, args):
def delete_pipeline_run(wrapper_instance, args):
"""Deletes a specified ZenML pipeline run."""
return wrapper_instance.delete_pipeline_run(args)

@self.command(f"{TOOL_MODULE_NAME}.getPipelineRun")
@self.zenml_command(wrapper_name="pipeline_runs_wrapper")
def get_pipeline_run(wrapper_instance, args):
"""Gets a specified ZenML pipeline run."""
return wrapper_instance.get_pipeline_run(args)

@self.command(f"{TOOL_MODULE_NAME}.getPipelineRunStep")
@self.zenml_command(wrapper_name="pipeline_runs_wrapper")
def get_run_step(wrapper_instance, args):
Expand All @@ -291,9 +361,9 @@ def get_run_step(wrapper_instance, args):
def get_run_artifact(wrapper_instance, args):
"""Gets a specified ZenML pipeline artifact"""
return wrapper_instance.get_run_artifact(args)

@self.command(f"{TOOL_MODULE_NAME}.getPipelineRunDag")
@self.zenml_command(wrapper_name="pipeline_runs_wrapper")
def get_run_dag(wrapper_instance, args):
"""Gets graph data for a specified ZenML pipeline run"""
return wrapper_instance.get_pipeline_run_graph(args)
return wrapper_instance.get_pipeline_run_graph(args)
59 changes: 51 additions & 8 deletions bundled/tool/type_hints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from typing import Any, TypedDict, Dict, List, Union
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Any, TypedDict, Dict, List, Optional
from uuid import UUID


Expand All @@ -20,8 +32,6 @@ class GraphEdge(TypedDict):
source: str
target: str



class GraphResponse(TypedDict):
nodes: List[GraphNode]
edges: List[GraphEdge]
Expand All @@ -37,9 +47,9 @@ class RunStepResponse(TypedDict):
id: str
status: str
author: Dict[str, str]
startTime: Union[str, None]
endTime: Union[str, None]
duration: Union[str, None]
startTime: Optional[str]
endTime: Optional[str]
duration: Optional[str]
stackName: str
orchestrator: Dict[str, str]
pipeline: Dict[str, str]
Expand Down Expand Up @@ -71,7 +81,7 @@ class ZenmlStoreInfo(TypedDict):
class ZenmlStoreConfig(TypedDict):
type: str
url: str
api_token: Union[str, None]
api_token: Optional[str]

class ZenmlServerInfoResp(TypedDict):
store_info: ZenmlStoreInfo
Expand All @@ -84,4 +94,37 @@ class ZenmlGlobalConfigResp(TypedDict):
version: str
active_stack_id: str
active_workspace_name: str
store: ZenmlStoreConfig
store: ZenmlStoreConfig

class StackComponent(TypedDict):
id: str
name: str
flavor: str
type: str
config: Dict[str, Any]

class ListComponentsResponse(TypedDict):
index: int
max_size: int
total_pages: int
total: int
items: List[StackComponent]

class Flavor(TypedDict):
id: str
name: str
type: str
logo_url: str
config_schema: Dict[str, Any]
docs_url: Optional[str]
sdk_docs_url: Optional[str]
connector_type: Optional[str]
connector_resource_type: Optional[str]
connector_resource_id_attr: Optional[str]

class ListFlavorsResponse(TypedDict):
index: int
max_size: int
total_pages: int
total: int
items: List[Flavor]
Loading

0 comments on commit 6ac8078

Please sign in to comment.