我们从Python开源项目中,提取了以下2个代码示例,用于说明如何使用utils.save_model()。
def train(): log.info('loading dataset...') train_data=TextIterator(train_file,n_batch=batch_size,maxlen=maxlen) valid_data = TextIterator(valid_file, n_batch=batch_size, maxlen=maxlen) test_data = TextIterator(test_file, n_batch=batch_size, maxlen=maxlen,mode=2) log.info('building models....') model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru', optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size) start=time.time() if os.path.isfile(model_dir): print 'loading checkpoint parameters....',model_dir model=load_model(model_dir,model) if goto_line!=0: train_data.goto_line(goto_line) print 'goto line:',goto_line log.info('training start...') for epoch in xrange(NEPOCH): costs=0 idx=0 error_rate_list=[] try: for (x,xmask),(y,ymask),label in train_data: idx+=1 if x.shape[-1]!=batch_size: continue cost,error_rate=model.train(x,xmask,y,ymask,label,lr) #print cost,error_rate #projected_output,cost= model.test(x, xmask, y, ymask,label) #print "projected_output shape:", projected_output.shape ##print "cnn_output shape:",cnn_output.shape #print "cost:",cost costs+=cost error_rate_list.append(error_rate) if np.isnan(cost) or np.isinf(cost): print 'Nan Or Inf detected!' print x.shape,y.shape print cost,error_rate return -1 if idx % disp_freq==0: log.info('epoch: %d, idx: %d cost: %.3f, Accuracy: %.3f '%(epoch,idx,costs/idx, np.mean(list(itertools.chain.from_iterable(error_rate_list))))) if idx%dump_freq==0: save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model) except Exception: print np.max(x),np.max(y) print x.shape,y.shape evaluate(train_data,valid_data, test_data,model) log.info("Finished. Time = " +str(time.time()-start))
def test(): log.info('loading dataset...') log.info('building models....') model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru',optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size) log.info('training start....') start=time.time() if os.path.isfile(model_dir): print 'loading checkpoint parameters....',model_dir model=load_model(model_dir,model) for epoch in xrange(NEPOCH): costs=[] idx=0 acc_list=[] train_data = TextIterator(train_file+".train."+str(epoch), n_batch=batch_size, maxlen=maxlen) valid_data = TextIterator(train_file+".valid."+str(epoch), n_batch=batch_size, maxlen=maxlen) for (x,xmask),(y,ymask),label in train_data: idx+=1 if x.shape[-1]!=batch_size: continue #print x.shape cost,acc=model.predict(x,xmask,y,ymask,label) #print cost #projected_output,cost= model.test(x, xmask, y, ymask,label) #print "projected_output shape:", projected_output.shape ##print "cnn_output shape:",cnn_output.shape #print "cost:",cost costs.append(cost) acc_list.append(acc) if np.isnan(np.mean(cost)) or np.isinf(np.mean(cost)): print 'Nan Or Inf detected!' print "x:",x print x.shape print 'y:',y print y.shape return -1 #log.info('dumping parameters....') #save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model) log.info('epoch: %d, cost: %.3f, Accuracy: %.3f ' % ( epoch,np.mean(list(itertools.chain.from_iterable(costs))), np.mean(list(itertools.chain.from_iterable(acc_list))))) loss, acc = evaluate(valid_data, model) log.info('validation cost: %.3f, Accuracy: %.3f' % (loss,acc)) log.info("Finished. Time = " +str(time.time()-start))