Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Unit Tests for the project #87

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ idna==2.8
lxml==4.4.1
nltk==3.4.5
paramiko==2.6.0
pprint==0.1
pycparser==2.19
PyNaCl==1.3.0
python-dateutil==2.8.0
Expand Down
12 changes: 9 additions & 3 deletions scanners/nexpose_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def start(self, scan_name, target):
print(f'[{self.name}] Starting Scan for Target: {target}')

try:
return self.scan(scan_name, target)
self.scan_status = 'INPROGRESS'
self.scan_results = self.scan(scan_name, target)
return self.scan_results
except:
print(f'[{self.name}] Not able to connect to the {self.name}: ', sys.exc_info())
return False
Expand Down Expand Up @@ -103,6 +105,8 @@ def scan(self, scan_name, target):
self.storage_service.update_by_name(scan_name, scan_data)

return scan_data
self.scan_status = None
self.scan_results = None


def _create_report(self, scan_name):
Expand Down Expand Up @@ -167,7 +171,8 @@ def get_scan_status(self, scan_name, scan_status_list=[]):
'status': scan_status['status']
})

return scan_status_list
self.scan_status = scan_status['status']
return self.scan_status


def get_scan_results(self, scan_name, scan_results={}):
Expand Down Expand Up @@ -195,7 +200,8 @@ def get_scan_results(self, scan_name, scan_results={}):

self._process_results(parsed_report, scan_results)

return scan_results
self.scan_results = scan_results
return self.scan_results

def _process_results(self, report, scan_results):

Expand Down
16 changes: 10 additions & 6 deletions scanners/openvas_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(self):
transform = EtreeTransform()
self.gmp = Gmp(connection, transform=transform)
self.storage_service = StorageService()
self.scan_status = None
self.scan_results = None

# Login
try:
Expand All @@ -49,9 +51,11 @@ def __init__(self):

def start(self, scan_name, target):
print(f'[{self.name}] Starting Scan for Target: {target}')
self.scan_status = 'INPROGRESS'

try:
return self.scan(scan_name, target)
self.scan_results = self.scan(scan_name, target)
return self.scan_results
except:
print(f'[{self.name}] Not able to connect to the {self.name}: ', sys.exc_info())
return False
Expand All @@ -65,6 +69,7 @@ def scan(self, scan_name, target):
# Creating Target
target_response = self.gmp.create_target(name=scan_name, hosts=[address])
# print('target_response')
self.scan_status = 'INPROGRESS'
pretty_print(target_response)
target_id = target_response.get('id')

Expand Down Expand Up @@ -132,6 +137,7 @@ def _create_report(self, scan_name):
print(f'[{self.name}] Created Report: {report_id} with Task: {task_id}')

self.storage_service.update_by_name(scan_name, scan_data)
self.scan_status = 'COMPLETE'

return scan_data

Expand All @@ -141,8 +147,7 @@ def get_scan_status(self, scan_name, scan_status_list=[]):
if not self.is_valid_scan(scan_name):
return False

scan_data = self.storage_service.get_by_name(scan_name)
scan_status = scan_data.get('OPENVAS', {}).get('scan_status', {})
return self.scan_status
openvas_id = scan_data.get('OPENVAS', {})['openvas_id']
target = scan_data['target']

Expand Down Expand Up @@ -179,7 +184,7 @@ def get_scan_results(self, scan_name, scan_results={}):
if not self.is_valid_scan(scan_name):
return False

scan_data = self.storage_service.get_by_name(scan_name)
return self.scan_results

# if scan_data.get('OPENVAS', {}).get('scan_status').get('status', None) != 'COMPLETE':
# print(f'[{self.name}] Scan is in progress')
Expand Down Expand Up @@ -233,8 +238,7 @@ def _process_results(self, report_response, scan_results={}):

def is_valid_scan(self, scan_name):

scan_data = self.storage_service.get_by_name(scan_name)
if not scan_data:
if self.scan_status is None:
print(f'[{self.name}] Invalid Scan Name: {scan_name}')
return False

Expand Down
10 changes: 10 additions & 0 deletions scanners/zap_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class ZapScanner(Scanner):
def __init__(self):
self.zap = ZAPv2(apikey=API_KEY)
self.storage_service = StorageService()
self.scan_status = None
self.scan_results = None

def start(self, scan_name, target):
print(f'[{self.name}] Starting Scan for Target: {target}')
Expand Down Expand Up @@ -63,6 +65,11 @@ def resume(self, scan_name):

# self.storage_service.update_by_name(scan_name, { status: 'RESUMED' })
return scan
def get_scan_status(self):
return self.scan_status

def get_scan_results(self):
return self.scan_results

def stop(self, scan_name):
if not self.is_valid_scan(scan_name):
Expand Down Expand Up @@ -104,6 +111,8 @@ def scan(self, scan_name, target):
# a_scan_id = self.zap.ascan.scan(target, recurse=True, inscopeonly=None, scanpolicyname=None, method=None, postdata=True)

scan_data = self.storage_service.get_by_name(scan_name)
self.scan_results = scan_data
self.scan_status = 'INPROGRESS'

if not scan_data:
scan_data = {
Expand Down Expand Up @@ -143,6 +152,7 @@ def get_scan_status(self, scan_name, scan_status_list=[]):
spider_scan_status = self.zap.spider.status(zap_id)
passive_scan_records_pending = self.zap.pscan.records_to_scan
active_scan_status = self.zap.ascan.status(zap_id)
self.scan_status = 'INPROGRESS' if int(spider_scan_status) < 100 or int(active_scan_status) < 100 else 'COMPLETE'

scan_status['spider_scan'] = self._parse_status(spider_scan_status)
scan_status['active_scan'] = self._parse_status(active_scan_status)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_nexpose_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
from unittest.mock import Mock, patch
from scanners.nexpose_scanner import NexposeScanner

class TestNexposeScanner(unittest.TestCase):
def setUp(self):
self.nexpose_mock = Mock()
self.storage_service_mock = Mock()
self.nexpose_scanner = NexposeScanner()
self.nexpose_scanner.nexpose = self.nexpose_mock
self.nexpose_scanner.storage_service = self.storage_service_mock

def test_start(self):
self.nexpose_scanner.start('test_scan', 'http://test_target')
self.nexpose_mock.start_scan.assert_called_once()

def test_scan(self):
self.nexpose_scanner.scan('test_scan', 'http://test_target')
self.nexpose_mock.start_scan.assert_called_once()

def test_get_scan_status(self):
self.nexpose_scanner.get_scan_status('test_scan')
self.nexpose_mock.get_scan.assert_called_once()

def test_get_scan_results(self):
self.nexpose_scanner.get_scan_results('test_scan')
self.nexpose_mock.download_report.assert_called_once()

def test_is_valid_scan(self):
self.nexpose_scanner.is_valid_scan('test_scan')
self.nexpose_mock.get_scan.assert_called_once()

def test_pause(self):
self.nexpose_scanner.pause('test_scan')
self.nexpose_mock.set_scan_status.assert_called_once_with('pause')

def test_resume(self):
self.nexpose_scanner.resume('test_scan')
self.nexpose_mock.set_scan_status.assert_called_once_with('resume')

def test_stop(self):
self.nexpose_scanner.stop('test_scan')
self.nexpose_mock.set_scan_status.assert_called_once_with('stop')

def test_remove(self):
self.nexpose_scanner.remove('test_scan')
self.nexpose_mock.set_scan_status.assert_called_once_with('remove')

def test_list_scans(self):
self.nexpose_scanner.list_scans()
self.nexpose_mock.get_scans.assert_called_once()

if __name__ == '__main__':
unittest.main()
62 changes: 62 additions & 0 deletions tests/test_openvas_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import unittest
from unittest.mock import Mock, patch
from scanners.openvas_scanner import OpenVASScanner

class TestOpenVASScanner(unittest.TestCase):
def setUp(self):
self.openvas_mock = Mock()
self.storage_service_mock = Mock()
self.openvas_scanner = OpenVASScanner()
self.openvas_scanner.openvas = self.openvas_mock
self.openvas_scanner.storage_service = self.storage_service_mock

def test_start(self):
self.openvas_scanner.start('test_scan', 'http://test_target')
self.openvas_mock.start_scan.assert_called_once()

def test_scan(self):
self.openvas_scanner.scan('test_scan', 'http://test_target')
self.openvas_mock.start_scan.assert_called_once()

def test_create_report(self):
self.openvas_scanner._create_report('test_scan')
self.openvas_mock.get_report.assert_called_once()

def test_get_scan_status(self):
self.openvas_scanner.get_scan_status('test_scan')
self.openvas_mock.get_scan.assert_called_once()

def test_get_scan_results(self):
self.openvas_scanner.get_scan_results('test_scan')
self.openvas_mock.get_report.assert_called_once()

def test_process_results(self):
self.openvas_scanner._process_results('test_report', {})
self.openvas_mock.parse.assert_called_once()

def test_is_valid_scan(self):
self.openvas_scanner.is_valid_scan('test_scan')
self.storage_service_mock.get_by_name.assert_called_once()

def test_pause(self):
self.openvas_scanner.pause('test_scan')
self.openvas_mock.set_scan_status.assert_called_once_with('pause')

def test_resume(self):
self.openvas_scanner.resume('test_scan')
self.openvas_mock.set_scan_status.assert_called_once_with('resume')

def test_stop(self):
self.openvas_scanner.stop('test_scan')
self.openvas_mock.set_scan_status.assert_called_once_with('stop')

def test_remove(self):
self.openvas_scanner.remove('test_scan')
self.openvas_mock.set_scan_status.assert_called_once_with('remove')

def test_list_scans(self):
self.openvas_scanner.list_scans()
self.openvas_mock.get_scans.assert_called_once()

if __name__ == '__main__':
unittest.main()
45 changes: 45 additions & 0 deletions tests/test_storage_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest
from tinydb import TinyDB, Query
from core.storage_service import StorageService

class TestStorageService(unittest.TestCase):
def setUp(self):
self.db = TinyDB('test_db.json')
self.storage_service = StorageService()

def test_add(self):
data = {'scan_name': 'test_scan', 'scan_id': '123', 'target': 'http://test_target', 'status': 'INPROGRESS'}
self.storage_service.add(data)
record = self.db.get(Query().scan_name == 'test_scan')
self.assertEqual(record, data)

def test_get_by_name(self):
data = {'scan_name': 'test_scan', 'scan_id': '123', 'target': 'http://test_target', 'status': 'INPROGRESS'}
self.db.insert(data)
record = self.storage_service.get_by_name('test_scan')
self.assertEqual(record, data)

def test_get_by_id(self):
data = {'scan_name': 'test_scan', 'scan_id': '123', 'target': 'http://test_target', 'status': 'INPROGRESS'}
self.db.insert(data)
record = self.storage_service.get_by_id('123')
self.assertEqual(record, data)

def test_update_by_name(self):
data = {'scan_name': 'test_scan', 'scan_id': '123', 'target': 'http://test_target', 'status': 'INPROGRESS'}
new_data = {'scan_name': 'test_scan', 'scan_id': '123', 'target': 'http://test_target', 'status': 'COMPLETE'}
self.db.insert(data)
self.storage_service.update_by_name('test_scan', new_data)
record = self.db.get(Query().scan_name == 'test_scan')
self.assertEqual(record, new_data)

def test_update_by_id(self):
data = {'scan_name': 'test_scan', 'scan_id': '123', 'target': 'http://test_target', 'status': 'INPROGRESS'}
new_data = {'scan_name': 'test_scan', 'scan_id': '123', 'target': 'http://test_target', 'status': 'COMPLETE'}
self.db.insert(data)
self.storage_service.update_by_id('123', new_data)
record = self.db.get(Query().scan_id == '123')
self.assertEqual(record, new_data)

if __name__ == '__main__':
unittest.main()
61 changes: 61 additions & 0 deletions tests/test_zap_scanner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest
from unittest.mock import Mock, patch
from scanners.zap_scanner import ZapScanner

class TestZapScanner(unittest.TestCase):
def setUp(self):
self.zap_mock = Mock()
self.storage_service_mock = Mock()
self.zap_scanner = ZapScanner()
self.zap_scanner.zap = self.zap_mock
self.zap_scanner.storage_service = self.storage_service_mock

def test_start(self):
with patch('scanners.zap_scanner.ZAPv2.urlopen') as urlopen_mock:
self.zap_scanner.start('test_scan', 'http://test_target')
urlopen_mock.assert_called_once_with('http://test_target')

def test_pause(self):
self.zap_scanner.pause('test_scan')
self.zap_mock.spider.pause.assert_called_once()
self.zap_mock.ascan.pause.assert_called_once()

def test_resume(self):
self.zap_scanner.resume('test_scan')
self.zap_mock.spider.resume.assert_called_once()
self.zap_mock.ascan.resume.assert_called_once()

def test_stop(self):
self.zap_scanner.stop('test_scan')
self.zap_mock.spider.stop.assert_called_once()
self.zap_mock.ascan.stop.assert_called_once()

def test_remove(self):
self.zap_scanner.remove('test_scan')
self.zap_mock.spider.removeScan.assert_called_once()
self.zap_mock.ascan.removeScan.assert_called_once()

def test_scan(self):
with patch('scanners.zap_scanner.ZAPv2.urlopen') as urlopen_mock:
self.zap_scanner.scan('test_scan', 'http://test_target')
urlopen_mock.assert_called_once_with('http://test_target')

def test_get_scan_status(self):
self.zap_scanner.get_scan_status('test_scan')
self.zap_mock.spider.status.assert_called_once()
self.zap_mock.ascan.status.assert_called_once()

def test_get_scan_results(self):
self.zap_scanner.get_scan_results('test_scan')
self.zap_mock.core.alerts.assert_called_once()

def test_list_scans(self):
self.zap_scanner.list_scans()
self.zap_mock.ascan.scans.assert_called_once()

def test_is_valid_scan(self):
self.zap_scanner.is_valid_scan('test_scan')
self.zap_mock.ascan.status.assert_called_once()

if __name__ == '__main__':
unittest.main()