orhankislal commented on a change in pull request #14: Image loader python
module, and demo notebook
URL: https://github.com/apache/madlib-site/pull/14#discussion_r294964778
##########
File path: community-artifacts/madlib_image_loader.py
##########
@@ -0,0 +1,384 @@
+#
+# Python module to load images into postgres or greenplum db, for
+# use with madlib deep_learning module.
+#
+# The format of the image tables created will have at least 3 rows:
+# (id SERIAL, x REAL[], y). Each row is 1 image,
+# with image data represented by x (a 3D array of type "real"), and
+# y (category) as text. id is just a unique identifier for each image,
+# so they don't get mixed up during prediction.
+#
+# ImageLoader.ROWS_PER_FILE = 1000 by default; this is the number of rows per
+# temporary file (or StringIO buffer) loaded at once.
+#
+
+# User API is through ImageLoader and DbCredentials class constructors,
+# and ImageLoader.load_np_array_to_table
+#
+# 1. Create objects:
+#
+# db_creds = DbCredentials(db_name='madlib', user=None, password='',
host='localhost', port=5432)
+#
+# iloader = ImageLoader(db_creds, num_workers, table_name=None)
+#
+# 2. Perform parallel image loading:
+#
+# iloader.load_np_array_to_table(data_x, data_y, table_name,
append=False, img_names=None, no_temp_files=False)
+#
+# data_x contains image data in np.array format, and data_y is a 1D np.array
of the image categories (labels).
+#
+# Default database credentials are: localhost port 5432, madlib db, no
password. Calling the default
+# constructor DbCredentials() will attempt to connect using these
credentials, but any of them can be
+# overriden.
+#
+# append=False attempts to create a new table, while append=True appends
more images to an existing table.
+#
+# If the user passes a table_name while creating ImageLoader object, it will
be used for all further
+# calls to load_np_array_to_table. It can be changed by passing it as a
parameter during the
+# actual call to load_np_array_to_table, and if so future calls will load
to that table name instead.
+# This avoids needing to pass the table_name again every time, but also
allows it to be changed at
+# any time.
+#
+# EXPERIMENTAL: If no_temp_files=True, the operation will happen without
writing out
+# the tables to temporary files before loading them. Instead,
+# an in-memory filelike buffer (StringIO) will be used to
build
+# the tables before loading.
+#
+# img_names: this is currently unused, but we plan to use it when we add
support for loading images
+# from disk.
+
+import numpy as np
+from keras.preprocessing import image
+from keras.datasets import cifar10
+import keras
+import sys
+import os
+import re
+import gc
+import random
+import string
+import psycopg2 as db
+from multiprocessing import Pool, current_process
+from shutil import rmtree
+import time
+import signal
+import traceback
+import exceptions
+from cStringIO import StringIO
+
+class SignalException (Exception):
+ pass
+
+def _worker_sig_handler(signum, frame):
+ if signum == signal.SIGINT:
+ msg = "Received SIGINT in worker."
+ elif signum == signal.SIGTERM:
+ msg = "Received SIGTERM in worker."
+ _worker_cleanup()
+ elif signum == signal.SIGSEGV:
+ msg = "Received SIGSEGV in worker."
+ traceback.print_stack(frame)
+ else:
+ msg = "Received unknown signal in worker"
+
+ raise SignalException(msg)
+
+def _call_worker(data):
+ try:
+ if iloader.no_temp_files:
+ iloader._just_load(data)
+ else:
+ iloader._write_tmp_file_and_load(data)
+ except Exception as e:
+ if iloader.tmp_dir:
+ iloader.rm_temp_dir()
+ # For some reason, when an exception is raised in a worker, the
+ # stack trace doesn't get shown. So we have to print it ourselves
+ # (actual exception # msg will get printed by mother process.
+ #
+ print "\n{0}: Error loading images:".format(iloader.pr_name)
+ print traceback.format_exc()
+ raise e
+
+def _worker_cleanup(dummy):
+ # Called when worker process is terminated
+ if iloader.tmp_dir:
+ iloader.rm_temp_dir()
+
+def init_worker(mother_pid, table_name, append, no_temp_files, db_creds):
+ pr = current_process()
+ print("Initializing {0} [pid {1}]".format(pr.name, pr.pid))
+
+ try:
+ iloader = ImageLoader(db_creds=db_creds)
+ iloader.mother_pid = mother_pid
+ iloader.table_name = table_name
+ iloader.no_temp_files = no_temp_files
+ iloader.img_names = None
+ signal.signal(signal.SIGINT, _worker_sig_handler)
+ signal.signal(signal.SIGSEGV, _worker_sig_handler)
+ if not no_temp_files:
+ iloader.mk_temp_dir()
+ iloader.db_connect()
+ except Exception as e:
+ if iloader.tmp_dir:
+ iloader.rm_temp_dir()
+ print "\nException in {0} init_worker:".format(pr.name)
+ print traceback.format_exc()
+ raise e
+
+class DbCredentials:
+ def __init__(self, db_name='madlib', user=None, password='',
host='localhost', port=5432):
+ if user:
+ self.user = user
+ else:
+ self.user = os.environ["USER"]
+
+ self.db_name = db_name
+ self.password = password
+ self.host = host
+ self.port = port
+
+class ImageLoader:
+ def __init__(self, db_creds=None, num_workers=None):
+ self.num_workers = num_workers
+ self.append = False
+ self.img_num = 0
+ self.db_creds = db_creds
+ self.db_conn = None
+ self.db_cur = None
+ self.tmp_dir = None
+ self.mother = False
+ self.pr_name = current_process().name
+
+ global iloader # Singleton per process
+ iloader = self
+
+ def _random_string(self):
+ return ''.join([random.choice(string.ascii_letters + string.digits)
for n in xrange(10)])
+
+ def mk_temp_dir(self):
+ self.tmp_dir = '/tmp/madlib_{0}'.format(self._random_string())
+ os.mkdir(self.tmp_dir)
+ print("{0}: Created temporary directory {0}".format(self.pr_name,
self.tmp_dir))
+
+ def rm_temp_dir(self):
+ rmtree(self.tmp_dir)
+ self.tmp_dir = None
+ print("{0}: Removed temporary directory {0}".format(self.pr_name,
self.tmp_dir))
+
+ def db_connect(self):
+ if self.db_cur:
+ return
+
+ db_name = self.db_creds.db_name
+ user = self.db_creds.user
+ host = self.db_creds.host
+ port = self.db_creds.port
+ password = self.db_creds.password
+ connection_string =\
+ "dbname={0} user={1} host={2} port={3}".format(db_name,
+ user,
+ host,
+ port)
+
+ try:
+ self.db_conn = db.connect(connection_string)
+ self.db_cur = self.db_conn.cursor()
+ self.db_conn.autocommit = True
+
+ except (Exception, db.DatabaseError) as error:
+ self.db_close()
+ print(error)
+ raise error
+ print("{0}: Connected to {1} db.".format(self.pr_name,
self.db_creds.db_name))
+
+ def db_exec(self, query, args=None, echo=True):
+ if self.db_cur is not None:
+ if echo:
+ print "Executing: {0}".format(query)
+ self.db_cur.execute(query, args)
+ if echo:
+ print self.db_cur.statusmessage
+ else:
+ print("{0}: db_cur is None in db_exec--aborting", self.pr_name)
+ raise Exception
+
+ def db_close(self):
+ if self.db_cur is not None:
+ self.db_cur.close()
+ self.db_cur = None
+ else:
+ print("{0}: WARNING: db_cur is None in db_close", self.pr_name)
+ if isinstance(self.db_conn, db.extensions.connection):
+ self.db_conn.close()
+ self.db_conn = None
+
+ def _gen_lines(self, data, img_names=None):
+ for i, row in enumerate(data):
+ x, y = row
+ line = str(x.tolist())
+ line = line.replace('[','{').replace(']','}')
+ if img_names:
+ line = '"{0}", "{1}", "{2}"\n'.format(line, y, img_names[i])
+ else:
+ line = '{0}|{1}\n'.format(line, y)
+ yield line
+
+ def _write_file(self, file_object, data, img_names=None):
+ lines = self._gen_lines(data, img_names)
+ file_object.writelines(lines)
+ # Do we actually need this?
+# file_object.write('\.\n')
+
+ ROWS_PER_FILE = 1000
+
+ # Copies from open file-like object f into database
+ def _copy_into_db(self, f, data):
+ table_name = self.table_name
+ img_names = self.img_names
+
+ if img_names:
+ self.db_cur.copy_from(f, table_name, sep=',',
columns=['x','y','img_name'])
+ else:
+ self.db_cur.copy_from(f, table_name, sep='|', columns=['x','y'])
+
+ print("{0}: Loaded {1} images into {2}".format(self.pr_name,
len(data), self.table_name))
+
+ # Use in-memory buffer as file-like object to load a block of data into db
+ # (no temp files written)
+ def _just_load(self, data):
+ f = StringIO()
+ self._write_file(f, data)
+ self._copy_into_db(f, data)
+ f.close()
+
+ # Write out a temporary file and then load it into db as a table
+ def _write_tmp_file_and_load(self, data):
+ table_name = self.table_name
+
+ if not self.tmp_dir:
+ print("{0}: Can't find temporary directory...
aborting.".format(self.pr_name))
+ raise Exception
+
+ filename = os.path.join(self.tmp_dir, '{0}{1:04}.tmp'.format(
+ table_name, self.img_num))
+
+ self.img_num += 1
+ with file(filename, 'w') as f:
+ self._write_file(f, data)
+
+ print("{0}: Wrote {1} images to {2}".format(self.pr_name, len(data),
filename))
+
+ with file(filename, 'r') as f:
+ self._copy_into_db(f, data)
+
+ def load_np_array_to_table(self, data_x, data_y, table_name, append=False,
+ img_names=None, no_temp_files=False):
+ """
+ Loads a numpy array into db. For append=False, creates a new table
and loads the data.
+ For append=True, appends data to existing table. Throws an
exception if
+ append=False and table_name already exists, or if append=True and
table_name
+ does not exist. Makes use of worker processes initialized during
ImageLoader
+ object creation to load in parallel.
+ @data_x independent variable data, a numpy array of images. Size of
first dimension is
+ number of images. Rest of dimensions determined by image
resolution and number
+ of channels.
+ @data_y dependent variable data (image classes), as an numpy array
+ @table_name Name of table in db to load data into
+ @append Whether to create a new table (False) or append to an existing
one (True).
+ If unspecified, default is False
+ @img_names If not None, a list of the image names corresponding to
elements of the
+ data_x numpy array. If present, this is included as a
column in the table.
+ @no_temp_files If specified, no temporary files are written--all
operations are performed
+ in-memory.
+
+ """
+ start_time = time.time()
+ self.mother = True
+ self.append = append
+ self.table_name = table_name
+ self.db_connect()
+
+ if self.append:
+ # Validate that table already exists
+ try:
+ self.db_exec("SELECT count(*) FROM
{0}".format(self.table_name), echo=False)
+ except(Exception, db.DatabaseError):
+ raise exceptions.RuntimeError("Table {0} does not exist in {1}
db. Use append=False to create it before loading."
Review comment:
We might want to raise the actual error message as well, in case the issue
is not table not existing. For example, the user might not have access to the
table.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services