Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Unit Test for CompressOutbox to Tempfile Creation Refactoring (SOFTWARE-5540) #176

Open
wants to merge 10 commits into
base: 2.x
Choose a base branch
from
53 changes: 22 additions & 31 deletions common/gratia/common/sandbox_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import glob
import time
import shutil
import tempfile
import tarfile

from gratia.common.config import ConfigProxy
Expand Down Expand Up @@ -409,18 +410,8 @@ def SearchOutstandingRecord():

def GenerateFilename(prefix, current_dir):
'''Generate a filename of the for current_dir/prefix.$pid.ConfigFragment.gratia.xml__Unique'''
filename = prefix + str(global_state.RecordPid) + '.' + Config.get_GratiaExtension() \
+ '__XXXXXXXXXX'
filename = os.path.join(current_dir, filename)
mktemp_pipe = os.popen('mktemp -q "' + filename + '"')
if mktemp_pipe != None:
filename = mktemp_pipe.readline()
mktemp_pipe.close()
filename = filename.strip()
if filename != r'':
return filename

raise IOError
fn_prefix = f'{prefix}.{global_state.RecordPid}.{Config.get_GratiaExtension()}__'
return tempfile.NamedTemporaryFile(prefix=fn_prefix, dir=current_dir, delete=False, mode='w')

def UncompressOutbox(staging_name, target_dir):

Expand Down Expand Up @@ -487,20 +478,21 @@ def CompressOutbox(probe_dir, outbox, outfiles):
DebugPrint(0, msg + ':' + exc)
raise InternalError(msg) from exc

staging_name = GenerateFilename('tz.', staged_store)
DebugPrint(1, 'Compressing outbox in tar.bz2 file: ' + staging_name)
with GenerateFilename('tz', staged_store) as temp_tarfile:
staging_name = temp_tarfile.name
DebugPrint(1, 'Compressing outbox in tar.bz2 file: ' + staging_name)

try:
tar = tarfile.open(staging_name, 'w:bz2')
except KeyboardInterrupt:
raise
except SystemExit:
raise
except Exception as e:
DebugPrint(0, 'Warning: Exception caught while opening tar.bz2 file: ' + staging_name + ':')
DebugPrint(0, 'Caught exception: ', e)
DebugPrintTraceback()
return False
try:
tar = tarfile.open(staging_name, 'w:bz2')
except KeyboardInterrupt:
raise
except SystemExit:
raise
except Exception as e:
DebugPrint(0, 'Warning: Exception caught while opening tar.bz2 file: ' + staging_name + ':')
DebugPrint(0, 'Caught exception: ', e)
DebugPrintTraceback()
return False

try:
for f in outfiles:
Expand Down Expand Up @@ -599,12 +591,11 @@ def OpenNewRecordFile(dirIndex):
raise InternalError(msg) from exc

try:
filename = GenerateFilename('r.', working_dir)
DebugPrint(3, 'Creating file:', filename)
outstandingRecordCount += 1
f = open(filename, 'w')
dirIndex = index
return (f, dirIndex)
with GenerateFilename('r', working_dir) as recordfile:
DebugPrint(3, 'Creating file:', recordfile.name)
outstandingRecordCount += 1
dirIndex = index
return (recordfile, dirIndex)
except Exception as exc:
msg = 'ERROR: Caught exception while creating file'
DebugPrint(0, msg + ': ', exc)
Expand Down
6 changes: 2 additions & 4 deletions common/gratia/common/xml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,8 @@ def UsageCheckXmldoc(xmlDoc, external, resourceType=None):
subdir = os.path.join(Config.get_DataFolder(), "quarantine", 'subdir.' + Config.getFilenameFragment())
if not os.path.exists(subdir):
os.mkdir(subdir)
fn = sandbox_mgmt.GenerateFilename("r.", subdir)
writer = open(fn, 'w')
usageRecord.writexml(writer)
writer.close()
with sandbox_mgmt.GenerateFilename("r", subdir) as writer:
usageRecord.writexml(writer)
usageRecord.unlink()
continue

Expand Down
120 changes: 120 additions & 0 deletions test/test_sandbox_mgmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/bin/env python

import glob
import os
import shutil
import tarfile
import tempfile
import unittest
from unittest.mock import patch, PropertyMock

from common.gratia.common import sandbox_mgmt

class SandboxMgmtTests(unittest.TestCase):

@patch('gratia.common.config.ConfigProxy.get_GratiaExtension', create=True, return_value='test-extension')
def test_GenerateFilename(self, mock_config):
"""GenerateFilename creates a temporary file and returns the path to the file
"""
prefix = 'test-prefix'
temp_dir = '/tmp'

try:
with sandbox_mgmt.GenerateFilename(prefix, temp_dir) as filename:
self.assertTrue(os.path.exists(filename.name),
f'Failed to create temporary file ({filename.name})')
self.assertEqual(temp_dir.rstrip('/'),
os.path.dirname(filename.name),
f'Temporary file {filename.name} placed in the wrong directory')
self.assertRegex(filename.name,
rf'{temp_dir}/*{prefix}\.\d+\.{mock_config.return_value}__\w+',
'Unexpected file name format')
finally:
try:
filename.close()
os.remove(filename.name)
except (FileNotFoundError, NameError):
# don't need to clean up what's not there
pass

class CompressOutboxTests(unittest.TestCase):
def setUp(self):
#provision test environment
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved
self.probe_dir = tempfile.mkdtemp()
self.outbox = os.path.join(self.probe_dir, 'outbox')
os.makedirs(self.outbox, exist_ok=True)
self.outfiles = ['testfile1', 'testfile2']
# add content to the files
for fname in self.outfiles:
with open(os.path.join(self.outbox, fname), 'w') as f:
f.write('test content')
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved

def tearDown(self):
#Remove probe_dir after test
shutil.rmtree(self.probe_dir)

@patch('gratia.common.config.ConfigProxy.getFilenameFragment', create=True, return_value ='test-filename')
@patch('gratia.common.config.ConfigProxy.get_GratiaExtension', create=True, return_value ='test-extension')
def test_compress_outbox(self, mock_gratia_ext, mock_file_frag):
"""CompressOutbox compresses the files in the outbox directory
and stores the resulting tarball in probe_dir/staged.
"""
#Parameters for function
probe_dir = self.probe_dir
outbox = self.outbox
outfiles = self.outfiles

#Assert that CompressOutbox returns True
result = sandbox_mgmt.CompressOutbox(probe_dir, outbox, outfiles)
self.assertTrue(result)
brianhlin marked this conversation as resolved.
Show resolved Hide resolved

@patch('gratia.common.config.ConfigProxy.getFilenameFragment', create=True, return_value ='test-filename')
@patch('gratia.common.config.ConfigProxy.get_GratiaExtension', create=True, return_value ='test-extension')
def test_tarball_creation(self, mock_gratia_ext, mock_filefrag):
"""
Assert that tarball is created in the correct location
"""
#Parameters for function
probe_dir = self.probe_dir
outbox = self.outbox
outfiles = self.outfiles

sandbox_mgmt.CompressOutbox(probe_dir, outbox, outfiles)

#Finds tarball that matches GenerateFilename function output
os.chdir(f'{probe_dir}/staged/store')
for tarball in glob.glob("tz.*.test-extension__*"):
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved
return(tarball)
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved

self.assertTrue(os.path.exists(path_to_tarball/{tarball}))
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved

@patch('gratia.common.config.ConfigProxy.getFilenameFragment', create=True, return_value ='test-filename')
@patch('gratia.common.config.ConfigProxy.get_GratiaExtension', create=True, return_value ='test-extension')
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved
def test_tarball_contents(self, mock_gratia_ext, mock_filefrag):
"""
Assert that unpacked tarball contains files from outfiles
"""
#Parameters for function
probe_dir = self.probe_dir
outbox = self.outbox
outfiles = self.outfiles

sandbox_mgmt.CompressOutbox(probe_dir, outbox, outfiles)

#Finds tarball that matches GenerateFilename() output
os.chdir(f'{probe_dir}/staged/store')
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved
for tarball in glob.glob("tz.*.test-extension__*"):
return(tarball)

#Gets names of files within tarball
file_obj= tarfile.open(tarball,"r")
namelist=file_obj.getnames()
for names in namelist:
return(names)
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved
file_obj.close()
jeff-takaki marked this conversation as resolved.
Show resolved Hide resolved

# Sort both lists to ensure order-independent comparison
names.sort()
outfiles.sort()

self.assertListEqual(names, outfiles)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good assertion! Another one would be to also check that the contents are expected