You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2019/06/18 18:56:23 UTC

[GitHub] [madlib-site] orhankislal commented on a change in pull request #14: Image loader python module, and demo notebook

orhankislal commented on a change in pull request #14: Image loader python module, and demo notebook
URL: https://github.com/apache/madlib-site/pull/14#discussion_r294958231
 
 

 ##########
 File path: community-artifacts/madlib_image_loader.py
 ##########
 @@ -0,0 +1,384 @@
+#
+# Python module to load images into postgres or greenplum db, for 
+#  use with madlib deep_learning module.
+#
+# The format of the image tables created will have at least 3 rows:
+#     (id SERIAL, x REAL[], y).  Each row is 1 image,
+#     with image data represented by x (a 3D array of type "real"), and
+#     y (category) as text.  id is just a unique identifier for each image,
+#     so they don't get mixed up during prediction.
+#
+#   ImageLoader.ROWS_PER_FILE = 1000 by default; this is the number of rows per
+#      temporary file (or StringIO buffer) loaded at once.
+#
+
+#   User API is through ImageLoader and DbCredentials class constructors,
+#     and ImageLoader.load_np_array_to_table
+#
+#     1. Create objects:
+#
+#           db_creds = DbCredentials(db_name='madlib', user=None, password='', host='localhost', port=5432)
+#
+#           iloader = ImageLoader(db_creds, num_workers, table_name=None)
+#
+#     2. Perform parallel image loading:
+#
+#           iloader.load_np_array_to_table(data_x, data_y, table_name, append=False, img_names=None, no_temp_files=False)
+#
+#   data_x contains image data in np.array format, and data_y is a 1D np.array of the image categories (labels).
+#
+#   Default database credentials are: localhost port 5432, madlib db, no password.  Calling the default
+#     constructor DbCredentials() will attempt to connect using these credentials, but any of them can be
+#     overriden.
+#
+#   append=False attempts to create a new table, while append=True appends more images to an existing table.
+#
+#   If the user passes a table_name while creating ImageLoader object, it will be used for all further
+#     calls to load_np_array_to_table.  It can be changed by passing it as a parameter during the
+#     actual call to load_np_array_to_table, and if so future calls will load to that table name instead.
+#     This avoids needing to pass the table_name again every time, but also allows it to be changed at
+#     any time.
+#
+#   EXPERIMENTAL:  If no_temp_files=True, the operation will happen without writing out
+#                  the tables to temporary files before loading them.  Instead,
+#                  an in-memory filelike buffer (StringIO) will be used to build
+#                  the tables before loading.
+#   
+#   img_names:  this is currently unused, but we plan to use it when we add support for loading images
+#               from disk.
+
+import numpy as np
+from keras.preprocessing import image
+from keras.datasets import cifar10
+import keras
+import sys
+import os
+import re
+import gc
+import random
+import string
+import psycopg2 as db
+from multiprocessing import Pool, current_process
+from shutil import rmtree
+import time
+import signal
+import traceback
+import exceptions
+from cStringIO import StringIO
+
+class SignalException (Exception):
+    pass
+
+def _worker_sig_handler(signum, frame):
+    if signum == signal.SIGINT:
+        msg = "Received SIGINT in worker."
+    elif signum == signal.SIGTERM:
+        msg = "Received SIGTERM in worker."
+        _worker_cleanup()
+    elif signum == signal.SIGSEGV:
+        msg = "Received SIGSEGV in worker."
+        traceback.print_stack(frame)
+    else:
+        msg = "Received unknown signal in worker"
+
+    raise SignalException(msg)
+
+def _call_worker(data):
+    try:
+        if iloader.no_temp_files:
+            iloader._just_load(data)
+        else:
+            iloader._write_tmp_file_and_load(data)
+    except Exception as e:
+        if iloader.tmp_dir:
+            iloader.rm_temp_dir()
+        # For some reason, when an exception is raised in a worker, the
+        #  stack trace doesn't get shown.  So we have to print it ourselves
+        #  (actual exception #  msg will get printed by mother process.
+        #
+        print "\n{0}: Error loading images:".format(iloader.pr_name)
+        print traceback.format_exc()
+        raise e
+
+def _worker_cleanup(dummy):
+    # Called when worker process is terminated
+    if iloader.tmp_dir:
+        iloader.rm_temp_dir()
+
+def init_worker(mother_pid, table_name, append, no_temp_files, db_creds):
+    pr = current_process()
+    print("Initializing {0} [pid {1}]".format(pr.name, pr.pid))
+
+    try:
+        iloader = ImageLoader(db_creds=db_creds)
+        iloader.mother_pid = mother_pid
+        iloader.table_name = table_name
+        iloader.no_temp_files = no_temp_files
+        iloader.img_names = None
+        signal.signal(signal.SIGINT, _worker_sig_handler)
+        signal.signal(signal.SIGSEGV, _worker_sig_handler)
+        if not no_temp_files:
+            iloader.mk_temp_dir()
+        iloader.db_connect()
+    except Exception as e:
+        if iloader.tmp_dir:
+            iloader.rm_temp_dir()
+        print "\nException in {0} init_worker:".format(pr.name)
+        print traceback.format_exc()
+        raise e
+
+class DbCredentials:
+    def __init__(self, db_name='madlib', user=None, password='', host='localhost', port=5432):
+        if user:
+            self.user = user
+        else:
+            self.user = os.environ["USER"]
+
+        self.db_name = db_name
+        self.password = password
+        self.host = host
+        self.port = port
+
+class ImageLoader:
+    def __init__(self, db_creds=None, num_workers=None):
+        self.num_workers = num_workers
+        self.append = False
+        self.img_num = 0
+        self.db_creds = db_creds
+        self.db_conn = None
+        self.db_cur = None
+        self.tmp_dir = None
+        self.mother = False
+        self.pr_name = current_process().name
+
+        global iloader  # Singleton per process
+        iloader = self
+
+    def _random_string(self):
+        return ''.join([random.choice(string.ascii_letters + string.digits) for n in xrange(10)])
+
+    def mk_temp_dir(self):
+        self.tmp_dir = '/tmp/madlib_{0}'.format(self._random_string())
+        os.mkdir(self.tmp_dir)
+        print("{0}: Created temporary directory {0}".format(self.pr_name, self.tmp_dir))
+
+    def rm_temp_dir(self):
+        rmtree(self.tmp_dir)
+        self.tmp_dir = None
+        print("{0}: Removed temporary directory {0}".format(self.pr_name, self.tmp_dir))
+
+    def db_connect(self):
+        if self.db_cur:
+            return
+
+        db_name = self.db_creds.db_name
+        user = self.db_creds.user
+        host = self.db_creds.host
+        port = self.db_creds.port
+        password = self.db_creds.password
+        connection_string =\
+            "dbname={0} user={1} host={2} port={3}".format(db_name,
+                                                           user,
+                                                           host,
+                                                           port)
+
+        try:
+            self.db_conn = db.connect(connection_string)
+            self.db_cur = self.db_conn.cursor()
+            self.db_conn.autocommit = True
+
+        except (Exception, db.DatabaseError) as error:
+            self.db_close()
+            print(error)
+            raise error
+        print("{0}: Connected to {1} db.".format(self.pr_name, self.db_creds.db_name))
+
+    def db_exec(self, query, args=None, echo=True):
+        if self.db_cur is not None:
+            if echo:
+                print "Executing: {0}".format(query)
+            self.db_cur.execute(query, args)
+            if echo:
+                print self.db_cur.statusmessage
+        else:
+            print("{0}: db_cur is None in db_exec--aborting", self.pr_name)
+            raise Exception
+
+    def db_close(self):
+        if self.db_cur is not None:
+            self.db_cur.close()
+            self.db_cur = None
+        else:
+            print("{0}: WARNING: db_cur is None in db_close", self.pr_name)
+        if isinstance(self.db_conn, db.extensions.connection):
+            self.db_conn.close()
+            self.db_conn = None
+
+    def _gen_lines(self, data, img_names=None):
+        for i, row in enumerate(data):
+            x, y = row
+            line = str(x.tolist())
+            line = line.replace('[','{').replace(']','}')
+            if img_names:
+                line = '"{0}", "{1}", "{2}"\n'.format(line, y, img_names[i])
+            else:
+                line = '{0}|{1}\n'.format(line, y)
+            yield line
+
+    def _write_file(self, file_object, data, img_names=None):
+        lines = self._gen_lines(data, img_names)
+        file_object.writelines(lines)
+        # Do we actually need this?
+#        file_object.write('\.\n')
+
+    ROWS_PER_FILE = 1000
+
+    # Copies from open file-like object f into database
+    def _copy_into_db(self, f, data):
+        table_name = self.table_name
+        img_names = self.img_names
+
+        if img_names:
+            self.db_cur.copy_from(f, table_name, sep=',', columns=['x','y','img_name'])
 
 Review comment:
   We should use | here as well.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services