Skip to content
This repository has been archived by the owner on Aug 29, 2023. It is now read-only.

Use user's local ssh config #145

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions kafka_utils/kafka_corruption_check/main.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,9 @@
import logging
import re
import sys
import os
import getpass

from contextlib import closing
from functools import partial
from multiprocessing import Pool
@@ -58,15 +61,47 @@ def chunks(l, n):
def ssh_client(host):
"""Start an ssh client.

The ssh client will attempt to use configs from the user's local ssh config in ~/.ssh/config.

:param host: the host
:type host: str
:returns: ssh client
:rtype: Paramiko client
"""
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(host)
return ssh
client = paramiko.SSHClient()
client._policy = paramiko.WarningPolicy()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

# parse local user ssh config
ssh_config = paramiko.SSHConfig()
user_config_file = os.path.expanduser("~/.ssh/config")
if os.path.exists(user_config_file):
with open(user_config_file) as f:
ssh_config.parse(f)

cfg = {'hostname': host, 'username': getpass.getuser()}

user_config = ssh_config.lookup(cfg['hostname'])

logging.debug("Local user ssh config for host {host}: {user_config}".format(host=host, user_config=user_config))

for k in ('hostname', 'port'):
if k in user_config:
cfg[k] = user_config[k]

if 'user' in user_config:
cfg['username'] = user_config['user']

if 'proxycommand' in user_config:
cfg['sock'] = paramiko.ProxyCommand(user_config['proxycommand'])

if 'identityfile' in user_config:
cfg['key_filename'] = user_config['identityfile']

logging.debug("Overriden ssh config for {host}: {cfg}".format(host=host, cfg=cfg))

client.connect(**cfg)
return client


def report_stderr(host, stderr):