Created
January 29, 2019 13:54
-
-
Save ByungSunBae/1c31629a3b1119a27ae559fad616a30e to your computer and use it in GitHub Desktop.
tensorflow simple operation in golang (key point : sum of elements of vector)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"github.com/kniren/gota/dataframe" | |
tf "github.com/tensorflow/tensorflow/tensorflow/go" | |
"github.com/tensorflow/tensorflow/tensorflow/go/op" | |
"os" | |
) | |
func errcheck(e error) { | |
if e != nil { | |
panic(e) | |
} | |
} | |
func main() { | |
// example.csv is dummy data that made in python. | |
// So, you can generate data in python or R. | |
file, err := os.Open("example.csv") | |
errcheck(err) | |
defer file.Close() | |
example := dataframe.ReadCSV(file) | |
//fmt.Println(example) | |
root := op.NewScope() | |
x_pl := op.Placeholder(root.SubScope("feature"), tf.Float, op.PlaceholderShape(tf.MakeShape(-1, 1))) | |
y_pl := op.Placeholder(root.SubScope("target"), tf.Float, op.PlaceholderShape(tf.MakeShape(-1, 1))) | |
fmt.Println(x_pl.Op.Name(), y_pl.Op.Name()) | |
multiplication := op.Mul(root, x_pl, y_pl) | |
// I spend time because of entering axis of op.Sum. | |
axis := op.Const(root, []int32{0, 1}) | |
sum_vl := op.Sum(root, multiplication, axis) | |
graph, err := root.Finalize() | |
errcheck(err) | |
var sess *tf.Session | |
sess, err = tf.NewSession(graph, &tf.SessionOptions{}) | |
errcheck(err) | |
x_col_tmp := example.Col("x").Float() | |
x_col := [400][1]float32{} | |
for idx, val := range x_col_tmp { | |
x_col[idx][0] = float32(val) | |
} | |
y_col_tmp := example.Col("y").Float() | |
y_col := [400][1]float32{} | |
for idx, val := range y_col_tmp { | |
y_col[idx][0] = float32(val) | |
} | |
var x_val, y_val *tf.Tensor | |
if x_val, err = tf.NewTensor(x_col); err != nil { | |
panic(err.Error()) | |
} | |
if y_val, err = tf.NewTensor(y_col); err != nil { | |
panic(err.Error()) | |
} | |
var results []*tf.Tensor | |
if results, err = sess.Run(map[tf.Output]*tf.Tensor{ | |
x_pl: x_val, | |
y_pl: y_val, | |
}, []tf.Output{sum_vl}, nil); err != nil { | |
panic(err.Error()) | |
} | |
for _, result := range results { | |
fmt.Println(result.Value().(float32)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment