Created
May 28, 2019 04:45
-
-
Save soumith/af07e5cf66174566d56c07bd983bcd89 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/torch.h> | |
#include <iostream> | |
#include <ATen/Parallel.h> | |
#include <ATen/ATen.h> | |
// using namespace at; | |
using namespace torch; | |
void submodular_select(Tensor candidate_points, Tensor features_done, Tensor features) | |
{ | |
int max_idx = -1; | |
float max_value = -1e-9; | |
for (int i=0; i < candidate_points.size(0); i++) | |
{ | |
std::vector<Tensor> temp; | |
if (candidate_points.item<int>() == 1) | |
{ | |
temp.push_back(features_done); | |
temp.push_back(features[candidate_points[i]]); | |
auto stacked_temp = stack(temp); | |
std::cout << std::get<0>(stacked_temp.max(1,false)) << std::endl; | |
float value = std::get<0>(stacked_temp.max(1,false)).sum().item<float>(); | |
if (value > max_value) | |
{ | |
max_value = value; | |
max_idx = i; | |
} | |
} | |
} | |
std::cout<<"Max Value" << max_value << std::endl; | |
std::cout << "Max Index" << max_idx << std::endl; | |
// return max_idx; | |
} | |
int main() | |
{ | |
int num_data_points = 6000; | |
int num_features = 256; | |
int batch_size = 64; | |
Tensor features = torch::randn({num_data_points, num_features}, dtype(kFloat)); | |
Tensor done = torch::randint(0, num_data_points, batch_size*4, dtype(kLong)); // Already Sampled Points | |
Tensor done_index = torch::arange(0, batch_size*3, kLong).squeeze(); | |
Tensor features_done = features.index(done_index); | |
Tensor candidate_points = torch::ones(num_data_points, dtype(kLong)); | |
auto scatter_val = torch::zeros(num_data_points, dtype(kLong)); | |
candidate_points = candidate_points.scatter_(0, done, scatter_val); | |
for (int batch=0; batch < batch_size; batch++) | |
{ | |
submodular_select(candidate_points, features_done, features); | |
// std::cout<<max_idx<<std::endl; | |
// done[batch_size*3+1] = max_idx; | |
// features_done = features[done[batch_size*3+batch+1]]; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment