Python six.moves 模块,StringIO() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用six.moves.StringIO()

项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def testRunnerRegistersResult(self):
        class Test(unittest2.TestCase):
            def testFoo(self):
                pass
        originalRegisterResult = unittest2.runner.registerResult
        def cleanup():
            unittest2.runner.registerResult = originalRegisterResult
        self.addCleanup(cleanup)

        result = unittest2.TestResult()
        runner = unittest2.TextTestRunner(stream=StringIO())
        # Use our result object
        runner._makeResult = lambda: result

        self.wasRegistered = 0
        def fakeRegisterResult(thisResult):
            self.wasRegistered += 1
            self.assertEqual(thisResult, result)
        unittest2.runner.registerResult = fakeRegisterResult

        runner.run(unittest2.TestSuite())
        self.assertEqual(self.wasRegistered, 1)
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_startTestRun_stopTestRun_called(self):
        class LoggingTextResult(LoggingResult):
            separator2 = ''
            def printErrors(self):
                pass

        class LoggingRunner(unittest2.TextTestRunner):
            def __init__(self, events):
                super(LoggingRunner, self).__init__(StringIO())
                self._events = events

            def _makeResult(self):
                return LoggingTextResult(self._events)

        events = []
        runner = LoggingRunner(events)
        runner.run(unittest2.TestSuite())
        expected = ['startTestRun', 'stopTestRun']
        self.assertEqual(events, expected)
项目:deb-python-pyvmomi    作者:openstack    | 项目源码 | 文件源码
def SerializeFaultDetail(val, info=None, version=None, nsMap=None, encoding=None):
   if version is None:
      try:
         if not isinstance(val, MethodFault):
            raise TypeError('{0} is not a MethodFault'.format(str(val)))
         version = val._version
      except AttributeError:
         version = BASE_VERSION
   if info is None:
      info = Object(name="object", type=object, version=version, flags=0)

   writer = StringIO()
   SoapSerializer(writer, version, nsMap, encoding).SerializeFaultDetail(val, info)
   return writer.getvalue()

## SOAP serializer
#
项目:ranger-agent    作者:openstack    | 项目源码 | 文件源码
def formatException(self, exc_info, record=None):
        """Format exception output with CONF.logging_exception_prefix."""
        if not record:
            return logging.Formatter.formatException(self, exc_info)

        stringbuffer = moves.StringIO()
        traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
                                  None, stringbuffer)
        lines = stringbuffer.getvalue().split('\n')
        stringbuffer.close()

        if CONF.logging_exception_prefix.find('%(asctime)') != -1:
            record.asctime = self.formatTime(record, self.datefmt)

        formatted_lines = []
        for line in lines:
            pl = CONF.logging_exception_prefix % record.__dict__
            fl = '%s%s' % (pl, line)
            formatted_lines.append(fl)
        return '\n'.join(formatted_lines)
项目:Theano-Deep-learning    作者:GeekLiB    | 项目源码 | 文件源码
def __str__(self):
        sio = StringIO()
        print("  node:", self.node, file=sio)
        print("  node.inputs:", [(str(i), id(i))
                                 for i in self.node.inputs], file=sio)
        print("  node.outputs:", [(str(i), id(i))
                                  for i in self.node.outputs], file=sio)
        print("  view_map:", getattr(self.node.op, 'view_map', {}), file=sio)
        print("  destroy_map:", getattr(self.node.op,
                                        'destroy_map', {}), file=sio)
        print("  aliased output:", self.output_idx, file=sio)
        print("  aliased output storage:", self.out_storage, file=sio)
        if self.in_alias_idx:
            print("  aliased to inputs:", self.in_alias_idx, file=sio)
        if self.out_alias_idx:
            print("  aliased to outputs:", self.out_alias_idx, file=sio)
        return sio.getvalue()
项目:niceman    作者:ReproNim    | 项目源码 | 文件源码
def _test_progress_bar(backend, len, increment):
    out = StringIO()
    fill_str = ('123456890' * (len//10))[:len]
    pb = DialogUI(out).get_progressbar('label', fill_str, maxval=10, backend=backend)
    pb.start()
    # we can't increment 11 times
    for x in range(11):
        if not (increment and x == 0):
            # do not increment on 0
            pb.update(x if not increment else 1, increment=increment)
        out.flush()  # needed atm
        pstr = out.getvalue()
        ok_startswith(pstr.lstrip('\r'), 'label:')
        assert_re_in(r'.*\b%d%%.*' % (10*x), pstr)
        if backend == 'progressbar':
            assert_in('ETA', pstr)
    pb.finish()
    ok_endswith(out.getvalue(), '\n')
项目:col-aws-clients    作者:collectrium    | 项目源码 | 文件源码
def test_deployment(self):
        ag_deployer = APIGatewayDeployer(
            api_name='Sample',
            region_name=self.region_name,
            aws_access_key_id=self.aws_access_key_id,
            aws_secret_access_key=self.aws_secret_access_key,
            swagger_file=StringIO(json.dumps(self.swagger_json)),
            domain_name="example.com"

        )

        ag_deployer.deploy_stage(stage='development',
                                     lambda_function_name='api_lambda',
                                     lambda_version='development')

        ag_deployer.deploy_domain(stage='development', base_path='v1',
                                  certificate_body="", certificate_private_key="", certificate_chain='')

        ag_deployer.deploy_stage(stage='development',
                                     lambda_function_name='api_lambda',
                                     lambda_version='development')

        ag_deployer.deploy_domain(stage='development', base_path='v1',
                                  certificate_body="", certificate_private_key="", certificate_chain='')
项目:kobo    作者:release-engineering    | 项目源码 | 文件源码
def test_iter_chunks(self):
        self.assertEqual(list(iter_chunks([], 100)), [])
        self.assertEqual(list(iter_chunks(list(range(5)), 1)), [[0], [1], [2], [3], [4]])
        self.assertEqual(list(iter_chunks(list(range(5)), 2)), [[0, 1], [2, 3], [4]])
        self.assertEqual(list(iter_chunks(list(range(5)), 5)), [[0, 1, 2, 3, 4]])
        self.assertEqual(list(iter_chunks(list(range(6)), 2)), [[0, 1], [2, 3], [4, 5]])

        self.assertEqual(list(iter_chunks(range(5), 2)), [[0, 1], [2, 3], [4]])
        self.assertEqual(list(iter_chunks(range(6), 2)), [[0, 1], [2, 3], [4, 5]])
        self.assertEqual(list(iter_chunks(range(1, 6), 2)), [[1, 2], [3, 4], [5]])
        self.assertEqual(list(iter_chunks(range(1, 7), 2)), [[1, 2], [3, 4], [5, 6]])

        def gen(num):
            for i in range(num):
                yield i+1
        self.assertEqual(list(iter_chunks(gen(5), 2)), [[1, 2], [3, 4], [5]])

        self.assertEqual(list(iter_chunks("01234", 2)), ["01", "23", "4"])
        self.assertEqual(list(iter_chunks("012345", 2)), ["01", "23", "45"])

        file_obj = open(os.path.dirname(os.path.abspath(__file__)) + "/chunks_file", "r")
        self.assertEqual(list(iter_chunks(file_obj, 11)), (10 * ["1234567890\n"]) + ["\n"])

        string_io = StringIO((10 * "1234567890\n") + "\n")
        self.assertEqual(list(iter_chunks(string_io, 11)), (10 * ["1234567890\n"]) + ["\n"])
项目:LIS-Tempest    作者:LIS    | 项目源码 | 文件源码
def formatException(self, exc_info, record=None):
        """Format exception output with CONF.logging_exception_prefix."""
        if not record:
            return logging.Formatter.formatException(self, exc_info)

        stringbuffer = moves.StringIO()
        traceback.print_exception(exc_info[0], exc_info[1], exc_info[2],
                                  None, stringbuffer)
        lines = stringbuffer.getvalue().split('\n')
        stringbuffer.close()

        if CONF.logging_exception_prefix.find('%(asctime)') != -1:
            record.asctime = self.formatTime(record, self.datefmt)

        formatted_lines = []
        for line in lines:
            pl = CONF.logging_exception_prefix % record.__dict__
            fl = '%s%s' % (pl, line)
            formatted_lines.append(fl)
        return '\n'.join(formatted_lines)
项目:Trusted-Platform-Module-nova    作者:BU-NU-CLOUD-SP16    | 项目源码 | 文件源码
def test_list_without_host(self):
        output = StringIO()
        sys.stdout = output
        with mock.patch.object(objects.InstanceList, 'get_by_filters') as get:
            get.return_value = objects.InstanceList(
                objects=[fake_instance.fake_instance_obj(
                    context.get_admin_context(), host='foo-host',
                    flavor=self.fake_flavor,
                    system_metadata={})])
            self.commands.list()

        sys.stdout = sys.__stdout__
        result = output.getvalue()

        self.assertIn('node', result)   # check the header line
        self.assertIn('m1.tiny', result)    # flavor.name
        self.assertIn('foo-host', result)
项目:Trusted-Platform-Module-nova    作者:BU-NU-CLOUD-SP16    | 项目源码 | 文件源码
def test_list_with_host(self):
        output = StringIO()
        sys.stdout = output
        with mock.patch.object(objects.InstanceList, 'get_by_host') as get:
            get.return_value = objects.InstanceList(
                objects=[fake_instance.fake_instance_obj(
                    context.get_admin_context(),
                    flavor=self.fake_flavor,
                    system_metadata={})])
            self.commands.list(host='fake-host')

        sys.stdout = sys.__stdout__
        result = output.getvalue()

        self.assertIn('node', result)   # check the header line
        self.assertIn('m1.tiny', result)    # flavor.name
        self.assertIn('fake-host', result)
项目:Trusted-Platform-Module-nova    作者:BU-NU-CLOUD-SP16    | 项目源码 | 文件源码
def _test_archive_deleted_rows(self, mock_db_archive, verbose=False):
        self.useFixture(fixtures.MonkeyPatch('sys.stdout', StringIO()))
        self.commands.archive_deleted_rows(20, verbose=verbose)
        mock_db_archive.assert_called_once_with(20)
        output = sys.stdout.getvalue()
        if verbose:
            expected = '''\
+-----------+-------------------------+
| Table     | Number of Rows Archived |
+-----------+-------------------------+
| consoles  | 5                       |
| instances | 10                      |
+-----------+-------------------------+
'''
            self.assertEqual(expected, output)
        else:
            self.assertEqual(0, len(output))
项目:Trusted-Platform-Module-nova    作者:BU-NU-CLOUD-SP16    | 项目源码 | 文件源码
def test_download_data_dest_path_write_fails(self, show_mock, open_mock):
        client = mock.MagicMock()
        client.call.return_value = [1, 2, 3]
        ctx = mock.sentinel.ctx
        service = glance.GlanceImageService(client)

        # NOTE(mikal): data is a file like object, which in our case always
        # raises an exception when we attempt to write to the file.
        class FakeDiskException(Exception):
            pass

        class Exceptionator(StringIO):
            def write(self, _):
                raise FakeDiskException('Disk full!')

        self.assertRaises(FakeDiskException, service.download, ctx,
                          mock.sentinel.image_id, data=Exceptionator())
项目:Hawkeye    作者:tozhengxq    | 项目源码 | 文件源码
def preview_sql(self, url, step, **args):
        """Mocks SQLAlchemy Engine to store all executed calls in a string
        and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>`

        :returns: SQL file
        """
        buf = StringIO()
        args['engine_arg_strategy'] = 'mock'
        args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)

        @with_engine
        def go(url, step, **kw):
            engine = kw.pop('engine')
            self.run(engine, step)
            return buf.getvalue()

        return go(url, step, **args)
项目:gixy    作者:yandex    | 项目源码 | 文件源码
def serialize(self, items):
        """Does the inverse of config parsing by taking parsed values and
        converting them back to a string representing config file contents.
        """
        r = StringIO()
        for key, value in items.items():
            if type(value) == OrderedDict:
                r.write('\n[%s]\n' % key)
                r.write(self.serialize(value))
            else:
                value, help = value
                if help:
                    r.write('; %s\n' % help)
                r.write('%s = %s\n' % (key, value))
        return r.getvalue()
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def startTest(self, test):
        "Called when the given test is about to be run"
        self.testsRun += 1
        self._mirrorOutput = False
        if self.buffer:
            if self._stderr_buffer is None:
                self._stderr_buffer = StringIO()
                self._stdout_buffer = StringIO()
            sys.stdout = self._stdout_buffer
            sys.stderr = self._stderr_buffer
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def captured_output(stream_name):
    """Return a context manager used by captured_stdout/stdin/stderr
    that temporarily replaces the sys stream *stream_name* with a StringIO."""
    orig_stdout = getattr(sys, stream_name)
    setattr(sys, stream_name, StringIO())
    try:
        yield getattr(sys, stream_name)
    finally:
        setattr(sys, stream_name, orig_stdout)
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_new_runner_old_case(self):
        runner = unittest2.TextTestRunner(resultclass=resultFactory,
                                          stream=StringIO())
        class Test(unittest.TestCase):
            def testOne(self):
                pass
        suite = unittest2.TestSuite((Test('testOne'),))
        result = runner.run(suite)
        self.assertEqual(result.testsRun, 1)
        self.assertEqual(len(result.errors), 0)
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_old_runner_new_case(self):
        runner = unittest.TextTestRunner(stream=StringIO())
        class Test(unittest2.TestCase):
            def testOne(self):
                self.assertDictEqual({}, {})

        suite = unittest.TestSuite((Test('testOne'),))
        result = runner.run(suite)
        self.assertEqual(result.testsRun, 1)
        self.assertEqual(len(result.errors), 0)
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def testFailFastSetByRunner(self):
        runner = unittest2.TextTestRunner(stream=StringIO(), failfast=True)
        self.testRan = False
        def test(result):
            self.testRan = True
            self.assertTrue(result.failfast)
        runner.run(test)
        self.assertTrue(self.testRan)
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_NonExit(self):
        program = unittest2.main(exit=False,
                                argv=["foobar"],
                                testRunner=unittest2.TextTestRunner(stream=StringIO()),
                                testLoader=self.FooBarLoader())
        self.assertTrue(hasattr(program, 'result'))
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_Exit(self):
        self.assertRaises(
            SystemExit,
            unittest2.main,
            argv=["foobar"],
            testRunner=unittest2.TextTestRunner(stream=StringIO()),
            exit=True,
            testLoader=self.FooBarLoader())
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_ExitAsDefault(self):
        self.assertRaises(
            SystemExit,
            unittest2.main,
            argv=["foobar"],
            testRunner=unittest2.TextTestRunner(stream=StringIO()),
            testLoader=self.FooBarLoader())
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def testRunner(self):
        # Creating a TextTestRunner with the appropriate argument should
        # register the TextTestResult it creates
        runner = unittest2.TextTestRunner(stream=StringIO())

        result = runner.run(unittest2.TestSuite())
        self.assertIn(result, unittest2.signals._results)
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_locals(self):
        runner = unittest.TextTestRunner(stream=io.StringIO(), tb_locals=True)
        result = runner.run(unittest.TestSuite())
        self.assertEqual(True, result.tb_locals)
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_works_with_result_without_startTestRun_stopTestRun(self):
        class OldTextResult(OldTestResult):
            def __init__(self, *_):
                super(OldTextResult, self).__init__()
            separator2 = ''
            def printErrors(self):
                pass

        runner = unittest2.TextTestRunner(stream=StringIO(),
                                          resultclass=OldTextResult)
        runner.run(unittest2.TestSuite())
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def test_pickle_unpickle(self):
        # Issue #7197: a TextTestRunner should be (un)pickleable. This is
        # required by test_multiprocessing under Windows (in verbose mode).
        stream = StringIO(u("foo"))
        runner = unittest2.TextTestRunner(stream)
        for protocol in range(2, pickle.HIGHEST_PROTOCOL + 1):
            s = pickle.dumps(runner, protocol=protocol)
            obj = pickle.loads(s)
            # StringIO objects never compare equal, a cheap test instead.
            self.assertEqual(obj.stream.getvalue(), stream.getvalue())
项目:devsecops-example-helloworld    作者:boozallen    | 项目源码 | 文件源码
def getRunner(self):
        return unittest2.TextTestRunner(resultclass=resultFactory,
                                          stream=StringIO())
项目:python-bileanclient    作者:openstack    | 项目源码 | 文件源码
def __enter__(self):
        self.real_stdout = sys.stdout
        self.stringio = moves.StringIO()
        sys.stdout = self.stringio
        return self
项目:Dshield    作者:ywjt    | 项目源码 | 文件源码
def testDuckTyping(self):
        # We want to support arbitrary classes that implement the stream
        # interface.

        class StringPassThrough(object):
            def __init__(self, stream):
                self.stream = stream

            def read(self, *args, **kwargs):
                return self.stream.read(*args, **kwargs)


        dstr = StringPassThrough(StringIO('2014 January 19'))

        self.assertEqual(parse(dstr), datetime(2014, 1, 19))
项目:Dshield    作者:ywjt    | 项目源码 | 文件源码
def testParseStream(self):
        dstr = StringIO('2014 January 19')

        self.assertEqual(parse(dstr), datetime(2014, 1, 19))
项目:http-prompt    作者:eliangcs    | 项目源码 | 文件源码
def _colorize(self, text, token_type):
        if not self.formatter:
            return text

        out = StringIO()
        self.formatter.format([(token_type, text)], out)
        return out.getvalue()
项目:python-apt-mirror-updater    作者:xolox    | 项目源码 | 文件源码
def gather_eol_dates(context, directory=DISTRO_INFO_DIRECTORY):
    """
    Gather release `end of life`_ dates from distro-info-data_ CSV files.

    :param context: An execution context created by :mod:`executor.contexts`.
    :param directory: The pathname of a directory with CSV files containing
                      end-of-life dates (a string, defaults to
                      :data:`DISTRO_INFO_DIRECTORY`).
    :returns: A dictionary like :data:`KNOWN_EOL_DATES`.
    """
    known_dates = {}
    if context.is_directory(directory):
        for entry in context.list_entries(directory):
            filename = os.path.join(directory, entry)
            basename, extension = os.path.splitext(entry)
            if extension.lower() == '.csv':
                distributor_id = basename.lower()
                known_dates[distributor_id] = {}
                contents = context.read_file(filename)
                for row in csv.DictReader(StringIO(contents)):
                    series = row.get('series')
                    eol = row.get('eol-server') or row.get('eol')
                    if series and eol:
                        eol = time.mktime(parse_date(eol) + (-1, -1, -1))
                        known_dates[distributor_id][series] = int(eol)
    return known_dates
项目:axibot    作者:storborg    | 项目源码 | 文件源码
def process_upload(app, document, filename):
    if document[0] == '{':
        f = StringIO(document)
        return Job.deserialize(f)
    else:
        return planning.plan_job(document, filename=filename)
项目:performance    作者:python    | 项目源码 | 文件源码
def main(loops, level):
    board, solution = LEVELS[level]
    order = DESCENDING
    strategy = Done.FIRST_STRATEGY
    stream = StringIO()

    board = board.strip()
    expected = solution.rstrip()

    range_it = xrange(loops)
    t0 = perf.perf_counter()

    for _ in range_it:
        stream = StringIO()
        solve_file(board, strategy, order, stream)
        output = stream.getvalue()
        stream = None

    dt = perf.perf_counter() - t0

    output = '\n'.join(line.rstrip() for line in output.splitlines())
    if output != expected:
        raise AssertionError("got a wrong answer:\n%s\nexpected: %s"
                             % (output, expected))

    return dt
项目:intel-iot-refkit    作者:intel    | 项目源码 | 文件源码
def _save_output_data(self):
        # Only try to get sys.stdout and sys.sterr as they not be
        # StringIO yet, e.g. when test fails during __call__
        try:
            self._stdout_data = sys.stdout.getvalue()
            self._stderr_data = sys.stderr.getvalue()
        except AttributeError:
            pass
项目:Chromium_DepotTools    作者:p07r0457    | 项目源码 | 文件源码
def reset(self):
        self.out = StringIO()
        self.messages = []
项目:Chromium_DepotTools    作者:p07r0457    | 项目源码 | 文件源码
def tokenize_str(code):
    return list(tokenize.generate_tokens(StringIO(code).readline))
项目:node-gn    作者:Shouqun    | 项目源码 | 文件源码
def reset(self):
        self.out = StringIO()
        self.messages = []
项目:node-gn    作者:Shouqun    | 项目源码 | 文件源码
def tokenize_str(code):
    return list(tokenize.generate_tokens(StringIO(code).readline))
项目:deb-python-pyvmomi    作者:openstack    | 项目源码 | 文件源码
def _SerializeToUnicode(val, info=None, version=None, nsMap=None):
   if version is None:
      try:
         if isinstance(val, list):
            itemType = val.Item
            version = itemType._version
         else:
            if val is None:
               # neither val nor version is given
               return ''
            # Pick up the version from val
            version = val._version
      except AttributeError:
         version = BASE_VERSION
   if info is None:
      info = Object(name="object", type=object, version=version, flags=0)

   writer = StringIO()
   SoapSerializer(writer, version, nsMap).Serialize(val, info)
   return writer.getvalue()

## Serialize fault detail
#
# Serializes a fault as the content of the detail element in a
# soapenv:Fault (i.e. without a LocalizedMethodFault wrapper).
#
# This function assumes CheckField(info, val) was already called
# @param val the value to serialize
# @param info the field
# @param version the version
# @param nsMap a dict of xml ns -> prefix
# @return the serialized object as a unicode string
项目:artemis    作者:QUVA-Lab    | 项目源码 | 文件源码
def unzip_gz(data):
    return gzip.GzipFile(fileobj = StringIO(data)).read()
项目:artemis    作者:QUVA-Lab    | 项目源码 | 文件源码
def __init__(self, log_file_path = None, print_to_console = True, prefix = None):
        """
        :param log_file_path: The path to save the records, or None if you just want to keep it in memory
        :param print_to_console:
        """
        self._print_to_console = print_to_console
        if log_file_path is not None:
            # self._log_file_path = os.path.join(base_dir, log_file_path.replace('%T', now))
            make_file_dir(log_file_path)
            self.log = open(log_file_path, 'w')
        else:
            self.log = StringIO()
        self._log_file_path = log_file_path
        self.old_stdout = _ORIGINAL_STDOUT
        self.prefix = None if prefix is None else prefix
项目:Theano-Deep-learning    作者:GeekLiB    | 项目源码 | 文件源码
def test_record_good():
    """
    Tests that when we record a sequence of events, then
    repeat it exactly, the Record class:
        1) Records it correctly
        2) Does not raise any errors
    """

    # Record a sequence of events
    output = StringIO()

    recorder = Record(file_object=output, replay=False)

    num_lines = 10

    for i in xrange(num_lines):
        recorder.handle_line(str(i) + '\n')

    # Make sure they were recorded correctly
    output_value = output.getvalue()

    assert output_value == ''.join(str(i) + '\n' for i in xrange(num_lines))

    # Make sure that the playback functionality doesn't raise any errors
    # when we repeat them
    output = StringIO(output_value)

    playback_checker = Record(file_object=output, replay=True)

    for i in xrange(num_lines):
        playback_checker.handle_line(str(i) + '\n')
项目:Theano-Deep-learning    作者:GeekLiB    | 项目源码 | 文件源码
def test_record_bad():
    """
    Tests that when we record a sequence of events, then
    do something different on playback, the Record class catches it.
    """

    # Record a sequence of events
    output = StringIO()

    recorder = Record(file_object=output, replay=False)

    num_lines = 10

    for i in xrange(num_lines):
        recorder.handle_line(str(i) + '\n')

    # Make sure that the playback functionality doesn't raise any errors
    # when we repeat some of them
    output_value = output.getvalue()
    output = StringIO(output_value)

    playback_checker = Record(file_object=output, replay=True)

    for i in xrange(num_lines // 2):
        playback_checker.handle_line(str(i) + '\n')

    # Make sure it raises an error when we deviate from the recorded sequence
    try:
        playback_checker.handle_line('0\n')
    except MismatchError:
        return
    raise AssertionError("Failed to detect mismatch between recorded sequence "
                         " and repetition of it.")
项目:Theano-Deep-learning    作者:GeekLiB    | 项目源码 | 文件源码
def test_record_mode_good():
    """
    Like test_record_good, but some events are recorded by the
    theano RecordMode. We don't attempt to check the
    exact string value of the record in this case.
    """

    # Record a sequence of events
    output = StringIO()

    recorder = Record(file_object=output, replay=False)

    record_mode = RecordMode(recorder)

    i = iscalar()
    f = function([i], i, mode=record_mode, name='f')

    num_lines = 10

    for i in xrange(num_lines):
        recorder.handle_line(str(i) + '\n')
        f(i)

    # Make sure that the playback functionality doesn't raise any errors
    # when we repeat them
    output_value = output.getvalue()
    output = StringIO(output_value)

    playback_checker = Record(file_object=output, replay=True)

    playback_mode = RecordMode(playback_checker)

    i = iscalar()
    f = function([i], i, mode=playback_mode, name='f')

    for i in xrange(num_lines):
        playback_checker.handle_line(str(i) + '\n')
        f(i)
项目:Theano-Deep-learning    作者:GeekLiB    | 项目源码 | 文件源码
def test_pydotprint_cond_highlight():
    """
    This is a REALLY PARTIAL TEST.

    I did them to help debug stuff.
    """

    # Skip test if pydot is not available.
    if not theano.printing.pydot_imported:
        raise SkipTest('pydot not available')

    x = tensor.dvector()
    f = theano.function([x], x * 2)
    f([1, 2, 3, 4])

    s = StringIO()
    new_handler = logging.StreamHandler(s)
    new_handler.setLevel(logging.DEBUG)
    orig_handler = theano.logging_default_handler

    theano.theano_logger.removeHandler(orig_handler)
    theano.theano_logger.addHandler(new_handler)
    try:
        theano.printing.pydotprint(f, cond_highlight=True,
                                   print_output_file=False)
    finally:
        theano.theano_logger.addHandler(orig_handler)
        theano.theano_logger.removeHandler(new_handler)

    assert (s.getvalue() == 'pydotprint: cond_highlight is set but there'
            ' is no IfElse node in the graph\n')
项目:Theano-Deep-learning    作者:GeekLiB    | 项目源码 | 文件源码
def test2_invalid_neg(self):
        n = as_tensor_variable(rand(2, 3))
        old_stderr = sys.stderr
        sys.stderr = StringIO()
        try:
            try:
                eval_outputs(max_and_argmax(n, -3))
                assert False
            except ValueError as e:
                pass
        finally:
            sys.stderr = old_stderr
项目:Theano-Deep-learning    作者:GeekLiB    | 项目源码 | 文件源码
def test2_invalid_neg(self):
        for fct, nfct in [(argmax, numpy.argmax), (argmin, numpy.argmin)]:
            n = as_tensor_variable(rand(2, 3))
            old_stderr = sys.stderr
            sys.stderr = StringIO()
            try:
                try:
                    eval_outputs(fct(n, -3))
                    assert False
                except ValueError as e:
                    pass
            finally:
                sys.stderr = old_stderr
项目:Theano-Deep-learning    作者:GeekLiB    | 项目源码 | 文件源码
def test2_invalid_neg(self):
        for fct in [max, min]:
            n = as_tensor_variable(rand(2, 3))
            old_stderr = sys.stderr
            sys.stderr = StringIO()
            try:
                try:
                    eval_outputs(fct(n, -3))
                    assert False
                except ValueError as e:
                    pass
            finally:
                sys.stderr = old_stderr