diff --git a/kafka_utils/kafka_corruption_check/main.py b/kafka_utils/kafka_corruption_check/main.py index a861f212..d4c31ca3 100644 --- a/kafka_utils/kafka_corruption_check/main.py +++ b/kafka_utils/kafka_corruption_check/main.py @@ -6,6 +6,9 @@ import logging import re import sys +import os +import getpass + from contextlib import closing from functools import partial from multiprocessing import Pool @@ -58,15 +61,47 @@ def chunks(l, n): def ssh_client(host): """Start an ssh client. + The ssh client will attempt to use configs from the user's local ssh config in ~/.ssh/config. + :param host: the host :type host: str :returns: ssh client :rtype: Paramiko client """ - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(host) - return ssh + client = paramiko.SSHClient() + client._policy = paramiko.WarningPolicy() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + # parse local user ssh config + ssh_config = paramiko.SSHConfig() + user_config_file = os.path.expanduser("~/.ssh/config") + if os.path.exists(user_config_file): + with open(user_config_file) as f: + ssh_config.parse(f) + + cfg = {'hostname': host, 'username': getpass.getuser()} + + user_config = ssh_config.lookup(cfg['hostname']) + + logging.debug("Local user ssh config for host {host}: {user_config}".format(host=host, user_config=user_config)) + + for k in ('hostname', 'port'): + if k in user_config: + cfg[k] = user_config[k] + + if 'user' in user_config: + cfg['username'] = user_config['user'] + + if 'proxycommand' in user_config: + cfg['sock'] = paramiko.ProxyCommand(user_config['proxycommand']) + + if 'identityfile' in user_config: + cfg['key_filename'] = user_config['identityfile'] + + logging.debug("Overriden ssh config for {host}: {cfg}".format(host=host, cfg=cfg)) + + client.connect(**cfg) + return client def report_stderr(host, stderr):