我们从Python开源项目中,提取了以下41个代码示例,用于说明如何使用setproctitle.setproctitle()。
def _initialize(self, worker_id): if getattr(self, '_initialized', False): return self._initialized = True if self.name is None: self.name = self.__class__.__name__ self.worker_id = worker_id self.pid = os.getpid() pname = os.path.basename(sys.argv[0]) self._title = "%(name)s(%(worker_id)d) [%(pid)d]" % dict( name=self.name, worker_id=self.worker_id, pid=self.pid) # Set process title setproctitle.setproctitle( "%(pname)s - %(name)s(%(worker_id)d)" % dict( pname=pname, name=self.name, worker_id=self.worker_id))
def data_munging(self): """Reads data from replpication queue and writes to mongo See Also: :meth:`.replicator` """ self.write_pid(str(os.getpid())) if self.setproctitle: import setproctitle setproctitle.setproctitle('mymongo_datamunging') module_instance = ParseData() mongo = MyMongoDB(config['mongodb']) munging = DataMunging(mongo, self.queues['replicator_out']) munging.run(module_instance)
def __ping(self, title): hostname = "www.google.com" was_down = False setproctitle.setproctitle(title) while not self.exit.is_set(): # ping hostname ... response = os.system("ping -c 1 -w2 " + hostname + " > /dev/null 2>&1") # then check the response if response != 0: self._log.error(hostname + ' is unreachable!') was_down = False elif was_down: self._log.error(hostname + ' is up again!') if self._gpio != None: self._gpio.led(1, True) # LED 1 on time.sleep(0.2) self._gpio.led(1, False) # LED 1 off time.sleep(15)
def run(self, max_loops=-1): ''' loops events structures to execute raise events and execute tasks. Args: max_loops: number of loops to run. If positive, limits number of loops. defaults to negative, which would run loops until there are no events to raise and no task to run. ''' if setproctitle is not None: run_id = "%s." % self.run_id if self.run_id else '' setproctitle("eventor: %s" % (run_id,)) if max_loops < 0: result=self.loop_session() else: result=None for _ in range(max_loops): #module_logger.debug('Starting loop cycle') result=self.loop_cycle() human_result="success" if result else 'failure' total_todo, _=self.count_todos(with_delayeds=True) module_logger.info('Processing finished with: %s; outstanding tasks: %s' % (human_result, total_todo)) #module_logger.info('Processing finished') return result
def run(self): if setproctitle: oldproctitle = getproctitle() setproctitle('[backing up %d: %s]' % (self.pk, self.friendly_name)) try: self.run_rsync() self.snapshot_rotate() self.snapshot_create() # Atomic update of size. size = bfs.parse_backup_sizes( self.dest_pool, self.hostgroup.name, self.friendly_name, self.date_complete)['size'] size_mb = size[0:-6] or '0' # :P HostConfig.objects.filter(pk=self.pk).update( backup_size_mb=size_mb) # Send signal that we're done. self.signal_done(True) except: # Send signal that we've failed. self.signal_done(False) # Propagate. raise finally: if setproctitle: setproctitle(oldproctitle)
def setproctitle(*args, **kwargs): pass
def setproctitle(name): pass
def _start(self, foreground = True): if self.username is None: if os.getuid() == 0: sys.stderr.write("Refusing to run as superuser\n") sys.exit(1) self.pw = pwd.getpwuid(os.getuid()) else: self.pw = pwd.getpwnam(self.username) if os.getuid() not in (0, self.pw.pw_uid): sys.stderr.write("Cannot run as user \"%s\"\n" % (self.username, )) sys.exit(1) setproctitle(self.procname) if not foreground: self._drop_priv() self.pre_daemonize() self._daemonize() if self.pidfile: self._write_pid() self._open_log(syslog = self.syslog) else: self._drop_priv() self.pre_daemonize() self._open_log(syslog = False, debug = True) self.run()
def run(self): setproctitle.setproctitle('testdaemon') self.out = 0; while True: GPIO.output(LED, self.out) self.out = self.out ^ 1 time.sleep(0.2)
def post_worker_init(dummy_worker): setproctitle.setproctitle( settings.GUNICORN_WORKER_READY_PREFIX + setproctitle.getproctitle() )
def run(self): '''Main execute of the class''' def cb_exit_gracefully(signum, frame): '''Callback to exit gracefully''' self.logger.info("Grace exit command received signum %d" % (signum)) for proc in self.current_subprocs: if proc.poll() is None: # Switching to a kill -9 as the nice option seems to require it. # proc.send_signal(signal.SIGINT) proc.terminate() #subprocess.check_call("kill -9 " + proc.pid()) sys.exit(0) compressor_workers = int(self.config.get("compression", "compressor_workers")) self.logger.info("Compressor process starting up") self.pool = ThreadPool(compressor_workers) setproctitle("[compress] " + getproctitle()) signal.signal(signal.SIGINT, cb_exit_gracefully) signal.signal(signal.SIGTERM, cb_exit_gracefully) while True: tocompress_dir = os.path.join(self.config.get( "main", "working_directory"), "tocompress") files = self.get_files(tocompress_dir, ".mak") if files: self.pool.map(self.compress_filename, files) time.sleep(float(self.config.get( "compression", "compression_check_interval"))) sys.exit(0)
def run(self): """ Repeatedly call :meth:`loop` method every :attribute:`interval` seconds. In case of *separate_process* is :const:`True` exit when parent process has exited. """ if self.separate_process: setproctitle.setproctitle(self.name) self.context.config.configure_logging() # Register SIGINT handler which will exit service process def sigint_handler(dummy_signum, dummy_frame): """ Exit service process when SIGINT is reached. """ self.stop() signal.signal(signal.SIGINT, sigint_handler) next_loop_time = 0 while 1: # Exit if service process is run in separated process and pid # of the parent process has changed (parent process has exited # and init is new parent) or if stop flag is set. if ( (self.separate_process and os.getppid() != self._parent_pid) or self._stop_event.is_set() ): break # Repeatedly call loop method. After first call set ready flag. if time.time() >= next_loop_time: self.loop() if not next_loop_time and not self.ready: self._lock.acquire() try: self._ready.value = True finally: self._lock.release() next_loop_time = time.time() + self.interval else: time.sleep(0.1)
def tornado_worker(tornado_app, sockets, parent_pid): """ Tornado worker which process HTTP requests. """ setproctitle.setproctitle( "{:s}: worker {:s}".format( tornado_app.settings['context'].config.name, tornado_app.settings['interface'].name ) ) tornado_app.settings['context'].config.configure_logging() # Run HTTP server http_server = tornado.httpserver.HTTPServer(tornado_app) http_server.add_sockets(sockets) # Register SIGINT handler which will stop worker def sigint_handler(dummy_signum, dummy_frame): """ Stop HTTP server and IOLoop if SIGINT. """ # Stop HTTP server (stop accept new requests) http_server.stop() # Stop IOLoop tornado.ioloop.IOLoop.instance().add_callback( tornado.ioloop.IOLoop.instance().stop) signal.signal(signal.SIGINT, sigint_handler) # Register job which will stop worker if parent process PID is changed stop_callback = tornado.ioloop.PeriodicCallback( functools.partial(stop_child, http_server, parent_pid), 250) stop_callback.start() # Run IOLoop tornado.ioloop.IOLoop.instance().start()
def command(self): setproctitle.setproctitle( "{:s}: master process '{:s}'".format( self.context.config.name, " ".join(sys.argv) )) # For each interface create workers for tornado_app in get_tornado_apps(self.context, debug=False): self.init_workers(tornado_app) # Run workers try: start_workers(self.workers, max_restarts=100) except KeyboardInterrupt: pass
def set_process_name(name): pass
def setproctitle(t): return None
def main(): parser = setup_parser() argcomplete.autocomplete(parser) options = parser.parse_args() _setup_logger(options) # Support the deprecated -c option if getattr(options, 'config', None) is not None: options.configs.append(options.config) if options.subparser in ('report', 'logs', 'metrics', 'run'): _default_region(options) _default_account_id(options) try: command = options.command if not callable(command): command = getattr( importlib.import_module(command.rsplit('.', 1)[0]), command.rsplit('.', 1)[-1]) # Set the process name to something cleaner process_name = [os.path.basename(sys.argv[0])] process_name.extend(sys.argv[1:]) setproctitle(' '.join(process_name)) command(options) except Exception: if not options.debug: raise traceback.print_exc() pdb.post_mortem(sys.exc_info()[-1])
def run(self): """Runs the daemon Thims method runs the daemon and creates all the process needed. Then waits forever """ self.logger = logging.getLogger(__name__) sys.stderr = self.log_err try: util.find_spec('setproctitle') self.setproctitle = True import setproctitle setproctitle.setproctitle('mymongo') except ImportError: self.setproctitle = False self.logger.info("Running") self.queues = dict() self.queues['replicator_out'] = Queue() procs = dict() procs['scheduler'] = Process(name='scheduler', target=self.scheduler) procs['scheduler'].daemon = True procs['scheduler'].start() procs['replicator'] = Process(name='replicator', target=self.replicator) procs['replicator'].daemon = True procs['replicator'].start() procs['datamunging'] = Process(name='datamunging', target=self.data_munging) procs['datamunging'].daemon = True procs['datamunging'].start() procs['dataprocess'] = Process(name='dataprocess', target=self.data_process) procs['dataprocess'].daemon = True procs['dataprocess'].start() while True: self.logger.info('Working...') time.sleep(60)
def scheduler(self): """Runs the daemon scheduler """ self.write_pid(str(os.getpid())) if self.setproctitle: import setproctitle setproctitle.setproctitle('mymongo_scheduler') sched = BlockingScheduler() try: sched.add_job(self.dummy_sched, 'interval', minutes=1) sched.start() except Exception as e: self.logger.error('Cannot start scheduler. Error: ' + str(e))
def replicator(self): """Main process for replication. It writes entry in the replication queue See Also: :meth:`.data_munging` """ self.write_pid(str(os.getpid())) if self.setproctitle: import setproctitle setproctitle.setproctitle('mymongo_replicator') mongo = MyMongoDB(config['mongodb']) mysql.mysql_stream(config['mysql'], mongo, self.queues['replicator_out'])
def data_process(self): self.write_pid(str(os.getpid())) if self.setproctitle: import setproctitle setproctitle.setproctitle('mymongo_dataprocess') mongo = MyMongoDB(config['mongodb']) process_instance = ProcessData(mongo) process_instance.run()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--nTrials', type=int, default=10) args = parser.parse_args() setproctitle.setproctitle('bamos.optnet.prof-linear') npr.seed(0) prof(args)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--nTrials', type=int, default=10) args = parser.parse_args() setproctitle.setproctitle('bamos.optnet.prof-gurobi') npr.seed(0) prof(args)
def appendproctitle(name): ''' Append "name" to the current process title ''' if HAS_SETPROCTITLE: setproctitle.setproctitle(setproctitle.getproctitle() + ' ' + name)
def init(ident = None, args = None): """ Initialize logging system. Default logging destination is stderr if "args" is not specified. """ # pylint: disable=E1103 if ident is None: ident = os.path.basename(sys.argv[0]) if args is None: args = argparse.Namespace(log_level = logging.WARNING, log_handler = logging.StreamHandler) handler = args.log_handler() handler.setFormatter(Formatter(ident, handler)) root_logger = logging.getLogger() root_logger.addHandler(handler) root_logger.setLevel(args.log_level) if ident and have_setproctitle and use_setproctitle: if proctitle_extra: setproctitle.setproctitle("%s (%s)" % (ident, proctitle_extra)) else: setproctitle.setproctitle(ident)
def set_proctitle(title): try: # This is probably the best way to do this, but I don't want to force an # external dependency on this C module... import setproctitle setproctitle.setproctitle(title) except ImportError: import ctypes, ctypes.util libc = ctypes.cdll.LoadLibrary(ctypes.util.find_library('c')) title_bytes = title.encode(sys.getdefaultencoding(), 'replace') buf = ctypes.create_string_buffer(title_bytes) # BSD, maybe also OSX? try: libc.setproctitle(ctypes.create_string_buffer(b"-%s"), buf) return except AttributeError: pass # Linux try: libc.prctl(15, buf, 0, 0, 0) return except AttributeError: pass
def run(self): """ ????Topic Logsize""" setproctitle.setproctitle("KafkaExtractCollector") while True: self.handler() time.sleep(base.config["collector"]["interval_minute"] * 60)
def task_wrapper(run_id=None, task=None, step=None, adminq=None, use_process=True, logger_info=None): ''' Args: func: object with action method with the following signature: action(self, action, unit, group, sequencer) action: object with taskid, unit, group: id of the unit to pass sqid: sequencer id to pass to action ''' global module_logger if use_process: module_logger=MpLogger.get_logger(logger_info=logger_info, name='') # name="%s.%s_%s" %(logger_info['name'], step.name, task.sequence)) task.pid=os.getpid() os.environ['EVENTOR_STEP_SEQUENCE']=str(task.sequence) os.environ['EVENTOR_STEP_RECOVERY']=str(task.recovery) os.environ['EVENTOR_STEP_NAME']=str(step.name) if setproctitle is not None and use_process: run_id_s = "%s." % run_id if run_id else '' setproctitle("eventor: %s%s.%s(%s)" % (run_id_s, step.name, task.id_, task.sequence)) # Update task with PID update=TaskAdminMsg(msg_type=TaskAdminMsgType.update, value=task) adminq.put( update ) module_logger.info('[ Step {}/{} ] Trying to run'.format(step.name, task.sequence)) try: # todo: need to pass task resources. result=step(seq_path=task.sequence, ) except Exception as e: trace=inspect.trace() trace=traces(trace) #[2:] task.result=(e, pickle.dumps(trace)) task.status=TaskStatus.failure else: task.result=result task.status=TaskStatus.success result=TaskAdminMsg(msg_type=TaskAdminMsgType.result, value=task) module_logger.info('[ Step {}/{} ] Completed, status: {}'.format(step.name, task.sequence, str(task.status), )) adminq.put( result ) return True
def main(): try: import setproctitle setproctitle.setproctitle(name) except ImportError: pass dialects = csv.list_dialects() dialects.sort() dialects.insert(0, 'sniff') # CLI arguments parser = argparse.ArgumentParser(prog=name, description=description) parser.add_argument('-V', '--version', action='version', version="%(prog)s " + version) parser.add_argument('-C', '--config', help='Use a different configuration file') parser.add_argument('-s', '--section', help='Configuration file section', default='DEFAULT') parser.add_argument('-c', '--channel', help='Send to this channel or @username') parser.add_argument('-U', '--url', help='Mattermost webhook URL') parser.add_argument('-u', '--username', help='Username') parser.add_argument('-i', '--icon', help='Icon') group = parser.add_mutually_exclusive_group() group.add_argument('-t', '--tabular', metavar='DIALECT', const='sniff', nargs='?', choices=dialects, help='Parse input as CSV and format it as a table (DIALECT can be one of %(choices)s)') group.add_argument('-y', '--syntax', default='auto') parser.add_argument('-I', '--info', action='store_true', help='Include file information in message') parser.add_argument('-n', '--dry-run', '--just-print', action='store_true', help="Don't send, just print the payload") parser.add_argument('-f', '--file', default='-', help="Read content from FILE. If - reads from standard input (DEFAULT: %(default)s)") args = parser.parse_args() if args.file == '-': message = sys.stdin.read() filename = None else: message = '' filename = args.file try: payload = send(args.channel, message, filename, args.url, args.username, args.icon, args.syntax, args.tabular, args.info, args.dry_run, args.section, name, args.config) except (configparser.Error, TypeError, RuntimeError) as e: sys.exit(str(e)) if args.dry_run: print(payload)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--no-cuda', action='store_true') parser.add_argument('--nTrials', type=int, default=5) # parser.add_argument('--boardSz', type=int, default=2) # parser.add_argument('--batchSz', type=int, default=150) parser.add_argument('--Qpenalty', type=float, default=0.1) args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() setproctitle.setproctitle('bamos.sudoku.prof-sparse') print('=== nTrials: {}'.format(args.nTrials)) print('| {:8s} | {:8s} | {:21s} | {:21s} |'.format( 'boardSz', 'batchSz', 'dense forward (s)', 'sparse forward (s)')) for boardSz in [2,3]: with open('data/{}/features.pt'.format(boardSz), 'rb') as f: X = torch.load(f) with open('data/{}/labels.pt'.format(boardSz), 'rb') as f: Y = torch.load(f) N, nFeatures = X.size(0), int(np.prod(X.size()[1:])) for batchSz in [1, 64, 128]: dmodel = models.OptNetEq(boardSz, args.Qpenalty, trueInit=True) spmodel = models.SpOptNetEq(boardSz, args.Qpenalty, trueInit=True) if args.cuda: dmodel = dmodel.cuda() spmodel = spmodel.cuda() dtimes = [] sptimes = [] for i in range(args.nTrials): Xbatch = Variable(X[i*batchSz:(i+1)*batchSz]) Ybatch = Variable(Y[i*batchSz:(i+1)*batchSz]) if args.cuda: Xbatch = Xbatch.cuda() Ybatch = Ybatch.cuda() # Make sure buffers are initialized. # dmodel(Xbatch) # spmodel(Xbatch) start = time.time() # dmodel(Xbatch) dtimes.append(time.time()-start) start = time.time() spmodel(Xbatch) sptimes.append(time.time()-start) print('| {:8d} | {:8d} | {:.2e} +/- {:.2e} | {:.2e} +/- {:.2e} |'.format( boardSz, batchSz, np.mean(dtimes), np.std(dtimes), np.mean(sptimes), np.std(sptimes)))
def init_sender(config): gevent.signal(signal.SIGINT, sender_shutdown) gevent.signal(signal.SIGTERM, sender_shutdown) gevent.signal(signal.SIGQUIT, sender_shutdown) process_title = config['sender'].get('process_title') if process_title and isinstance(process_title, basestring): setproctitle.setproctitle(process_title) logger.info('Changing process name to %s', process_title) api_host = config['sender'].get('api_host', 'http://localhost:16649') db.init(config) cache.init(api_host, config) metrics.init(config, 'iris-sender', default_sender_metrics) api_cache.cache_priorities() api_cache.cache_applications() api_cache.cache_modes() global should_mock_gwatch_renewer, send_message if config['sender'].get('debug'): logger.info('DEBUG MODE') should_mock_gwatch_renewer = True should_skip_send = True else: should_skip_send = False should_mock_gwatch_renewer = should_mock_gwatch_renewer or config.get('skipgmailwatch', False) should_skip_send = should_skip_send or config.get('skipsend', False) if should_skip_send: config['vendors'] = [{ 'type': 'iris_dummy', 'name': 'iris dummy vendor' }] global quota quota = ApplicationQuota(db, cache.targets_for_role, message_send_enqueue, config['sender'].get('sender_app')) global coordinator zk_hosts = config['sender'].get('zookeeper_cluster', False) if zk_hosts: logger.info('Initializing coordinator with ZK: %s', zk_hosts) from iris.coordinator.kazoo import Coordinator coordinator = Coordinator(zk_hosts=zk_hosts, hostname=socket.gethostname(), port=config['sender'].get('port', 2321), join_cluster=True) else: logger.info('ZK cluster info not specified. Using master status from config') from iris.coordinator.noncluster import Coordinator coordinator = Coordinator(is_master=config['sender'].get('is_master', True), slaves=config['sender'].get('slaves', []))
def daemonize(self): """ do the UNIX double-fork magic, see Stevens' "Advanced Programming in the UNIX Environment" for details (ISBN 0201563177) http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16 """ try: pid = os.fork() if pid > 0: # exit first parent sys.exit(0) except OSError, e: sys.stderr.write("fork #1 failed: %d (%s)\n" % (e.errno, e.strerror)) sys.exit(1) # decouple from parent environment (Move to EasyApplication.load_config(), delete by leon 20151108) os.setsid() os.umask(0) # do second fork try: pid = os.fork() if pid > 0: # exit from second parent sys.exit(0) except OSError, e: sys.stderr.write("fork #2 failed: %d (%s)\n" % (e.errno, e.strerror)) sys.exit(1) # redirect standard file descriptors sys.stdout.flush() sys.stderr.flush() si = file(self.stdin, 'r') so = file(self.stdout, 'a+') se = file(self.stderr, 'a+', 0) os.dup2(si.fileno(), sys.stdin.fileno()) os.dup2(so.fileno(), sys.stdout.fileno()) os.dup2(se.fileno(), sys.stderr.fileno()) # write pidfile atexit.register(self.delpid) pid = str(os.getpid()) file(self.pidfile, 'w+').write("%s\n" % pid) # update proc name import setproctitle setproctitle.setproctitle(self.daemon_conf['name'])
def main(): parser = argparse.ArgumentParser() parser.add_argument('--save', type=str, default='work') parser.add_argument('--nEpoch', type=int, default=100) # parser.add_argument('--testBatchSz', type=int, default=2048) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--model', type=str, default="picnn", choices=['picnn', 'ficnn']) parser.add_argument('--dataset', type=str, default="moons", choices=['moons', 'circles', 'linear']) parser.add_argument('--noncvx', action='store_true') args = parser.parse_args() npr.seed(args.seed) tf.set_random_seed(args.seed) setproctitle.setproctitle('bamos.icnn.synthetic.{}.{}'.format(args.model, args.dataset)) save = os.path.join(os.path.expanduser(args.save), "{}.{}".format(args.model, args.dataset)) if os.path.isdir(save): shutil.rmtree(save) os.makedirs(save, exist_ok=True) if args.dataset == "moons": (dataX, dataY) = make_moons(noise=0.3, random_state=0) elif args.dataset == "circles": (dataX, dataY) = make_circles(noise=0.2, factor=0.5, random_state=0) dataY = 1.-dataY elif args.dataset == "linear": (dataX, dataY) = make_classification(n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1) rng = np.random.RandomState(2) dataX += 2 * rng.uniform(size=dataX.shape) else: assert(False) dataY = dataY.reshape((-1, 1)).astype(np.float32) nData = dataX.shape[0] nFeatures = dataX.shape[1] nLabels = 1 nXy = nFeatures + nLabels config = tf.ConfigProto() #log_device_placement=False) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: model = Model(nFeatures, nLabels, sess, args.model, nGdIter=30) model.train(args, dataX, dataY)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--save', type=str, default='work/mse.ebundle') parser.add_argument('--nEpoch', type=float, default=50) parser.add_argument('--nBundleIter', type=int, default=30) # parser.add_argument('--trainBatchSz', type=int, default=25) parser.add_argument('--trainBatchSz', type=int, default=70) # parser.add_argument('--testBatchSz', type=int, default=2048) parser.add_argument('--noncvx', action='store_true') parser.add_argument('--seed', type=int, default=42) # parser.add_argument('--valSplit', type=float, default=0) args = parser.parse_args() assert(not args.noncvx) setproctitle.setproctitle('bamos.icnn.comp.mse.ebundle') npr.seed(args.seed) tf.set_random_seed(args.seed) save = os.path.expanduser(args.save) if os.path.isdir(save): shutil.rmtree(save) os.makedirs(save) ckptDir = os.path.join(save, 'ckpt') args.ckptDir = ckptDir if not os.path.exists(ckptDir): os.makedirs(ckptDir) data = olivetti.load("data/olivetti") # eps = 1e-8 # data['trainX'] = data['trainX'].clip(eps, 1.-eps) # data['trainY'] = data['trainY'].clip(eps, 1.-eps) # data['testX'] = data['testX'].clip(eps, 1.-eps) # data['testY'] = data['testY'].clip(eps, 1.-eps) nTrain = data['trainX'].shape[0] nTest = data['testX'].shape[0] inputSz = list(data['trainX'][0].shape) outputSz = list(data['trainY'][1].shape) print("\n\n" + "="*40) print("+ nTrain: {}, nTest: {}".format(nTrain, nTest)) print("+ inputSz: {}, outputSz: {}".format(inputSz, outputSz)) print("="*40 + "\n\n") config = tf.ConfigProto() #log_device_placement=False) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: model = Model(inputSz, outputSz, sess) model.train(args, data['trainX'], data['trainY'], data['testX'], data['testY'])
def main(): parser = argparse.ArgumentParser() parser.add_argument('--save', type=str, default='work/mse') parser.add_argument('--nEpoch', type=float, default=50) # parser.add_argument('--trainBatchSz', type=int, default=25) parser.add_argument('--trainBatchSz', type=int, default=70) # parser.add_argument('--testBatchSz', type=int, default=2048) parser.add_argument('--nGdIter', type=int, default=30) parser.add_argument('--noncvx', action='store_true') parser.add_argument('--seed', type=int, default=42) # parser.add_argument('--valSplit', type=float, default=0) args = parser.parse_args() setproctitle.setproctitle('bamos.icnn.comp.mse') npr.seed(args.seed) tf.set_random_seed(args.seed) save = os.path.expanduser(args.save) if os.path.isdir(save): shutil.rmtree(save) os.makedirs(save) ckptDir = os.path.join(save, 'ckpt') args.ckptDir = ckptDir if not os.path.exists(ckptDir): os.makedirs(ckptDir) data = olivetti.load("data/olivetti") nTrain = data['trainX'].shape[0] nTest = data['testX'].shape[0] inputSz = list(data['trainX'][0].shape) outputSz = list(data['trainY'][1].shape) print("\n\n" + "="*40) print("+ nTrain: {}, nTest: {}".format(nTrain, nTest)) print("+ inputSz: {}, outputSz: {}".format(inputSz, outputSz)) print("="*40 + "\n\n") config = tf.ConfigProto() #log_device_placement=False) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: model = Model(inputSz, outputSz, sess, args.nGdIter) model.train(args, data['trainX'], data['trainY'], data['testX'], data['testY'])