我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用theano.tensor.max_and_argmax()。
def test_optimization(self): # If we use only the max output, we should replace this op with # a faster one. mode = theano.compile.mode.get_default_mode().including( 'canonicalize', 'fast_run') for axis in [0, 1, -1]: data = numpy.asarray(numpy.random.rand(2, 3), dtype=config.floatX) n = tensor.matrix() f = function([n], tensor.max_and_argmax(n, axis)[0], mode=mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 assert isinstance(topo[0].op, CAReduce) f = function([n], tensor.max_and_argmax(n, axis), mode=mode) topo = f.maker.fgraph.toposort() assert len(topo) == 1 assert isinstance(topo[0].op, tensor.MaxAndArgmax)
def _vitabi_forward(e_t, score_prev, trans): """ :param e_t: 1D: Batch, 2D: n_y :param score_prev: 1D: Batch, 2D: n_y :param trans: 1D: n_y, 2D, n_y """ score = score_prev.dimshuffle(0, 'x', 1) + trans + e_t.dimshuffle(0, 1, 'x') max_scores_t, max_nodes_t = T.max_and_argmax(score, axis=2) return max_scores_t, T.cast(max_nodes_t, dtype='int32')
def viterbi(obs_potentials, chain_potentials): def inner_function(obs, prior_result, chain_potentials): maxscore, maxarg = T.max_and_argmax(prior_result.dimshuffle(0, 1, 'x') + obs.dimshuffle('x', 'x', 0) + chain_potentials, axis=0) return maxscore, maxarg initial = T.zeros_like(chain_potentials[0]) [score, path], _ = theano.scan(fn=inner_function, outputs_info=[initial, None], sequences=[obs_potentials], non_sequences=chain_potentials) a = score[-1] aa = T.argmax(a) u = aa / a.shape[1] v = aa % a.shape[1] return trace_back(path, u, v)[1:]
def test_argmax_pushdown(): x = tensor.matrix() for sm in [softmax_graph, softmax_op]: # test that the max_and_argmax is pushed down if the max is not used out = tensor.max_and_argmax( sm(tensor.exp(tensor.tanh(sigmoid(x)))), axis=-1)[1] fgraph = gof.FunctionGraph( [x], [out]) theano.compile.mode.optdb.query( theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) # print 'AFTER' # for node in fgraph.toposort(): # print node.op assert len(fgraph.toposort()) == 2 # an output_guard is second assert fgraph.toposort()[0].op == tensor.basic._max_and_argmax assert str(fgraph.toposort()[1].op) == 'OutputGuard' assert check_stack_trace( fgraph, ops_to_check=tensor.basic._max_and_argmax) x = tensor.matrix() # test that the max_and_argmax is not pushed down if the max is used out = tensor.max_and_argmax( sm(tensor.exp(tensor.tanh(sigmoid(x)))), axis=-1)[0] fgraph = gof.FunctionGraph( [x], [out]) assert hasattr(fgraph.outputs[0].tag, 'trace') backup = config.warn.argmax_pushdown_bug config.warn.argmax_pushdown_bug = False try: theano.compile.mode.optdb.query( theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) finally: config.warn.argmax_pushdown_bug = backup # print 'AFTER' # for node in fgraph.toposort(): # print node.op assert len(fgraph.toposort()) == 4 # an output_guard is second assert isinstance(fgraph.toposort()[0].op, tensor.Elemwise) assert isinstance(fgraph.toposort()[1].op, Softmax) assert isinstance(fgraph.toposort()[2].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[2].op.scalar_op, theano.scalar.Maximum) assert str(fgraph.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias(): x = tensor.matrix() b = tensor.vector() out = tensor.argmax(softmax_with_bias(x, b), axis=-1) fgraph = gof.FunctionGraph( [x, b], [out]) theano.compile.mode.optdb.query( theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) # print 'AFTER' # for node in fgraph.toposort(): # print node.op types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.MaxAndArgmax) assert len(fgraph.toposort()) == 4 for i, type in enumerate(types_to_check): assert isinstance(fgraph.toposort()[i].op, type) assert str(fgraph.toposort()[3].op) == 'OutputGuard' assert check_stack_trace(fgraph, ops_to_check=types_to_check) x = tensor.matrix() b = tensor.vector() out = tensor.max_and_argmax(softmax_with_bias(x, b), axis=-1)[0] fgraph = gof.FunctionGraph( [x, b], [out]) backup = config.warn.argmax_pushdown_bug config.warn.argmax_pushdown_bug = False try: theano.compile.mode.optdb.query( theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) finally: config.warn.argmax_pushdown_bug = backup # print 'AFTER' # for node in fgraph.toposort(): # print node.op assert len(fgraph.toposort()) == 3 assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) assert isinstance(fgraph.toposort()[1].op, tensor.CAReduce) assert isinstance(fgraph.toposort()[1].op.scalar_op, theano.scalar.Maximum) assert str(fgraph.toposort()[2].op) == 'OutputGuard' assert check_stack_trace( fgraph, ops_to_check=(SoftmaxWithBias, tensor.CAReduce))