Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Both Python 2/3 compatibility #22

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 62 additions & 47 deletions iosCertTrustManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Python-ASN1 is copyright (c) 2007-2008 by Geert Jansen <[email protected]>.
# see https://github.com/geertj/python-asn1

from __future__ import print_function
import os
import sys
import argparse
Expand All @@ -29,9 +30,23 @@
import hashlib
import subprocess
import string
import binascii
import plistlib


if hasattr(__builtins__, 'raw_input'):
input = raw_input


if hasattr(plistlib, 'readPlist'):
readPlist = plistlib.readPlist
else:
def readPlist(path_or_file):
if isinstance(path_or_file, str):
with open(path_or_file, 'rb') as f:
return plistlib.load(f)
return plistlib.load(path_or_file)


def query_yes_no(question, default="yes"):
"""Ask a yes/no question via raw_input() and return their answer.

Expand All @@ -55,10 +70,10 @@ def query_yes_no(question, default="yes"):

while 1:
sys.stdout.write(question + prompt)
choice = raw_input().lower()
choice = input().lower()
if default is not None and choice == '':
return default
elif choice in valid.keys():
elif choice in list(valid.keys()):
return valid[choice]
else:
sys.stdout.write("Please respond with 'yes' or 'no' "\
Expand Down Expand Up @@ -99,16 +114,16 @@ def start(self):
def enter(self, nr, cls):
"""Start a constructed data value."""
if self.m_stack is None:
raise Error, 'Encoder not initialized. Call start() first.'
raise Error('Encoder not initialized. Call start() first.')
self._emit_tag(nr, ASN1.TypeConstructed, cls)
self.m_stack.append([])

def leave(self):
"""Finish a constructed data value."""
if self.m_stack is None:
raise Error, 'Encoder not initialized. Call start() first.'
raise Error('Encoder not initialized. Call start() first.')
if len(self.m_stack) == 1:
raise Error, 'Tag stack is empty.'
raise Error('Tag stack is empty.')
value = ''.join(self.m_stack[-1])
del self.m_stack[-1]
self._emit_length(len(value))
Expand All @@ -117,17 +132,17 @@ def leave(self):
def write(self, value, nr, typ, cls):
"""Write a primitive data value."""
if self.m_stack is None:
raise Error, 'Encoder not initialized. Call start() first.'
raise Error('Encoder not initialized. Call start() first.')
self._emit_tag(nr, typ, cls)
self._emit_length(len(value))
self._emit(value)

def output(self):
"""Return the encoded output."""
if self.m_stack is None:
raise Error, 'Encoder not initialized. Call start() first.'
raise Error('Encoder not initialized. Call start() first.')
if len(self.m_stack) != 1:
raise Error, 'Stack is not empty.'
raise Error('Stack is not empty.')
output = ''.join(self.m_stack[0])
return output

Expand All @@ -154,7 +169,7 @@ def _emit_tag_long(self, nr, typ, cls):
values.append((nr & 0x7f) | 0x80)
nr >>= 7
values.reverse()
values = map(chr, values)
values = list(map(chr, values))
for val in values:
self._emit(val)

Expand All @@ -177,7 +192,7 @@ def _emit_length_long(self, length):
values.append(length & 0xff)
length >>= 8
values.reverse()
values = map(chr, values)
values = list(map(chr, values))
# really for correctness as this should not happen anytime soon
assert len(values) < 127
head = chr(0x80 | len(values))
Expand All @@ -202,15 +217,15 @@ def __init__(self):
def start(self, data):
"""Start processing `data'."""
if not isinstance(data, str):
raise Error, 'Expecting string instance.'
raise Error('Expecting string instance.')
self.m_stack = [[0, data]]
self.m_tag = None

def peek(self):
"""Return the value of the next tag without moving to the next
TLV record."""
if self.m_stack is None:
raise Error, 'No input selected. Call start() first.'
raise Error('No input selected. Call start() first.')
if self._end_of_input():
return None
if self.m_tag is None:
Expand All @@ -220,7 +235,7 @@ def peek(self):
def read(self):
"""Read a simple value and move to the next TLV record."""
if self.m_stack is None:
raise Error, 'No input selected. Call start() first.'
raise Error('No input selected. Call start() first.')
if self._end_of_input():
return None
tag = self.peek()
Expand All @@ -236,10 +251,10 @@ def eof(self):
def enter(self):
"""Enter a constructed tag."""
if self.m_stack is None:
raise Error, 'No input selected. Call start() first.'
raise Error('No input selected. Call start() first.')
nr, typ, cls = self.peek()
if typ != ASN1.TypeConstructed:
raise Error, 'Cannot enter a non-constructed tag.'
raise Error('Cannot enter a non-constructed tag.')
length = self._read_length()
bytes = self._read_bytes(length)
self.m_stack.append([0, bytes])
Expand All @@ -248,9 +263,9 @@ def enter(self):
def leave(self):
"""Leave the last entered constructed tag."""
if self.m_stack is None:
raise Error, 'No input selected. Call start() first.'
raise Error('No input selected. Call start() first.')
if len(self.m_stack) == 1:
raise Error, 'Tag stack is empty.'
raise Error('Tag stack is empty.')
del self.m_stack[-1]
self.m_tag = None

Expand All @@ -275,10 +290,10 @@ def _read_length(self):
if byte & 0x80:
count = byte & 0x7f
if count == 0x7f:
raise Error, 'ASN1 syntax error'
raise Error('ASN1 syntax error')
bytes = self._read_bytes(count)
bytes = [ ord(b) for b in bytes ]
length = 0L
length = 0
for byte in bytes:
length = (length << 8) | byte
try:
Expand All @@ -301,7 +316,7 @@ def _read_byte(self):
try:
byte = ord(input[index])
except IndexError:
raise Error, 'Premature end of input.'
raise Error('Premature end of input.')
self.m_stack[-1][0] += 1
return byte

Expand All @@ -311,7 +326,7 @@ def _read_bytes(self, count):
index, input = self.m_stack[-1]
bytes = input[index:index+count]
if len(bytes) != count:
raise Error, 'Premature end of input.'
raise Error('Premature end of input.')
self.m_stack[-1][0] += count
return bytes

Expand Down Expand Up @@ -377,7 +392,7 @@ def get_subject(self):
possl = subprocess.Popen(['openssl', 'x509', '-inform', 'DER', '-noout', '-subject', '-nameopt', 'oneline'],
shell=False, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=None)
subjectText, error_text = possl.communicate(self.get_data())
return subjectText
return subjectText.decode('utf-8')
return None

def get_subject_ASN1(self):
Expand Down Expand Up @@ -461,18 +476,18 @@ def is_valid(self):

def _add_record(self, sha, subj, tset, data):
if not self.is_valid():
print " Invalid TrustStore.sqlite3"
print(" Invalid TrustStore.sqlite3")
return
conn = sqlite3.connect(self._path)
c = conn.cursor()
c.execute('SELECT COUNT(*) FROM tsettings WHERE subj=?', [sqlite3.Binary(subj)])
row = c.fetchone()
if row[0] == 0:
c.execute('INSERT INTO tsettings (' + self._hash + ', subj, tset, data) VALUES (?, ?, ?, ?)', [sqlite3.Binary(sha), sqlite3.Binary(subj), sqlite3.Binary(tset), sqlite3.Binary(data)])
print ' Certificate added'
print(' Certificate added')
else:
c.execute('UPDATE tsettings SET ' + self._hash + '=?, tset=?, data=? WHERE subj=?', [sqlite3.Binary(sha), sqlite3.Binary(tset), sqlite3.Binary(data), sqlite3.Binary(subj)])
print ' Existing certificate replaced'
print(' Existing certificate replaced')
conn.commit()
conn.close()

Expand All @@ -487,20 +502,20 @@ def _saveBlob(self, baseName, name, data):
def add_certificate(self, certificate):
# this also populates self._hash
if not self.is_valid():
print " Invalid TrustStore.sqlite3"
print(" Invalid TrustStore.sqlite3")
return
self._add_record(certificate.get_fingerprint(self._hash), certificate.get_subject_ASN1(),
self._tset, certificate.get_data())

def export_certificates(self, base_filename):
if not self.is_valid():
print " Invalid TrustStore.sqlite3"
print(" Invalid TrustStore.sqlite3")
return
conn = sqlite3.connect(self._path)
c = conn.cursor()
index = 1
print
print self._title
print()
print(self._title)
for row in c.execute('SELECT subj, data FROM tsettings'):
cert = Certificate()
cert.load_data(row[1])
Expand All @@ -511,7 +526,7 @@ def export_certificates(self, base_filename):

def export_certificates_data(self, base_filename):
if not self.is_valid():
print " Invalid TrustStore.sqlite3"
print(" Invalid TrustStore.sqlite3")
return
conn = sqlite3.connect(self._path)
c = conn.cursor()
Expand All @@ -529,7 +544,7 @@ def export_certificates_data(self, base_filename):
def import_certificate_data(self, base_filename):
# this also populates self._hash
if not self.is_valid():
print " Invalid TrustStore.sqlite3"
print(" Invalid TrustStore.sqlite3")
return
certificateSubject = self._loadBlob(base_filename, 'subj')
certificateTSet = self._loadBlob(base_filename, 'tset')
Expand All @@ -541,27 +556,27 @@ def import_certificate_data(self, base_filename):
self._add_record(certificateSha, certificateSubject, certificateTSet, certificateData)

def list_certificates(self):
print
print self._title
print()
print(self._title)
if not self.is_valid():
print " Invalid TrustStore.sqlite3"
print(" Invalid TrustStore.sqlite3")
return
conn = sqlite3.connect(self._path)
c = conn.cursor()
for row in c.execute('SELECT data FROM tsettings'):
cert = Certificate()
cert.load_data(row[0])
print " ", cert.get_subject()
print(" ", cert.get_subject())
conn.close()

def delete_certificates(self):
if not self.is_valid():
print " Invalid TrustStore.sqlite3"
print(" Invalid TrustStore.sqlite3")
return
conn = sqlite3.connect(self._path)
c = conn.cursor()
print
print self._title
print()
print(self._title)
todelete = []
for row in c.execute('SELECT subj, data FROM tsettings'):
cert = Certificate()
Expand Down Expand Up @@ -591,7 +606,7 @@ def __init__(self, simulatordir):
self._is_valid = False
infofile = simulatordir + "/device.plist"
if os.path.isfile(infofile):
info = plistlib.readPlist(infofile)
info = readPlist(infofile)
runtime = info["runtime"]
if runtime.startswith(self.runtimeName):
self.version = runtime[len(self.runtimeName):].replace("-", ".")
Expand Down Expand Up @@ -632,7 +647,7 @@ def __init__(self, path):
info_plist = self._path + "/Info.plist"
if os.path.isfile(info_plist):
try:
info = plistlib.readPlist(info_plist)
info = readPlist(info_plist)
self.device_name = info["Device Name"]
self.title = "Backup of " + self.device_name + " - " + str(info["Last Backup Date"])
self._isvalid = True
Expand Down Expand Up @@ -663,15 +678,15 @@ class Program:
def import_to_simulator(self, certificate_filepath, truststore_filepath=None):
cert = Certificate()
cert.load_PEMfile(certificate_filepath)
print cert.get_subject()
print(cert.get_subject())
if truststore_filepath:
if self.always_yes or query_yes_no("Import certificate to " + truststore_filepath, "no") == "yes":
tstore = TrustStore(truststore_filepath, always_yes=self.always_yes)
tstore.add_certificate(cert)
return
for simulator in simulators():
if self.always_yes or query_yes_no("Import certificate to " + simulator.title.encode('utf-8'), "no") == "yes":
print "Importing to " + simulator.truststore_file
print("Importing to " + simulator.truststore_file)
tstore = TrustStore(simulator.truststore_file, always_yes=self.always_yes)
tstore.add_certificate(cert)

Expand All @@ -683,7 +698,7 @@ def addfromdump(self, dump_base_filename, truststore_filepath=None):
return
for simulator in simulators():
if self.always_yes or query_yes_no("Import to " + simulator.title, "no") == "yes":
print "Importing to " + simulator.truststore_file
print("Importing to " + simulator.truststore_file)
tstore = TrustStore(simulator.truststore_file, always_yes=self.always_yes)
tstore.import_certificate_data(dump_base_filename)

Expand Down Expand Up @@ -751,7 +766,7 @@ def run(self):
else:
self.always_yes = False
if args.truststore and not os.path.isfile(args.truststore):
print "invalid file: ", args.truststore
print("invalid file: ", args.truststore)
exit(1)
if args.devicebackup:
if args.list:
Expand All @@ -761,7 +776,7 @@ def run(self):
elif args.dump_base_filename:
self.export_device_trustedcertificates(args.dump_base_filename, True)
else:
print "option not supported"
print("option not supported")
elif args.list:
self.list_simulator_trustedcertificates(args.truststore)
elif args.delete:
Expand All @@ -774,7 +789,7 @@ def run(self):
self.export_simulator_trustedcertificates(args.dump_base_filename, True, args.truststore)
elif args.adddump_base_filename:
self.addfromdump(args.adddump_base_filename, args.truststore)
print
print()

if __name__ == "__main__":
program = Program()
Expand Down