Last active
September 29, 2019 12:29
-
-
Save Nucs/4cd1220cc676945447f663ee93114578 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using System; | |
| using System.Collections.Generic; | |
| using System.Diagnostics; | |
| using System.Diagnostics.CodeAnalysis; | |
| using System.Drawing; | |
| using System.Drawing.Imaging; | |
| using System.IO; | |
| using System.Linq; | |
| using System.Runtime.CompilerServices; | |
| using System.Text; | |
| using System.Threading.Tasks; | |
| using NumSharp; | |
| using NumSharp.Backends; | |
| using NumSharp.Backends.Unmanaged; | |
| using Tensorflow; | |
| using static Tensorflow.Binding; | |
| using Buffer = System.Buffer; | |
| namespace ConsoleApp2 | |
| { | |
| class Program | |
| { | |
| static void Main(string[] args) | |
| { | |
| var dncnn = new DnCNN(); | |
| var l = new DirectoryInfo("./github_dataset").GetFiles().Select(f => new Bitmap(f.FullName)).ToList(); | |
| dncnn.Train(l, l.Select(b=>(Bitmap)b.Clone()).ToList(), 100); | |
| } | |
| } | |
| public class DnCNN | |
| { | |
| private const int batch_size = 128; | |
| Tensor X, Y_, Y; | |
| Tensor loss; | |
| Operation optimizer; | |
| Session sess; | |
| public DnCNN() | |
| { | |
| X = tf.placeholder(tf.float32, shape: (-1, -1, -1, 3), name: "input_image"); | |
| Y_ = tf.placeholder(tf.float32, shape: (-1, -1, -1, 3), name: "clean_image"); | |
| Y = BuildModel(X); | |
| loss = (1.0 / batch_size) * tf.nn.relu(Y_ - Y); | |
| optimizer = tf.train.AdamOptimizer(0.001f, name: "AdamOptimizer").minimize(loss); | |
| sess = new Session(); | |
| var init = tf.global_variables_initializer(); | |
| sess.run(init); | |
| } | |
| private Tensor BuildModel(Tensor input, bool is_training = true) | |
| { | |
| var output = tf.layers.conv2d(input, 64, new int[] {3, 3}, name: "conv1", padding: "same"); | |
| for (int i = 2; i < 20; i++) | |
| { | |
| output = tf.layers.conv2d(output, 64, new int[] {3, 3}, name: "conv" + i, padding: "same", use_bias: false); | |
| } | |
| output = tf.layers.conv2d(output, 3, new int[] {3, 3}, name: "conv20", padding: "same", use_bias: false); | |
| return input - output; | |
| } | |
| public void Train(List<Bitmap> inputImages, List<Bitmap> outputImages, int epochs = 1) | |
| { | |
| var sw = new Stopwatch(); | |
| sw.Start(); | |
| NDArray x_train = GenerateDataset(inputImages); | |
| NDArray y_train = GenerateDataset(outputImages); | |
| var saver = new Saver(); | |
| print($"Dataset created in {sw.ElapsedMilliseconds}ms"); | |
| sw.Restart(); | |
| for (int i = 0; i < epochs; i++) | |
| { | |
| sess.run(optimizer, (X, (x_train)), (Y_, (y_train))); | |
| // Calculate and display the batch loss and accuracy | |
| var result = sess.run(new[] {loss}, new FeedItem(X, x_train), new FeedItem(Y_, y_train)); | |
| print($"iter {i.ToString("000")}: {sw.ElapsedMilliseconds}ms"); | |
| sw.Restart(); | |
| saver.save(sess, @"E:\Downloads\ciao.ckpt"); | |
| } | |
| } | |
| public Bitmap Evaluate(List<Bitmap> inputImages) | |
| { | |
| var sw = new Stopwatch(); | |
| sw.Start(); | |
| NDArray x_eval = GenerateDataset(inputImages); | |
| var output = sess.run(Y, new FeedItem(X, x_eval)); | |
| sw.Stop(); | |
| print($"Inference done in {sw.ElapsedMilliseconds}ms"); | |
| output = 255 * output; | |
| return image(output[0].astype(NPTypeCode.Byte)); | |
| } | |
| public NDArray GenerateDataset(List<Bitmap> imgs) | |
| { | |
| return np.vstack(imgs.Select(img => img.ToNDArray(false, false)).ToArray()); | |
| } | |
| public Bitmap image(NDArray nd) | |
| { | |
| return nd.ToBitmap(nd.shape[2], nd.shape[1]); | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment