我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用six.moves.zip_longest()。
def __call__(cls, *args, **kwds): if len(args) > len(cls._fields): raise TypeError("%s() got too many positional arguments" % cls.__name__) values = {} for f, v in zip_longest(cls._fields, args): if v is None: v = kwds.get(f) elif f in kwds: raise TypeError("%s() got multiple values for argument '%s'" % (cls.__name__, f)) if v is None: if hasattr(cls, '_default_' + f): v = getattr(cls, '_default_' + f) else: raise TypeError("%s() missing required argument '%s'" % (cls.__name__, f)) elif hasattr(cls, '_clean_' + f): v = getattr(cls, '_clean_' + f)(v) values[f] = v return type.__call__(cls, **values)
def check_images_change(objects): for obj in objects: if obj['kind'] not in ('Deployment', 'DaemonSet', 'PetSet'): continue kube_obj = kubernetes.get_pykube_object_if_exists(obj) if kube_obj is None: continue old_obj = kube_obj.obj old_containers = old_obj['spec']['template']['spec']['containers'] old_images = [c['image'] for c in old_containers] new_containers = obj['spec']['template']['spec']['containers'] new_images = [c['image'] for c in new_containers] for old_image, new_image in zip_longest(old_images, new_images): if old_image != new_image: return old_image, new_image return False
def test_multiplicative_adjustments(self, name, data, lookback, adjustments, missing_value, perspective_offset, expected): array = AdjustedArray(data, NOMASK, adjustments, missing_value) for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable. window_iter = array.traverse( lookback, perspective_offset=perspective_offset, ) for yielded, expected_yield in zip_longest(window_iter, expected): check_arrays(yielded, expected_yield)
def test_overwrite_adjustment_cases(self, name, baseline, lookback, adjustments, missing_value, perspective_offset, expected): array = AdjustedArray(baseline, NOMASK, adjustments, missing_value) for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable. window_iter = array.traverse( lookback, perspective_offset=perspective_offset, ) for yielded, expected_yield in zip_longest(window_iter, expected): check_arrays(yielded, expected_yield)
def test_move_items(item_name): """Ensure that everything loads correctly.""" try: item = getattr(six.moves, item_name) if isinstance(item, types.ModuleType): __import__("six.moves." + item_name) except AttributeError: if item_name == "zip_longest" and sys.version_info < (2, 6): py.test.skip("zip_longest only available on 2.6+") except ImportError: if item_name == "winreg" and not sys.platform.startswith("win"): py.test.skip("Windows only module") if item_name.startswith("tkinter"): if not have_tkinter: py.test.skip("requires tkinter") if item_name == "tkinter_ttk" and sys.version_info[:2] <= (2, 6): py.test.skip("ttk only available on 2.7+") if item_name.startswith("dbm_gnu") and not have_gdbm: py.test.skip("requires gdbm") raise if sys.version_info[:2] >= (2, 6): assert item_name in dir(six.moves)
def get(self, content=True, links=True, comments=True): _content = self.gen_content_reqs(self.ids) if content else [] _links = self.gen_links_reqs(self.ids) if links else [] _comments = self.get_comments(self.ids, self.metas) if comments else () def gen_posts(): for content, links, comments in zip_longest( api.imap(_content), api.imap(_links), _comments ): post = {} post.update(content.json()) if content else None post.update({ 'links': links.json() if links else None, 'comments': self.extract_comments(comments) }) if post: yield post logger.info('[Posts.gen_posts <gen>] Processed.') return PostsResult(gen_posts)
def calc_consistency_score(segment_one, segment_two, offset_one, offset_two): """Calculate the number of bases aligned to the same reference bases in two alignments. :param segment_one: Pysam aligned segments. :param segment_two: Pysam aligned segments. :param offset_one: Hard clipping offset for the first alignment. :param offset_two: Hard clipping offset for the second alignment. :retruns: Number of matching base alignments. :rtype: int """ matches_one = aligned_pairs_to_matches( segment_one.get_aligned_pairs(), offset_one) matches_two = aligned_pairs_to_matches( segment_two.get_aligned_pairs(), offset_two) score = 0 for matches in zip_longest(matches_one, matches_two, fillvalue=False): if matches[0] == matches[1]: score += 1 return score
def _get_digests(images): images = list(images) if not images: return {} for image in images: fabricio.run( 'docker pull %s' % image, ignore_errors=True, quiet=False, use_cache=True, ) command = ( "docker inspect --type image --format '{{index .RepoDigests 0}}' %s" ) % ' '.join(images) digests = fabricio.run(command, ignore_errors=True, use_cache=True) return dict(zip_longest(images, filter(None, digests.splitlines())))
def _update_relation(self, id_, relation_type, samples, positions=[], label=None, relation=None): """Update existing relation.""" if relation is None: relation = self.resolwe.relation.get(id=id_) to_delete = copy.copy(relation.entities) to_add = [] for sample, position in zip_longest(samples, positions): entity_obj = {'entity': sample, 'position': position} if entity_obj in relation.entities: to_delete.remove(entity_obj) else: to_add.append(entity_obj) if to_add: relation.add_sample(*to_add) if to_delete: relation.remove_samples(*[obj['entity'] for obj in to_delete]) if label != relation.label: relation.label = label relation.save()
def test_no_adjustments(self, name, data, lookback, adjustments, missing_value, expected): array = AdjustedArray(data, NOMASK, adjustments, missing_value) for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable. window_iter = array.traverse(lookback) for yielded, expected_yield in zip_longest(window_iter, expected): self.assertEqual(yielded.dtype, data.dtype) assert_array_equal(yielded, expected_yield)
def test_multiplicative_adjustments(self, name, data, lookback, adjustments, missing_value, expected): array = AdjustedArray(data, NOMASK, adjustments, missing_value) for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable. window_iter = array.traverse(lookback) for yielded, expected_yield in zip_longest(window_iter, expected): assert_array_equal(yielded, expected_yield)
def test_overwrite_adjustment_cases(self, name, data, lookback, adjustments, missing_value, expected): array = AdjustedArray(data, NOMASK, adjustments, missing_value) for _ in range(2): # Iterate 2x ensure adjusted_arrays are re-usable. window_iter = array.traverse(lookback) for yielded, expected_yield in zip_longest(window_iter, expected): self.assertEqual(yielded.dtype, data.dtype) assert_array_equal(yielded, expected_yield)
def test_mppovm_expectation(nr_sites, width, local_dim, rank, nopovm, rgen): # Verify that :func:`povm.MPPovm.expectations()` produces # correct results. pmap = nopovm.probability_map mpnopovm = povm.MPPovm.from_local_povm(nopovm, width) # Use a random MPO rho for testing (instead of a positive MPO). rho = factory.random_mpo(nr_sites, local_dim, rank, rgen) reductions = mpsmpo.reductions_mpo(rho, width) # Compute expectation values with mpnopovm.expectations(), which # uses mpnopovm.probability_map. expectations = list(mpnopovm.expectations(rho)) assert len(expectations) == nr_sites - width + 1 for evals_mp, rho_red in zip_longest(expectations, reductions): # Compute expectation values by constructing each tensor # product POVM element. rho_red_matrix = rho_red.to_array_global().reshape( (local_dim**width,) * 2) evals = [] for factors in it.product(nopovm, repeat=width): elem = utils.mkron(*factors) evals.append(np.trace(np.dot(elem, rho_red_matrix))) evals = np.array(evals).reshape((len(nopovm),) * width) # Compute expectation with a different construction. In the # end, this is (should be, we verify it here) equivalent to # what `mpnopovm.expectations()` does. evals_ten = rho_red.ravel().to_array() for _ in range(width): evals_ten = np.tensordot(evals_ten, pmap, axes=(0, 1)) assert_array_almost_equal(evals_ten, evals) assert_array_almost_equal(evals_mp.to_array(), evals)
def grouper(n, iterable, fillvalue=None): "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx" args = [iter(iterable)] * n return zip_longest(*args, fillvalue=fillvalue)
def tileswrap(ihtORsize, numtilings, floats, wrapwidths, ints=[], readonly=False): ''' returns num-tilings tile indices corresponding to the floats and ints, wrapping some floats :param ihtORsize: integer or IHT object. An index hash table or a positive integer specifying the upper range of returned indices :param numtilings: integer. the number of tilings desired. For best results, the second argument, numTilings, should be a power of two greater or equal to four times the number of floats :param memory-size: ineteger. the number of possible tile indices :param floats: list. a list of real values making up the input vector :param wrapwidths: :param ints*: list. optional list of integers to get different hashings :param readonly*: boolean. ''' qfloats = [floor(f*numtilings) for f in floats] Tiles = [] for tiling in range(numtilings): tilingX2 = tiling*2 coords = [tiling] b = tiling for q, width in zip_longest(qfloats, wrapwidths): c = (q + b % numtilings) // numtilings coords.append(c % width if width else c) b += tilingX2 coords.extend(ints) Tiles.append(hashcoords(coords, ihtORsize, readonly)) return Tiles
def _add(s, A, B): if len(A) == 1: return s.element_class(s, [A[0] + B[0]] + list(B[1:])) elif len(B) == 1: return s.element_class(s, [A[0] + B[0]] + list(A[1:])) ret = [] for x, y in zip_longest(A, B, fillvalue=0): ret += [(x + y)] return s.element_class(s, ret)
def _equ(s, A, B): from six.moves import zip_longest if len(A) == 1 and len(B) == 1: return A[0] == B[0] return all([x == y for x, y in zip_longest(A, B, fillvalue=0)])
def _add(s, A, B): # A + B if len(A) == 1 and len(A[0]) == 1: A = A[0][0] return s.element_class(s, [[A + B[0][0]] + list(B[0][1:])] + list(B[1:])) elif len(B) == 1 and len(B[0]) == 1: B = B[0][0] return s.element_class(s, [[A[0][0] + B] + list(A[0][1:])] + list(A[1:])) ret = [] for x, y in zip_longest(A, B, fillvalue=[0]): t = [] for xs, ys in zip_longest(x, y, fillvalue=0): t += [xs + ys] ret += [t] return s.element_class(s, ret)
def _equ(s, A, B): if len(A) == 1 and len(A[0]) == 1 and len(B) == 1 and len(B[0]) == 1: return A[0][0] == B[0][0] return all([all([s == t for s, t in zip_longest(x, y, fillvalue=0)]) for x, y in zip(A, B)])
def main(j, args, params, tags, tasklet): params.result = page = args.page # parse params multiple_selection = False m = re.search(r'\{\{\s*dropdown\s*:\s*(.*)\n', args.macrostr) if m: attributes = [] macro_params = [part.strip().split('=') for part in m.group(1).split('|')] for pair in macro_params: if pair[0].lower() == 'multiple': multiple_selection = True if len(pair) == 2: attributes.append('{}="{}"'.format(*pair)) else: attributes.append(pair[0]) else: attributes = [] if multiple_selection: page.addJS(jsLink='/jslib/old/multiple_selection/multiple_selection.js') page.addCSS('/jslib/old/multiple_selection/multiple_selection.css') current_option = 0 options = [] lines = re.findall(r'\s*(\*+)\s+(.*)', args.cmdstr) for ((current_level, current_text), (next_level, _)) in zip_longest(lines, lines[1:], fillvalue=('', '')): if len(current_level) < len(next_level): options.append('<optgroup label="{0}">'.format(current_text)) continue options.append('<option value="{0}">{0}</option>'.format(current_text)) if len(current_level) > len(next_level): options.append('</optgroup>') page.addMessage('<select {}>{}</select>'.format(''.join(attributes), ''.join(options))) return params
def _calc_maxes(self): array = [self.header] + self.rows return [max(len(str(s)) for s in ss) for ss in zip_longest(*array, fillvalue='')]
def _get_printable_row(self, row): maxes = self._calc_maxes() return '| ' + ' | '.join([('{0: <%d}' % m).format(r) for r, m in zip_longest(row, maxes, fillvalue='')]) + ' |'
def _get_minibatch_feed_dict(self, target_q_values, non_terminal_minibatch, terminal_minibatch): """ Helper to construct the feed_dict for train_op. Takes the non-terminal and terminal minibatches as well as the max q-values computed from the target network for non-terminal states. Computes the expected q-values based on discounted future reward. @return: feed_dict to be used for train_op """ assert len(target_q_values) == len(non_terminal_minibatch) states = [] expected_q = [] actions = [] # Compute expected q-values to plug into the loss function minibatch = itertools.chain(non_terminal_minibatch, terminal_minibatch) for item, target_q in zip_longest(minibatch, target_q_values, fillvalue=0): state, action, reward, _, _ = item states.append(state) # target_q will be 0 for terminal states due to fillvalue in zip_longest expected_q.append(reward + self.config.reward_discount * target_q) actions.append(utils.one_hot(action, self.env.action_space.n)) return { self.network.x_placeholder: states, self.network.q_placeholder: expected_q, self.network.action_placeholder: actions, }
def group(n, iterable, fill_value=None): args = [iter(iterable)] * n return zip_longest(*args, fillvalue=fill_value)
def _data_iter(cls, documents, y): class DocIter(object): def __init__(self, documents, y): self.y = y self.documents = documents def __iter__(self): for sample, targets in zip_longest(self.documents, self.y): targets = cls._target_list(targets) yield (sample, targets) return DocIter(documents, y)
def __bind_commands(self): if not self.parallel: for attr in ['complete_kill', 'do_kill', 'do_status']: delattr(FrameworkConsole, attr) for name, func in get_commands(): longname = 'do_{}'.format(name) # set the behavior of the console command (multi-processed or not) # setattr(Console, longname, MethodType(FrameworkConsole.start_process_template(func) \ # if self.parallel and func.behavior.is_multiprocessed else func, self)) setattr(Console, longname, MethodType(func, self)) # retrieve parts of function's docstring to make console command's docstring parts = func.__doc__.split(':param ') description = parts[0].strip() arguments = [" ".join([l.strip() for l in x.split(":")[-1].split('\n')]) for x in parts[1:]] docstring = COMMAND_DOCSTRING["description"].format(description) if len(arguments) > 0: arg_descrs = [' - {}:\t{}'.format(n, d or "[no description]") \ for n, d in list(zip_longest(signature(func).parameters.keys(), arguments or [])) if n is not None] docstring += COMMAND_DOCSTRING["arguments"].format('\n'.join(arg_descrs)) if hasattr(func, 'examples') and isinstance(func.examples, list): args_examples = [' >>> {} {}'.format(name, e) for e in func.examples] docstring += COMMAND_DOCSTRING["examples"].format('\n'.join(args_examples)) setattr(getattr(getattr(Console, longname), '__func__'), '__doc__', docstring) # set the autocomplete list of values (can be lazy by using lambda) if relevant if hasattr(func, 'autocomplete'): setattr(Console, 'complete_{}'.format(name), MethodType(FrameworkConsole.complete_template(func.autocomplete), self)) if hasattr(func, 'reexec_on_emptyline') and func.reexec_on_emptyline: self.reexec.append(name)
def test_zip_longest(): from six.moves import zip_longest it = zip_longest(range(2), range(1)) assert six.advance_iterator(it) == (0, 0) assert six.advance_iterator(it) == (1, None)
def grouper(iterable, n, fillvalue=None): """Collect data into fixed-length chunks or blocks""" # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" args = [iter(iterable)] * n return zip_longest(*args, fillvalue=fillvalue)
def grouper(n, iterable, padvalue=None): "grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')" return zip_longest(*[iter(iterable)]*n, fillvalue=padvalue)
def test_basic_assembly(web_fixture, basic_scenarios): """An application is built by extending UserInterface, and defining this UserInterface in an .assemble() method. To define the UserInterface, several Views are defined. Views are mapped to URLs. When a user GETs the URL of a View, a page is rendered back to the user. How that page is created can happen in different ways, as illustrated by each scenario of this test. """ fixture = basic_scenarios wsgi_app = web_fixture.new_wsgi_app(site_root=fixture.MainUI) browser = Browser(wsgi_app) # GETting the URL results in the HTML for that View with warnings.catch_warnings(record=True) as caught_warnings: warnings.simplefilter('always') browser.open('/') assert browser.title == 'Hello' warning_messages = [six.text_type(i.message) for i in caught_warnings] assert len(warning_messages) == len(fixture.expected_warnings) for caught, expected_message in zip_longest(warning_messages, fixture.expected_warnings): assert expected_message in caught if fixture.content_includes_p: [message] = browser.xpath('//p') assert message.text == 'Hello world!' # The headers are set correctly response = browser.last_response assert response.text == fixture.expected_content assert response.content_type == 'text/html' assert response.charset == 'utf-8' # Invalid URLs do not exist with warnings.catch_warnings(record=True): browser.open('/nonexistantview/', status=404)
def expected_deprecation_warnings(expected_warnings): with warnings.catch_warnings(record=True) as caught_warnings: warnings.simplefilter('always') yield warning_messages = [six.text_type(i.message) for i in caught_warnings] assert len(warning_messages) == len(expected_warnings) for caught, expected_message in zip_longest(warning_messages, expected_warnings): assert expected_message in caught
def grouper(iterable, n=2, fillvalue=None): "Collect data into fixed-length chunks or blocks" # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" args = [iter(iterable)] * n return zip_longest(*args, fillvalue=fillvalue)
def _pretty_pos(aset): """ Helper function for pretty-printing a sentence with its POS tags. :param aset: The POS annotation set of the sentence to be printed. :type sent: list(AttrDict) :return: The text of the sentence and its POS tags. :rtype: str """ outstr = "" outstr += "POS annotation set ({0.ID}) {0.POS_tagset} in sentence {0.sent.ID}:\n\n".format(aset) # list the target spans and their associated aset index overt = sorted(aset.POS) sent = aset.sent s0 = sent.text s1 = '' s2 = '' i = 0 adjust = 0 for j,k,lbl in overt: assert j>=i,('Overlapping targets?',(j,k,lbl)) s1 += ' '*(j-i) + '-'*(k-j) if len(lbl)>(k-j): # add space in the sentence to make room for the annotation index amt = len(lbl)-(k-j) s0 = s0[:k+adjust]+ '~'*amt + s0[k+adjust:] # '~' to prevent line wrapping s1 = s1[:k+adjust]+ ' '*amt + s1[k+adjust:] adjust += amt s2 += ' '*(j-i) + lbl.ljust(k-j) i = k long_lines = [s0, s1, s2] outstr += '\n\n'.join(map('\n'.join, zip_longest(*mimic_wrap(long_lines), fillvalue=' '))).replace('~',' ') outstr += "\n" return outstr
def _create_relation(self, relation_type, samples, positions=[], label=None): """Create group relation with the given samples and positions.""" if not isinstance(samples, list): raise ValueError("`samples` argument must be list.") if not isinstance(positions, list): raise ValueError("`positions` argument must be list.") if positions: if len(samples) != len(positions): raise ValueError("`samples` and `positions` arguments must be of the same length.") relation_data = { 'type': relation_type, 'collection': self.id, 'entities': [] } for sample, position in zip_longest(samples, positions): entity_dict = {'entity': get_sample_id(sample)} if position: entity_dict['position'] = position relation_data['entities'].append(entity_dict) if label: relation_data['label'] = label return self.resolwe.relation.create(**relation_data)
def izip_equal(*iterables): """ Zip and raise exception if lengths are not equal. Taken from solution by Martijn Pieters, here: http://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python :param iterables: :return: """ sentinel = object() for combo in zip_longest(*iterables, fillvalue=sentinel): if any(sentinel is c for c in combo): raise ValueError('Iterables have different lengths') yield combo
def tileswrap (ihtORsize, numtilings, floats, wrapwidths, ints=[], readonly=False): """returns num-tilings tile indices corresponding to the floats and ints, wrapping some floats""" qfloats = [floor(f*numtilings) for f in floats] Tiles = [] for tiling in range(numtilings): tilingX2 = tiling*2 coords = [tiling] b = tiling for q, width in zip_longest(qfloats, wrapwidths): c = (q + b%numtilings) // numtilings coords.append(c%width if width else c) b += tilingX2 coords.extend(ints) Tiles.append(hashcoords(coords, ihtORsize, readonly)) return Tiles
def grouper(iterable, n, fillvalue=None): """Collect data into fixed-length chunks or blocks.""" # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx args = [iter(iterable)] * n return zip_longest(fillvalue=fillvalue, *args)
def as_index_map(keys, values): ''' Return ------ list : a list of `values` that is not None and non-duplicated, map : but also build index map that which key to index of which value Example >>> print(_build_index_map([1, 2, 3, 4], [1, 1, None, 2])) ... # ([1, 2], {1: 0, 2: 0, 3: None, 4: 1}) ''' seen = defaultdict(lambda: -1) keys_to_index = {} ret_values = [] for k, v in zip_longest(keys, values): if v is None: seen[v] = None elif seen[v] < 0: ret_values.append(v) seen[v] = len(ret_values) - 1 keys_to_index[k] = seen[v] return ret_values, keys_to_index # =========================================================================== # Misc # ===========================================================================
def local_sum(mpas, embed_tensor=None, length=None, slices=None): """Embed local MPAs on a linear chain and sum as MPA. We return the sum over :func:`embed_slice(length, slices[i], mpas[i], embed_tensor) <embed_slice>` as MPA. If ``slices`` is omitted, we use :func:`regular_slices(length, width, offset) <regular_slices>` with :code:`offset = 1`, :code:`width = len(mpas[0])` and :code:`length = len(mpas) + width - offset`. If ``slices`` is omitted or if the slices just described are given, we call :func:`_local_sum_identity()`, which gives a smaller virtual dimension than naive embedding and summing. :param mpas: List of local MPAs. :param embed_tensor: Defaults to square identity matrix (see :func:`_embed_ltens_identity` for details) :param length: Length of the resulting chain, ignored unless slices is given. :param slices: ``slice[i]`` specifies the position of ``mpas[i]``, optional. :returns: An MPA. """ # Check whether we can fall back to :func:`_local_sum_identity` # even though `slices` is given. if slices is not None: assert length is not None slices = tuple(slices) reg = regular_slices(length, slices[0].stop - slices[0].start, offset=1) if all(s == t for s, t in zip_longest(slices, reg)): slices = None # If `slices` is not given, use :func:`_local_sum_identity`. if slices is None: return _local_sum_identity(tuple(mpas), embed_tensor) mpas = (embed_slice(length, slice_, mpa, embed_tensor) for mpa, slice_ in zip(mpas, slices)) return sumup(mpas) ############################################################ # Functions for dealing with local operations on tensors # ############################################################
def depart_table(self, node): lines = self.table[1:] fmted_rows = [] colwidths = self.table[0] realwidths = colwidths[:] separator = 0 # don't allow paragraphs in table cells for now for line in lines: if line == 'sep': separator = len(fmted_rows) else: cells = [] for i, cell in enumerate(line): par = my_wrap(cell, width=colwidths[i]) if par: maxwidth = max(column_width(x) for x in par) else: maxwidth = 0 realwidths[i] = max(realwidths[i], maxwidth) cells.append(par) fmted_rows.append(cells) def writesep(char='-'): out = ['+'] for width in realwidths: out.append(char * (width+2)) out.append('+') self.add_text(''.join(out) + self.nl) def writerow(row): lines = zip_longest(*row) for line in lines: out = ['|'] for i, cell in enumerate(line): if cell: adjust_len = len(cell) - column_width(cell) out.append(' ' + cell.ljust( realwidths[i] + 1 + adjust_len)) else: out.append(' ' * (realwidths[i] + 2)) out.append('|') self.add_text(''.join(out) + self.nl) for i, row in enumerate(fmted_rows): if separator and i == separator: writesep('=') else: writesep('-') writerow(row) writesep('-') self.table = None self.end_state(wrap=False)
def _batch_grouping(batch, batch_size, rng, batch_filter): """ batch: contains [ (name, [list of data]), (name, [list of data]), (name, [list of data]), ... ] Note ---- We assume the shape[0] (or length) of all "data" and "others" are the same """ if len(batch) == 0: yield None else: # create batch of indices for each file (indices is the start # index of each batch) indices = [list(range(0, X[0].shape[0], batch_size)) for name, X in batch] # shuffle if possible if rng is not None: [rng.shuffle(i) for i in indices] # ====== create batch of data ====== # for idx in zip_longest(*indices): ret = [] for start, (name, X) in zip(idx, batch): # skip if the one data that is not enough if start is None: continue # pick data from each given input end = start + batch_size _ = [x[start:end] for x in X] ret.append(_) ret = [np.concatenate(x, axis=0) for x in zip(*ret)] # shuffle 1 more time N = list(set([r.shape[0] for r in ret])) if len(N) > 1: raise ValueError("The shape[0] of Data is different, found " "%d different length: %s" % (len(N), str(N))) N = N[0] if rng is not None: permutation = rng.permutation(N) ret = [r[permutation] for r in ret] # return the batches for start in range(0, N, batch_size): end = start + batch_size _ = batch_filter([x[start:end] for x in ret]) # always return tuple or list if _ is not None: yield _ if isinstance(_, (tuple, list)) else (ret,)
def group(batch): """ batch: contains [ (name, [list of data], [list of others]), (name, [list of data], [list of others]), (name, [list of data], [list of others]), ... ] Note ---- We assume the shape[0] (or length) of all "data" and "others" are the same """ rng = np.random.RandomState(1208) batch_size = 64 indices = [range((b[1][0].shape[0] - 1) // batch_size + 1) for b in batch] # shuffle if possible if rng is not None: [rng.shuffle(i) for i in indices] # ====== create batch of data ====== # for idx in zip_longest(*indices): ret = [] for i, b in zip(idx, batch): # skip if one of the data is not enough if i is None: continue # pick data from each given input name = b[0]; data = b[1]; others = b[2:] start = i * batch_size end = start + batch_size _ = [d[start:end] for d in data] + \ [o[start:end] for o in others] ret.append(_) ret = [np.concatenate(x, axis=0) for x in zip(*ret)] # # shuffle 1 more time if rng is not None: permutation = rng.permutation(ret[0].shape[0]) ret = [r[permutation] for r in ret] # # return the batches for i in range((ret[0].shape[0] - 1) // batch_size + 1): start = i * batch_size end = start + batch_size _ = [x[start:end] for x in ret] # always return tuple or list if _ is not None: yield _ if isinstance(_, (tuple, list)) else (ret,)
def _annotation_ascii_FEs(sent): ''' ASCII string rendering of the sentence along with a single target and its FEs. Secondary and tertiary FE layers are included if present. 'sent' can be an FE annotation set or an LU sentence with a single target. Line-wrapped to limit the display width. ''' feAbbrevs = OrderedDict() posspec = [] # POS-specific layer spans (e.g., Supp[ort], Cop[ula]) posspec_separate = False for lyr in ('Verb', 'Noun', 'Adj', 'Adv', 'Prep', 'Scon', 'Art'): if lyr in sent and sent[lyr]: for a,b,lbl in sent[lyr]: if lbl=='X': # skip this, which covers an entire phrase typically containing the target and all its FEs # (but do display the Gov) continue if any(1 for x,y,felbl in sent.FE[0] if x<=a<y or a<=x<b): # overlap between one of the POS-specific layers and first FE layer posspec_separate = True # show POS-specific layers on a separate line posspec.append((a,b,lbl.lower().replace('-',''))) # lowercase Cop=>cop, Non-Asp=>nonasp, etc. to distinguish from FE names if posspec_separate: POSSPEC = _annotation_ascii_FE_layer(posspec, {}, feAbbrevs) FE1 = _annotation_ascii_FE_layer(sorted(sent.FE[0] + (posspec if not posspec_separate else [])), sent.FE[1], feAbbrevs) FE2 = FE3 = None if 'FE2' in sent: FE2 = _annotation_ascii_FE_layer(sent.FE2[0], sent.FE2[1], feAbbrevs) if 'FE3' in sent: FE3 = _annotation_ascii_FE_layer(sent.FE3[0], sent.FE3[1], feAbbrevs) for i,j in sent.Target: FE1span, FE1name, FE1exp = FE1 if len(FE1span)<j: FE1span += ' '*(j-len(FE1span)) if len(FE1name)<j: FE1name += ' '*(j-len(FE1name)) FE1[1] = FE1name FE1[0] = FE1span[:i] + FE1span[i:j].replace(' ','*').replace('-','=') + FE1span[j:] long_lines = [sent.text] if posspec_separate: long_lines.extend(POSSPEC[:2]) long_lines.extend([FE1[0], FE1[1]+FE1[2]]) # lines with no length limit if FE2: long_lines.extend([FE2[0], FE2[1]+FE2[2]]) if FE3: long_lines.extend([FE3[0], FE3[1]+FE3[2]]) long_lines.append('') outstr = '\n'.join(map('\n'.join, zip_longest(*mimic_wrap(long_lines), fillvalue=' '))) if feAbbrevs: outstr += '('+', '.join('='.join(pair) for pair in feAbbrevs.items())+')' assert len(feAbbrevs)==len(dict(feAbbrevs)),'Abbreviation clash' outstr += "\n" return outstr