博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
将libFM模型变换成tensorflow可serving的形式
阅读量:5082 次
发布时间:2019-06-13

本文共 3148 字,大约阅读时间需要 10 分钟。

fm_model是libFM生成的模型

model.ckpt是可以tensorflow serving的模型结构

亲测输出正确。

 

代码:

1 import tensorflow as tf 2  3 # libFM model 4 def load_fm_model(file_name): 5     state = '' 6     fid = 0 7     max_fid = 0 8     w0 = 0.0 9     wj = {}10     v = {}11     k = 012     with open(file_name) as f:13         for line in f:14             line = line.rstrip()15             if 'global bias W0' in line:16                 state = 'w0'17                 fid = 018                 continue19             elif 'unary interactions Wj' in line:20                 state = 'wj'21                 fid = 022                 continue23             elif 'pairwise interactions Vj,f' in line:24                 state = 'v'25                 fid = 026                 continue27 28             if state == 'w0':29                 fv = float(line)30                 w0 = fv31             elif state == 'wj':32                 fv = float(line)33                 if fv != 0:34                     wj[fid] = fv35                 fid += 136                 max_fid = max(max_fid, fid)37             elif state == 'v':38                 fv = [float(_v) for _v in line.split(' ')]39                 k = len(fv)40                 if any([_v!=0 for _v in fv]):41                     v[fid] = fv42                 fid += 143                 max_fid = max(max_fid, fid)44     return w0, wj, v, k, max_fid45 46 _w0, _wj, _v, _k, _max_fid = load_fm_model('libfm_model_file')47 48 # max feature_id49 n = _max_fid50 print 'n', n51 52 # vector dimension53 k = _k54 print 'k', k55 56 # write fm algorithm57 w0 = tf.constant(_w0)58 w1c = tf.constant([_wj.get(fid, 0) for fid in xrange(n)], shape=[n])59 w1 = tf.Variable(w1c)60 #print 'w1', w161 62 vec = []63 for fid in xrange(n):64     vec.append(_v.get(fid, [0]*k))65 w2c = tf.constant(vec, shape=[n,k])66 w2 = tf.Variable(w2c)67 print 'w2', w268 69 # inputs70 x = tf.placeholder(tf.string, [None])71 batch = tf.shape(x)[0]72 x_s = tf.string_split(x)73 inds = tf.stack([tf.cast(x_s.indices[:,0], tf.int64), tf.string_to_number(x_s.values, tf.int64)], axis=1)74 x_sparse = tf.sparse.SparseTensor(indices=inds, values=tf.ones([tf.shape(inds)[0]]), dense_shape=[batch,n])75 x_ = tf.sparse.to_dense(x_sparse)76 77 w2_rep = tf.reshape(tf.tile(w2, [batch,1]), [-1,n,k])78 print 'w2_rep', w2_rep79 80 x_rep = tf.reshape(tf.tile(tf.reshape(x_, [batch*n, 1]), [1,k]), [-1,n,k])81 print 'x_rep', x_rep82 x_rep2 = tf.square(x_rep)83 84 #print tf.multiply(w2_rep,x_rep)85 #print tf.reduce_sum(tf.multiply(w2_rep,x_rep), axis=1)86 q = tf.square(tf.reduce_sum(tf.multiply(w2_rep, x_rep), axis=1))87 h = tf.reduce_sum(tf.multiply(tf.square(w2_rep), x_rep2), axis=1)88 89 y = w0 + tf.reduce_sum(tf.multiply(x_, w1), axis=1) +\90     1.0/2 * tf.reduce_sum(q-h, axis=1)91 92 saver = tf.train.Saver()93 with tf.Session() as sess:94     sess.run(tf.global_variables_initializer())95     #a = sess.run(y, feed_dict={x_:x_train,y_:y_train,batch:70})96     #print a97     save_path = "./model.ckpt"98     tf.saved_model.simple_save(sess, save_path, inputs={
"x": x}, outputs={
"y": y})

 

参考:

 (开头借鉴此文,但其有不少细节错误)

转载于:https://www.cnblogs.com/yaoyaohust/p/10472780.html

你可能感兴趣的文章
17Web服务器端控件
查看>>
历年春节日期
查看>>
关于消除MySQL输入错误后的警报声
查看>>
新开的博客先和大家打个招呼吧!
查看>>
小工具系列之json查看小工具
查看>>
SharePoint 元数据服务-PowerShell
查看>>
在Mac上安装MIT的scheme
查看>>
进程子进程linux系统编程之管道(一):匿名管道和pipe函数
查看>>
输出整数java处理大实数
查看>>
【leetcode】Valid Parentheses
查看>>
《构建高性能Web站点》观后感
查看>>
css transform 3D幻灯片特效
查看>>
批量下载网站图片的Python实用小工具
查看>>
Python 文件对象和方法
查看>>
java 反射机制--根据属性名获取属性值
查看>>
MVC模式在Java Web应用程序中的实例分析
查看>>
oracle update left join 写法
查看>>
VR中运动控制器的传送系统
查看>>
freemarker热部署(Intellij Idea)
查看>>
用 Go 编写一个简单的 WebSocket 推送服务
查看>>