Refactored DB connection sections

* Connect for each post instead of once per session
* Restart the SSH tunnel if it was closed
* Avoid issue where the connection(s) were broken by aborting Prime95
This commit is contained in:
2Shirt 2018-10-03 23:25:15 -06:00
parent 00b3c405d0
commit 117a47e94c

View file

@ -113,31 +113,33 @@ def connect_to_db():
] ]
# Establish SSH tunnel unless one already exists # Establish SSH tunnel unless one already exists
if not ost_db['Tunnel']: if not ost_db['Tunnel'] or ost_db['Tunnel'].poll() is not None:
ost_db['Tunnel'] = popen_program(cmd) ost_db['Tunnel'] = popen_program(cmd)
# Establish SQL connection (try a few times in case SSH is slow) # Establish SQL connection (try a few times in case SSH is slow)
for x in range(5): for x in range(5):
sleep(1) sleep(2)
try: try:
ost_db['Connection'] = mariadb.connect( ost_db['Connection'] = mariadb.connect(
user=DB_USER, password=DB_PASS, database=DB_NAME) user=DB_USER, password=DB_PASS, database=DB_NAME)
ost_db['Cursor'] = ost_db['Connection'].cursor()
except: except:
# Just try again # Just try again
pass pass
else: else:
break break
ost_db['Cursor'] = ost_db['Connection'].cursor()
ost_db['Errors'] = False
def disconnect_from_db(): def disconnect_from_db(reset_errors=False):
"""Disconnect from SQL DB and close SSH tunnel.""" """Disconnect from SQL DB."""
if ost_db['Cursor']: for c in ['Cursor', 'Connection']:
ost_db['Cursor'].close() try:
if ost_db['Connection']: ost_db[c].close()
ost_db['Connection'].close() except:
if ost_db['Tunnel']: # Ignore
ost_db['Tunnel'].kill() pass
ost_db[c] = None
if reset_errors:
ost_db['Errors'] = False
def export_png_graph(name, dev): def export_png_graph(name, dev):
"""Exports PNG graph using gnuplot, returns file path as str.""" """Exports PNG graph using gnuplot, returns file path as str."""
@ -358,6 +360,8 @@ def menu_diags(*args):
if not result['CS']: if not result['CS']:
print_warning('osTicket integration disabled for this run.') print_warning('osTicket integration disabled for this run.')
pause() pause()
ticket_number = get_osticket_number()
disconnect_from_db()
# Save log for non-quick tests # Save log for non-quick tests
global_vars['Date-Time'] = time.strftime("%Y-%m-%d_%H%M_%z") global_vars['Date-Time'] = time.strftime("%Y-%m-%d_%H%M_%z")
global_vars['LogDir'] = '{}/Logs/{}_{}'.format( global_vars['LogDir'] = '{}/Logs/{}_{}'.format(
@ -367,7 +371,6 @@ def menu_diags(*args):
os.makedirs(global_vars['LogDir'], exist_ok=True) os.makedirs(global_vars['LogDir'], exist_ok=True)
global_vars['LogFile'] = '{}/Hardware Diagnostics.log'.format( global_vars['LogFile'] = '{}/Hardware Diagnostics.log'.format(
global_vars['LogDir']) global_vars['LogDir'])
ticket_number = get_osticket_number()
run_tests(diag_modes[int(selection)-1]['Tests'], ticket_number) run_tests(diag_modes[int(selection)-1]['Tests'], ticket_number)
elif selection == 'A': elif selection == 'A':
run_program(['hw-diags-audio'], check=False, pipe=False) run_program(['hw-diags-audio'], check=False, pipe=False)
@ -391,7 +394,7 @@ def menu_diags(*args):
break break
# Done # Done
disconnect_from_db() disconnect_from_db(reset_errors=True)
def osticket_get_ticket_name(ticket_id): def osticket_get_ticket_name(ticket_id):
"""Lookup ticket and return name as str.""" """Lookup ticket and return name as str."""
@ -420,6 +423,9 @@ def osticket_needs_attention(ticket_id):
return # This function has been DISABLED due to a repurposing of that flag return # This function has been DISABLED due to a repurposing of that flag
if not ticket_id: if not ticket_id:
raise GenericError raise GenericError
# Connect to DB
connect_to_db()
if not ost_db['Cursor']: if not ost_db['Cursor']:
# Skip section # Skip section
return return
@ -435,11 +441,15 @@ def osticket_needs_attention(ticket_id):
ost_db['Cursor'].execute(sql_cmd) ost_db['Cursor'].execute(sql_cmd)
except: except:
ost_db['Errors'] = True ost_db['Errors'] = True
disconnect_from_db()
def osticket_post_reply(ticket_id, response): def osticket_post_reply(ticket_id, response):
"""Post a reply to a ticket in osTicket.""" """Post a reply to a ticket in osTicket."""
if not ticket_id: if not ticket_id:
raise GenericError raise GenericError
# Connect to DB
connect_to_db()
if not ost_db['Cursor']: if not ost_db['Cursor']:
# Skip section # Skip section
return return
@ -458,11 +468,15 @@ def osticket_post_reply(ticket_id, response):
ost_db['Cursor'].execute(sql_cmd) ost_db['Cursor'].execute(sql_cmd)
except: except:
ost_db['Errors'] = True ost_db['Errors'] = True
disconnect_from_db()
def osticket_set_drive_result(ticket_id, passed): def osticket_set_drive_result(ticket_id, passed):
"""Marks the pass/fail box for the drive(s) in osTicket.""" """Marks the pass/fail box for the drive(s) in osTicket."""
if not ticket_id: if not ticket_id:
raise GenericError raise GenericError
# Connect to DB
connect_to_db()
if not ost_db['Cursor']: if not ost_db['Cursor']:
# Skip section # Skip section
return return
@ -479,6 +493,7 @@ def osticket_set_drive_result(ticket_id, passed):
ost_db['Cursor'].execute(sql_cmd) ost_db['Cursor'].execute(sql_cmd)
except: except:
ost_db['Errors'] = True ost_db['Errors'] = True
disconnect_from_db()
def pad_with_dots(s, left_pad=True): def pad_with_dots(s, left_pad=True):
"""Replace ' ' padding with '..' for osTicket posts.""" """Replace ' ' padding with '..' for osTicket posts."""
@ -963,7 +978,7 @@ def run_mprime(ticket_number):
TESTS['Progress Out']).split()) TESTS['Progress Out']).split())
run_program('tmux split-window -bd watch -c -n1 -t hw-sensors'.split()) run_program('tmux split-window -bd watch -c -n1 -t hw-sensors'.split())
run_program('tmux resize-pane -y 3'.split()) run_program('tmux resize-pane -y 3'.split())
# Start test # Start test
run_program(['apple-fans', 'max']) run_program(['apple-fans', 'max'])
try: try:
@ -1130,7 +1145,7 @@ def run_nvme_smart(ticket_number):
run_program( run_program(
'sudo smartctl -t short /dev/{}'.format(name).split(), 'sudo smartctl -t short /dev/{}'.format(name).split(),
check=False) check=False)
# Wait and show progress (in 10 second increments) # Wait and show progress (in 10 second increments)
for iteration in range(int(test_length*60/10)): for iteration in range(int(test_length*60/10)):
# Update SMART data # Update SMART data
@ -1206,7 +1221,7 @@ def run_tests(tests, ticket_number=None):
run_badblocks(ticket_number) run_badblocks(ticket_number)
if TESTS['iobenchmark']['Enabled']: if TESTS['iobenchmark']['Enabled']:
run_iobenchmark(ticket_number) run_iobenchmark(ticket_number)
# Show results # Show results
if ticket_number: if ticket_number:
post_drive_results(ticket_number) post_drive_results(ticket_number)
@ -1248,7 +1263,7 @@ def scan_disks(full_paths=False, only_path=None):
TESTS['NVMe/SMART']['Status'][d['name']] = 'Pending' TESTS['NVMe/SMART']['Status'][d['name']] = 'Pending'
TESTS['badblocks']['Status'][d['name']] = 'Pending' TESTS['badblocks']['Status'][d['name']] = 'Pending'
TESTS['iobenchmark']['Status'][d['name']] = 'Pending' TESTS['iobenchmark']['Status'][d['name']] = 'Pending'
for dev, data in devs.items(): for dev, data in devs.items():
# Get SMART attributes # Get SMART attributes
run_program( run_program(
@ -1257,7 +1272,7 @@ def scan_disks(full_paths=False, only_path=None):
dev).split(), dev).split(),
check = False) check = False)
data['smartctl'] = get_smart_details(dev) data['smartctl'] = get_smart_details(dev)
# Get NVMe attributes # Get NVMe attributes
if data['lsblk']['tran'] == 'nvme': if data['lsblk']['tran'] == 'nvme':
cmd = 'sudo nvme smart-log /dev/{} -o json'.format(dev).split() cmd = 'sudo nvme smart-log /dev/{} -o json'.format(dev).split()
@ -1296,7 +1311,7 @@ def scan_disks(full_paths=False, only_path=None):
else: else:
data['Quick Health OK'] = False data['Quick Health OK'] = False
data['SMART Support'] = False data['SMART Support'] = False
# Ask for manual overrides if necessary # Ask for manual overrides if necessary
if TESTS['badblocks']['Enabled'] or TESTS['iobenchmark']['Enabled']: if TESTS['badblocks']['Enabled'] or TESTS['iobenchmark']['Enabled']:
show_disk_details(data) show_disk_details(data)