我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用six.moves.cPickle.load()。
def load_params(self, f_, filter_=None): di = pickle.load(f_) if filter_ is None: for k,v in di.items(): p = self._vars_di[k].get_value(borrow=True) if p.shape != v.shape: raise ValueError('Shape mismatch, need %s, got %s'%(v.shape, p.shape), p.shape) self._vars_di[k].set_value(v) else: pat = re.compile(filter_) for k,v in di.items(): if not pat.fullmatch(k): continue p = self._vars_di[k].get_value(borrow=True) if p.shape != v.shape: raise ValueError('Shape mismatch, need %s, got %s'%(v.shape, p.shape), p.shape) self._vars_di[k].set_value(v)
def read_data_files(self, subset='train'): """Reads from data file and returns images and labels in a numpy array.""" assert self.data_dir, ('Cannot call `read_data_files` when using synthetic ' 'data') if subset == 'train': filenames = [os.path.join(self.data_dir, 'data_batch_%d' % i) for i in xrange(1, 6)] elif subset == 'validation': filenames = [os.path.join(self.data_dir, 'test_batch')] else: raise ValueError('Invalid data subset "%s"' % subset) inputs = [] for filename in filenames: with gfile.Open(filename, 'r') as f: inputs.append(cPickle.load(f)) # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the # input format. all_images = np.concatenate( [each_input['data'] for each_input in inputs]).astype(np.float32) all_labels = np.concatenate( [each_input['labels'] for each_input in inputs]) return all_images, all_labels
def read_data(): with open(PICKLE_FILENAME, 'rb') as f: save = pickle.load(f) train_dataset = save['train_dataset'] train_labels = save['train_labels'] valid_dataset = save['valid_dataset'] valid_labels = save['valid_labels'] test_dataset = save['test_dataset'] test_labels = save['test_labels'] del save print('Training set', train_dataset.shape, train_labels.shape) print('Valid set', valid_dataset.shape, valid_labels.shape) print('Test set', test_dataset.shape, test_labels.shape) return [train_dataset, valid_dataset, test_dataset], [train_labels, valid_labels, test_labels]
def load_pickle(filename): """Load a community model from a pickled version. Parameters ---------- filename : str The file the community is stored in. Returns ------- micom.Community The loaded community model. """ with open(filename, mode="rb") as infile: return pickle.load(infile)
def predict(): """ An example of how to load a trained model and use it to predict labels. """ # load the saved model classifier = pickle.load(open('best_model.pkl')) # compile a predictor function predict_model = theano.function( inputs=[classifier.input], outputs=classifier.y_pred) # We can test it on some examples from test test dataset='mnist.pkl.gz' datasets = load_data(dataset) test_set_x, test_set_y = datasets[2] test_set_x = test_set_x.get_value() predicted_values = predict_model(test_set_x[:10]) print("Predicted values for the first 10 examples in test set:") print(predicted_values)
def history_infos(opt): infos = {} if len(opt.start_from) != 0: # open old infos and check if models are compatible model_id = opt.start_from infos_id = model_id.replace('save/', '') + '.infos-best.pkl' with open(os.path.join(opt.start_from, infos_id)) as f: infos = cPickle.load(f) saved_model_opt = infos['opt'] iteration = infos.get('iter', 0) epoch = infos.get('epoch', 0) val_result_history = infos.get('val_result_history', {}) loss_history = infos.get('loss_history', {}) lr_history = infos.get('lr_history', {}) best_val_score = infos.get('best_val_score', None) if opt.load_best_score == 1 else 0 val_loss = 0.0 val_history = [val_result_history, best_val_score, val_loss] train_history = [loss_history, lr_history] return opt, infos, iteration, epoch, val_history, train_history
def load_batch(fpath, label_key='labels'): f = open(fpath, 'rb') if sys.version_info < (3,): d = cPickle.load(f) else: d = cPickle.load(f, encoding="bytes") # decode utf8 d_decoded = {} for k, v in d.items(): d_decoded[k.decode("utf8")] = v d = d_decoded f.close() data = d["data"] labels = d[label_key] data = data.reshape(data.shape[0], 3, 32, 32) return data, labels
def from_indra_pickle(path, name=None, version=None, description=None): """Imports a model from :mod:`indra`. :param str path: Path to pickled list of :class:`indra.statements.Statement` :param str name: The name for the BEL graph :param str version: The version of the BEL graph :param str description: The description of the BEL graph :rtype: pybel.BELGraph """ with open(path, 'rb') as f: statements = load(f) return from_indra_statements( statements=statements, name=name, version=version, description=description )
def restore_snapshot(self, filename=None): """ Restore a saved snapshot of current process from file Warning: this is not thread safe, do not use with multithread program Args: - file: saved snapshot Returns: - Bool """ if not filename: filename = self.get_config_filename("snapshot") fd = open(filename, "rb") snapshot = pickle.load(fd) return self.give_snapshot(snapshot) ######################### # Memory Operations # #########################
def maybe_pickle(data_folders, min_num_images_per_class, force=False): dataset_names = [] folders_list = os.listdir(data_folders) for folder in folders_list: #print(os.path.join(data_folders, folder)) curr_folder_path = os.path.join(data_folders, folder) if os.path.isdir(curr_folder_path): set_filename = curr_folder_path + '.pickle' dataset_names.append(set_filename) if os.path.exists(set_filename) and not force: # You may override by setting force=True. print('%s already present - Skipping pickling.' % set_filename) else: print('Pickling %s.' % set_filename) dataset = load_letter(curr_folder_path, min_num_images_per_class) # load and normalize the data try: with open(set_filename, 'wb') as f: pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL) f.close() except Exception as e: print('Unable to save data to', set_filename, ':', e) return dataset_names
def read_dataset(data_dir): pickle_filename = "PascalVoc.pickle" pickle_filepath = os.path.join(data_dir, pickle_filename) if not os.path.exists(pickle_filepath): utils.maybe_download_and_extract(data_dir, DATA_URL, is_tarfile=True) PascalVoc_folder = "VOCdevkit" result = create_image_lists(os.path.join(data_dir, PascalVoc_folder)) print ("Pickling ...") with open(pickle_filepath, 'wb') as f: pickle.dump(result, f, pickle.HIGHEST_PROTOCOL) else: print ("Found pickle file!") with open(pickle_filepath, 'rb') as f: result = pickle.load(f) training_records = result['training'] validation_records = result['validation'] del result return training_records, validation_records
def read_dataset(data_dir): pickle_filename = "MITSceneParsing.pickle" pickle_filepath = os.path.join(data_dir, pickle_filename) if not os.path.exists(pickle_filepath): utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True) SceneParsing_folder = os.path.splitext(DATA_URL.split("/")[-1])[0] result = create_image_lists(os.path.join(data_dir, SceneParsing_folder)) print ("Pickling ...") with open(pickle_filepath, 'wb') as f: pickle.dump(result, f, pickle.HIGHEST_PROTOCOL) else: print ("Found pickle file!") with open(pickle_filepath, 'rb') as f: result = pickle.load(f) training_records = result['training'] validation_records = result['validation'] del result return training_records, validation_records
def install_and_load(self): # TODO automatically install if fails to find anything FILE_NOT_FOUND_MSG = ( 'Did not found TIMIT file "%s"' ', make sure you download and install the dataset') self.subset = {} path = os.path.join(os.path.dirname(__file__), 'TIMIT', '%s_set.pkl') for subset in ['train', 'test']: filepath = path % subset if not os.path.exists(filepath): raise IOError( FILE_NOT_FOUND_MSG % filepath) with open(filepath, 'rb') as f: gc.disable() all_data = [pickle.load(f)] all_data.append(pickle.load(f)) all_data.append(pickle.load(f)) gc.enable() self.subset[subset] = all_data # use same subset for validation / test # as TIMIT is small self.subset['valid'] = self.subset['test']
def read_dataset(data_dir): pickle_filename = "celebA.pickle" pickle_filepath = os.path.join(data_dir, pickle_filename) if not os.path.exists(pickle_filepath): utils.maybe_download_and_extract(data_dir, DATA_URL, is_zipfile=True) celebA_folder = os.path.splitext(DATA_URL.split("/")[-1])[0] result = create_image_lists(os.path.join(data_dir, celebA_folder)) print ("Training set: %d" % len(result['train'])) print ("Test set: %d" % len(result['test'])) print ("Validation set: %d" % len(result['validation'])) print ("Pickling ...") with open(pickle_filepath, 'wb') as f: pickle.dump(result, f, pickle.HIGHEST_PROTOCOL) else: print ("Found pickle file!") with open(pickle_filepath, 'rb') as f: result = pickle.load(f) training_images = result['train'] testing_images = result['test'] validation_images = result['validation'] del result return training_images, testing_images, validation_images
def get_word_index(path='reuters_word_index.pkl'): """Retrieves the dictionary mapping word indices back to words. # Arguments path: where to cache the data (relative to `~/.keras/dataset`). # Returns The word index dictionary. """ path = get_file(path, origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl') f = open(path, 'rb') if sys.version_info < (3,): data = cPickle.load(f) else: data = cPickle.load(f, encoding='latin1') f.close() return data
def get_word_index(path='imdb_word_index.pkl'): """Retrieves the dictionary mapping word indices back to words. # Arguments path: where to cache the data (relative to `~/.keras/dataset`). # Returns The word index dictionary. """ path = get_file(path, origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.pkl', md5_hash='72d94b01291be4ff843198d3b0e1e4d7') f = open(path, 'rb') if sys.version_info < (3,): data = cPickle.load(f) else: data = cPickle.load(f, encoding='latin1') f.close() return data
def load_data(path='mnist.pkl.gz'): """Loads the MNIST dataset. # Arguments path: path where to cache the dataset locally (relative to ~/.keras/datasets). # Returns Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ path = get_file(path, origin='https://s3.amazonaws.com/img-datasets/mnist.pkl.gz') if path.endswith('.gz'): f = gzip.open(path, 'rb') else: f = open(path, 'rb') if sys.version_info < (3,): data = cPickle.load(f) else: data = cPickle.load(f, encoding='bytes') f.close() return data # (x_train, y_train), (x_test, y_test)
def __init__(self, save_dir=SAVE_DIR, prime_text=PRIME_TEXT, num_sample_symbols=NUM_SAMPLE_SYMBOLS): self.save_dir = save_dir self.prime_text = prime_text self.num_sample_symbols = num_sample_symbols with open(os.path.join(Sampler.SAVE_DIR, 'chars_vocab.pkl'), 'rb') as file: self.chars, self.vocab = cPickle.load(file) self.model = Model(len(self.chars), is_sampled=True) # polite GPU memory allocation: don't grab everything you can. config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.allocator_type = 'BFC' self.sess = tf.Session(config=config) tf.initialize_all_variables().run(session=self.sess) self.checkpoint = tf.train.get_checkpoint_state(self.save_dir) if self.checkpoint and self.checkpoint.model_checkpoint_path: tf.train.Saver(tf.all_variables()).restore(self.sess, self.checkpoint.model_checkpoint_path)
def update_default_setting(self, key_tree, value): """ Update a default value in the local settings file. :param key_tree: A tuple containing a tree of dictionary keys. :param value: The value for the setting. """ # Open the defaults. with open(self._default_settings_path, "rb") as fp: defaults = yaml.load(fp) branch = defaults for key in key_tree[:-1]: branch.setdefault(key, {}) branch = branch[key] branch[key_tree[-1]] = value with open(self._default_settings_path, "w") as fp: fp.write(yaml.dump(defaults)) return True
def model_from_yaml(yaml_string, custom_objects=None): """Parses a yaml model configuration file and returns a model instance. Arguments: yaml_string: YAML string encoding a model configuration. custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. Returns: A Keras model instance (uncompiled). Raises: ImportError: if yaml module is not found. """ if yaml is None: raise ImportError('Requires yaml module installed.') config = yaml.load(yaml_string) return layer_module.deserialize(config, custom_objects=custom_objects)
def to_yaml(self, **kwargs): """Returns a yaml string containing the network configuration. To load a network from a yaml save file, use `keras.models.model_from_yaml(yaml_string, custom_objects={})`. `custom_objects` should be a dictionary mapping the names of custom losses / layers / etc to the corresponding functions / classes. Arguments: **kwargs: Additional keyword arguments to be passed to `yaml.dump()`. Returns: A YAML string. Raises: ImportError: if yaml module is not found. """ if yaml is None: raise ImportError('Requires yaml module installed.') return yaml.dump(self._updated_config(), **kwargs)
def load_data(path='mnist.npz'): """Loads the MNIST dataset. Arguments: path: path where to cache the dataset locally (relative to ~/.keras/datasets). Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ path = get_file( path, origin='https://s3.amazonaws.com/img-datasets/mnist.npz') f = np.load(path) x_train = f['x_train'] y_train = f['y_train'] x_test = f['x_test'] y_test = f['y_test'] f.close() return (x_train, y_train), (x_test, y_test)
def get_word_index(path='reuters_word_index.json'): """Retrieves the dictionary mapping word indices back to words. Arguments: path: where to cache the data (relative to `~/.keras/dataset`). Returns: The word index dictionary. """ path = get_file( path, origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json') f = open(path) data = json.load(f) f.close() return data
def __init__(self,player_number,epsilon=0.5,discount=0.8,alpha=1e-4,mode='test',pickle_name = None,lr_decay_fn = None): Agent.__init__(self,player_number) self.epsilon = epsilon self.discount = discount self.alpha = alpha self.setLearningTarget() self.lr_decay = lr_decay_fn if mode == 'train': self.train = True elif mode == 'test': self.train = False else: print ('no mode \'',mode,'\' for QlearningAgent') raise Exception self.reset() if pickle_name != None: self.load(pickle_name) print ('epsilon :',self.epsilon) print ('learning :',self.alpha) print ('discount :',self.discount)
def sample(args): # import configuration with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'rb') as f: words, vocab = cPickle.load(f) # import the trained model model = Model(saved_args, True) with tf.Session() as sess: # initialize the model tf.initialize_all_variables().run() saver = tf.train.Saver(tf.all_variables()) ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # sample the new sequence word by word literature = model.sample(sess, words, vocab, args.n, args.start, args.sample) with codecs.open('result/sequence.txt','a','utf-8') as f: f.write(literature+'\n\n') print(literature)
def load(self): with open(os.path.join(self.data_dir, "notMNIST.pickle"), 'rb') as f: save = pickle.load(f) train_dataset = save['train_dataset'] train_labels = save['train_labels'] valid_dataset = save['valid_dataset'] valid_labels = save['valid_labels'] test_dataset = save['test_dataset'] test_labels = save['test_labels'] del save # hint to help gc free up memory train_dataset, train_labels = self.reformat(train_dataset, train_labels, self.image_size, self.num_labels) valid_dataset, valid_labels = self.reformat(valid_dataset, valid_labels, self.image_size, self.num_labels) test_dataset, test_labels = self.reformat(test_dataset, test_labels, self.image_size, self.num_labels) return train_dataset, train_labels, valid_dataset, valid_labels, test_dataset, test_labels
def sample(args): with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: chars, vocab = cPickle.load(f) model = Model(saved_args, True) val_loss_file = args.save_dir + '/val_loss.json' with tf.Session() as sess: saver = tf.train.Saver(tf.all_variables()) if os.path.exists(val_loss_file): with open(val_loss_file, "r") as text_file: text = text_file.read() loss_json = json.loads(text) losses = loss_json.keys() losses.sort(key=lambda x: float(x)) loss = losses[0] model_checkpoint_path = loss_json[loss]['checkpoint_path'] #print(model_checkpoint_path) saver.restore(sess, model_checkpoint_path) result = model.sample(sess, chars, vocab, args.n, args.prime, args.sample_rule, args.temperature) print(result) #add this back in later, not sure why its not working output = "/data/output/"+ str(int(time.time())) + ".txt" with open(output, "w") as text_file: text_file.write(result) print(output)
def empty_network(network): logger.debug("Storing pypsa timeseries to disk") from .components import all_components panels = {} for c in all_components: attr = network.components[c]["list_name"] + "_t" panels[attr] = getattr(network, attr) setattr(network, attr, None) fd, fn = tempfile.mkstemp() with os.fdopen(fd, 'wb') as f: pickle.dump(panels, f, -1) del panels gc.collect() yield logger.debug("Reloading pypsa timeseries from disk") with open(fn, 'rb') as f: panels = pickle.load(f) os.remove(fn) for attr, pnl in iteritems(panels): setattr(network, attr, pnl)
def load(self, config_data): """ Method to load the configuration file, the configuration schema, and select the correct validator and backend Args: config_data(dict): The configuration dictionary Returns: None """ self.config_data = config_data # Load the schema file based on the config that was provided try: schema_name = self.config_data['schema']['name'] except KeyError as err: raise ConfigFileError("The specified schema was not found: {}. Try to update your ingest client library or double check your ingest job configuration file".format(self.config_data['schema']['name'])) with open(os.path.join(resource_filename("ingestclient", "schema"), "{}.json".format(schema_name)), 'rt') as schema_file: self.schema = json.load(schema_file)
def load_plugins(self): """Method to load the plugins Returns: None """ # Create plugin instances package, class_name = self.config_data["client"]["tile_processor"]["class"].rsplit('.', 1) tile_module = importlib.import_module(package) tile_class = getattr(tile_module, class_name) self.tile_processor_class = tile_class() package, class_name = self.config_data["client"]["path_processor"]["class"].rsplit('.', 1) path_module = importlib.import_module(package) path_class = getattr(path_module, class_name) self.path_processor_class = path_class()
def load_and_display_pickle(datasets, sample_size, title=None): fig = plt.figure() if title: fig.suptitle(title, fontsize=16, fontweight='bold') num_of_images = [] for pickle_file in datasets: with open(pickle_file, 'rb') as f: data = pickle.load(f) print('Total images in', pickle_file, ':', len(data)) for index, image in enumerate(data): if index == sample_size: break ax = fig.add_subplot(len(datasets), sample_size, sample_size * datasets.index(pickle_file) + index + 1) ax.imshow(image) ax.set_axis_off() ax.imshow(image) num_of_images.append(len(data)) balance_check(num_of_images) plt.show() return num_of_images
def str_to_func(s, sandbox=None): if isinstance(s, (tuple, list)): code, closure, defaults = s elif isinstance(s, string_types): # path to file if os.path.isfile(s): with open(s, 'rb') as f: code, closure, defaults = cPickle.load(f) else: # pickled string code, closure, defaults = cPickle.loads(s) else: raise ValueError("Unsupport str_to_func for type:%s" % type(s)) code = marshal.loads(cPickle.loads(code).tostring()) func = types.FunctionType(code=code, name=code.co_name, globals=sandbox if isinstance(sandbox, Mapping) else globals(), closure=closure, argdefs=defaults) return func
def load_npy_to_any(path='', name='file.npy'): """Load .npy file. Examples --------- - see save_any_to_npy() """ file_path = os.path.join(path, name) try: npy = np.load(file_path).item() except: npy = np.load(file_path) finally: try: return npy except: print("[!] Fail to load %s" % file_path) exit() # Visualizing npz files
def read_data_files(self, subset='train'): """Reads from data file and return images and labels in a numpy array.""" if subset == 'train': filenames = [os.path.join(self.data_dir, 'data_batch_%d' % i) for i in xrange(1, 6)] elif subset == 'validation': filenames = [os.path.join(self.data_dir, 'test_batch')] else: raise ValueError('Invalid data subset "%s"' % subset) inputs = [] for filename in filenames: with gfile.Open(filename, 'r') as f: inputs.append(cPickle.load(f)) # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the # input format. all_images = np.concatenate( [each_input['data'] for each_input in inputs]).astype(np.float32) all_labels = np.concatenate( [each_input['labels'] for each_input in inputs]) return all_images, all_labels
def sample(args): with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: chars, vocab = cPickle.load(f) model = Model(saved_args, training=False) with tf.Session() as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) ret, hidden = model.sample(sess, chars, vocab, args.n, args.prime, args.sample)#.encode('utf-8')) print("Number of characters generated: ", len(ret)) for i in range(len(ret)): print("Generated character: ", ret[i]) print("Assosciated hidden state:" , hidden[i])
def load_npy_to_any(path='', name='file.npy'): """Load .npy file. Examples --------- - see save_any_to_npy() """ file_path = os.path.join(path, name) try: npy = np.load(file_path).item() except: npy = np.load(file_path) finally: try: return npy except: print("[!] Fail to load %s" % file_path) exit() ## Folder functions
def main(): parser = argparse.ArgumentParser() parser.add_argument('--save_dir', type=str, default='save', help='model directory to load stored checkpointed models from') parser.add_argument('-n', type=int, default=200, help='number of words to sample') parser.add_argument('--prime', type=str, default=' ', help='prime text') parser.add_argument('--pick', type=int, default=1, help='1 = weighted pick, 2 = beam search pick') parser.add_argument('--width', type=int, default=4, help='width of the beam search') parser.add_argument('--sample', type=int, default=1, help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces') args = parser.parse_args() sample(args)
def load(self, local_dir_=None): ''' load dataset from local disk Args: local_dir_: string or None if None, will use default Dataset.DEFAULT_DIR '''
def load(self, local_dir_=None): if local_dir_ is None: local_dir = self.DEFAULT_DIR else: local_dir = Path(local_dir_) data_di = np.load(str(local_dir/'cifar10.npz')) self.datum[:] = data_di['images'] self.labels[:] = data_di['labels']
def install( self, local_dst_dir_=None, local_src_dir_=None, clean_install_=False): ''' Install the dataset into directly usable format, requires downloading for public dataset. Args: local_dst_dir_: string or None where to install the dataset, None -> "%(default_dir)s" local_src_dir_: string or None where to find the raw downloaded files, None -> "%(default_dir)s" ''' local_dst_dir = self.DEFAULT_DIR if local_dst_dir_ is None else Path(local_dst_dir_) local_src_dir = self.DEFAULT_DIR if local_src_dir_ is None else Path(local_src_dir_) local_dst_dir.mkdir(parents=True, exist_ok=True) assert local_src_dir.exists() images = np.empty((60000,3,32,32), dtype=np.uint8) labels = np.empty((60000,), dtype=np.uint8) tarfile_name = str(local_src_dir / 'cifar-10-python.tar.gz') with tarfile.open(tarfile_name, 'r:gz') as tf: for i in range(5): with tf.extractfile('cifar-10-batches-py/data_batch_%d'%(i+1)) as f: data_di = pickle.load(f, encoding='bytes') images[(10000*i):(10000*(i+1))] = data_di[b'data'].reshape((10000,3,32,32)) labels[(10000*i):(10000*(i+1))] = np.asarray(data_di[b'labels'], dtype=np.uint8) with tf.extractfile('cifar-10-batches-py/test_batch') as f: data_di = pickle.load(f, encoding='bytes') images[50000:60000] = data_di[b'data'].reshape((10000,3,32,32)) labels[50000:60000] = data_di[b'labels'] np.savez_compressed(str(local_dst_dir / 'cifar10.npz'), images=images, labels=labels) if clean_install_: os.remove(tarfile_name)
def load(self, local_dir_=None): if local_dir_ is None: local_dir = self.DEFAULT_DIR else: local_dir = Path(local_dir_) data = np.load(str(local_dir / 'mnist.npz')) self.labels = data['labels'] self.datum = data['images'] self.label_map = np.arange(10) self.imsize = (1,28,28)
def pickle_load(filename): """Deserialize data from file using gzip compression.""" if filename.endswith('.pkz'): with gzip.open(filename, 'rb') as f: return pickle.load(f) elif filename.endswith('.jz'): with gzip.open(filename, 'rt') as f: return json_loads(f.read()) else: raise ValueError( 'Cannot determine format: {}'.format(os.path.basename(filename)))