Skip to content

Commit

Permalink
Fix editing of configuration and add valid fields for executables in …
Browse files Browse the repository at this point in the history
…AMS-config
  • Loading branch information
koparasy committed Dec 7, 2023
1 parent aa4e44a commit 0b6cc85
Showing 1 changed file with 89 additions and 19 deletions.
108 changes: 89 additions & 19 deletions src/AMSWorkflow/ams_wf/ams_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def slurm_bootstrap(cmd, flux_log_file):
bootstrap_cmd = f"{bootstrap_cmd} {cmd}"
logging.debug(f"Executing command {bootstrap_cmd}")
logging.shutdown()
sp.run(bootstrap_cmd, shell=True)
result = sp.run(bootstrap_cmd, shell=True)
return result.returncode
# NOTE: From this point on we should definetely not use the logging mechanism. We manually shut it donw
# to allo the bootstrapped script to use the same logger (this is important in the case of logging into a file)

Expand All @@ -38,6 +39,7 @@ def slurm_bootstrap(cmd, flux_log_file):
else:
logging.critical("Unknown scheduler, cannot bootstrap")
sys.exit()
return 0


class AMSConfig:
Expand All @@ -50,23 +52,38 @@ def validate_keys(level, config, mandatory_fields):
return False
return True

if not validate_keys("root", config, ["user_app", "ml_training", "execution_mode", "db", "stager"]):
def validate_step_field(level, config):
if not validate_keys(level, config, ["executable", "resources"]):
logging.critical(f"Mising fields in {level}")
return False

exec_path = Path(config["executable"])
if not exec_path.exists():
logging.critical("Executable {exec_path} does not exist")
return False

if not validate_keys(level, config["resources"], ["num_nodes", "num_processes_per_node"]):
logging.critical(f"Missing fields in resources of {level}")
return False

return True

if not validate_keys("root", config, ["user_app", "ml_training", "ml_pruning", "execution_mode", "db", "stager"]):
return False

if not validate_keys("user_app", config["user_app"], ["executable", "resources"]):
if not validate_step_field("user_app", config["user_app"]):
return False

exec_path = Path(config["user_app"]["executable"])
if not exec_path.exists():
logging.critical("Executable {exec_path} does not exist")
if not validate_step_field("ml_training", config["ml_training"]):
return False

if not validate_keys(
"user_app|resources", config["user_app"]["resources"], ["num_nodes", "num_processes_per_node"]
):
if not validate_step_field("ml_pruning", config["ml_pruning"]):
return False

if not validate_keys("ml_training", config["ml_training"], ["num_nodes", "num_processes_per_node"]):
if not validate_keys("ml_training|resources", config["ml_training"]["resources"], ["num_nodes", "num_processes_per_node"]):
return False

if not validate_keys("ml_pruning|resources", config["ml_training"]["resources"], ["num_nodes", "num_processes_per_node"]):
return False

if config["execution_mode"] not in ["sequential", "concurrent"]:
Expand All @@ -81,6 +98,22 @@ def validate_keys(level, config, mandatory_fields):
if "num_clients" not in config["stager"]:
logging.critical("When stager set in mode 'rmq' you need to define the number of rmq clients")
return False
if config["stager"]["mode"]:
rmq_config = config["rmq"]
if not isinstance(rmq_config["service-port"], int):
print(isinstance(int, type(rmq_config["service-port"])))
print(rmq_config["service-port"], type(rmq_config["service-port"]), int)
logging.critical("The RMQ service-port must be an integer type {0}".format(type(rmq_config["service-port"])))
return False
if not Path(rmq_config["rabbitmq-cert"]).exists():
logging.critical("The RMQ certificate file does not exist (or is not not accessible)")
return False

rmq_keys = AMSConfig.to_descr()["rmq"].keys()

if not validate_keys("rmq", config["rmq"], rmq_keys):
return False

return True

@staticmethod
Expand All @@ -92,10 +125,33 @@ def to_descr():
"env_variables": {"VARNAME": "VALUE"},
"resources": {"num_nodes": "XX", "num_processes_per_node": "YY", "num_gpus_per_node": "ZZ"},
},
"ml_training": {"num_nodes": "XX", "num_processes_per_node": "YY", "num_gpus_per_node": "ZZ"},
"ml_training": {
"executable": "path to executable",
"arguments": ["one", "two", "three"],
"env_variables": {"VARNAME": "VALUE"},
"resources": {"num_nodes": "XX", "num_processes_per_node": "YY", "num_gpus_per_node": "ZZ"},
},
"ml_pruning": {
"executable": "path to executable",
"arguments": ["one", "two", "three"],
"env_variables": {"VARNAME": "VALUE"},
"resources": {"num_nodes": "XX", "num_processes_per_node": "YY", "num_gpus_per_node": "ZZ"},
},
"execution_mode": "sequential",
"db": {"path": "path/to/db"},
"stager": {"mode": "filesystem", "num_clients": "number of rmq clients (mandatory only when mode is rmq)"},
"rmq" : {
"service-port": "Port",
"service-host": "server address",
"rabbitmq-erlang-cookie": "magic cookie",
"rabbitmq-name": "rmq server name",
"rabbitmq-password": "password",
"rabbitmq-user": "user",
"rabbitmq-vhost": "virtual host",
"rabbitmq-cert": "path to certificate to establish connection",
"rabbitmq-inbound-queue": "Queue name to send data from outside in the simulation",
"rabbitmq-outbound-queue": "Queue name to send data from the simulation to outside"
}
}


Expand All @@ -116,7 +172,10 @@ def generate_config(args):
logging.critical(f"Environemnt variable EDITOR is not set, example configuration is stored in {args.config}")
sys.exit()
cmd = f"{editor} {args.config}"
sp.run(cmd, shell=True)
result = sp.run(cmd, shell=True)
if result.returncode != 0:
logging.warning(f"{editor} {args.config} returned non zero code")

with open(args.config, "r") as fd:
data = json.load(fd)

Expand All @@ -135,7 +194,10 @@ def validate_config(args):
data = json.load(fd)

if not AMSConfig.validate(data):
logging.critical("Generated configuration file is not valid")
logging.info("Generated configuration file is NOT valid")
return False
logging.info("Generated configuration file IS valid")
return True


def start_cli(parser):
Expand Down Expand Up @@ -167,12 +229,18 @@ def get_cmd():

if is_bootstrapped():
logging.info("Execution is bootstrapped")
return
return False

with open(args.config, "r") as fd:
data = json.load(fd)

if not validate_config(data):
logging.info("Configuration file is not valid, exiting early...")
return False

logging.info("Execution is NOT bootstrapped")
cmd = get_cmd()
bootstrap(cmd, RootSched[args.scheduler], args.flux_log)
return
return (bootstrap(cmd, RootSched[args.scheduler], args.flux_log) == 0)


def main():
Expand All @@ -188,9 +256,11 @@ def main():
parser.add_argument(
"-l", "--log-file", dest="log_file", help="Path to file to store logs (when unspecified stdout/err is used)"
)
sub_parsers = parser.add_subparsers(help="Commands supported by ams deployment tool")
sub_parsers = parser.add_subparsers(dest="command", help="Commands supported by ams deployment tool")
sub_parsers.required = True
start_cli(sub_parsers)
generate_cli(sub_parsers)
validate_cli(sub_parsers)

args = parser.parse_args()
if args.log_file is not None:
Expand All @@ -208,8 +278,8 @@ def main():
level=args.verbose,
)

args.func(args)
return not args.func(args)


if __name__ == "__main__":
main()
sys.exit(main())

0 comments on commit 0b6cc85

Please sign in to comment.