Skip to content

Instantly share code, notes, and snippets.

@angrypie
Last active February 25, 2025 17:18
Show Gist options
  • Save angrypie/0e8f293a81d1e828e767a7c71e895739 to your computer and use it in GitHub Desktop.
Save angrypie/0e8f293a81d1e828e767a7c71e895739 to your computer and use it in GitHub Desktop.
zed auth token
[package]
name = "zeta"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1.0.96"
base64 = "0.22.1"
rand = "0.8.0"
rand_core = "0.9.2"
reqwest = "0.12.12"
rsa = "0.9.0"
serde = { version = "1.0.218", features = ["derive"] }
serde_json = "1.0.139"
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }
url = "2.5.4"
webbrowser = "1.0.3"
// use mlua::prelude::*;
use base64::prelude::*;
use url::Url;
use rsa::{pkcs1::{DecodeRsaPublicKey, EncodeRsaPublicKey }, RsaPrivateKey, RsaPublicKey};
use anyhow::{Context, Result};
use std::net::{TcpListener, TcpStream};
use std::io::{Write, Read};
use std::time::Duration;
use rand::rngs::OsRng;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde::{Serialize, Deserialize};
struct PublicKey(RsaPublicKey);
impl TryFrom<PublicKey> for String {
type Error = anyhow::Error;
fn try_from(key: PublicKey) -> Result<Self> {
let bytes = key
.0
.to_pkcs1_der()
.context("failed to serialize public key")?;
let string = BASE64_URL_SAFE.encode(&bytes);
Ok(string)
}
}
impl TryFrom<String> for PublicKey {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self> {
let bytes = BASE64_URL_SAFE
.decode(&value)
.context("failed to base64-decode public key string")?;
let key = Self(RsaPublicKey::from_pkcs1_der(&bytes).context("failed to parse public key")?);
Ok(key)
}
}
// /// creates login url that should be opened from browser
// fn url(port: u16) -> Result<String> {
// let mut rng = thread_rng();
// let bits = 2048;
// let private_key = RsaPrivateKey::new(&mut rng, bits).unwrap();
// let public_key = RsaPublicKey::from(&private_key);
// //print both private and public key
// Ok(format!("https://zed.dev/native_app_signin?native_app_port={}&native_app_public_key={}", port, String::try_from(PublicKey(public_key)).unwrap()))
// }
// #[mlua::lua_module]
// fn zeta_auth(lua: &Lua) -> LuaResult<LuaTable> {
// let exports = lua.create_table()?;
// exports.set("url", lua.create_function(url)?)?;
// Ok(exports)
// }
//
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let (token_json, user_id) = obtain_token_and_id()?;
println!("Obtained auth token: {}", token_json);
println!("Obtained user id: {}", user_id);
// let response = send_prediction_request(&token).await;
// let tmp_token = "bW9MWmdyZTFWNUlKVUdFX2xDWDJ6RmtVSVUzVkhzWWdaSUR3Mk9pTnoyblNRdDJDTkZHbE1TbVlxbzJ4cm5NVQ==";
// let response = send_prediction_request(&tmp_token).await;
//
//
// match response {
// Ok(response) => {
// println!("Response: {}", response);
// }
// Err(e) => {
// println!("Error: {}", e);
// }
// }
std::thread::sleep(Duration::from_secs(60));
Ok(())
}
//returns token and user_id
fn obtain_token_and_id() -> Result<(String, String), Box<dyn std::error::Error>> {
let port = 8181; // TODO: get free port
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))?;
println!("Server listening on port {}", port);
let bits = 2048;
let private_key = RsaPrivateKey::new(&mut OsRng, bits).unwrap();
let public_key = RsaPublicKey::from(&private_key);
let url = format!(
"https://zed.dev/native_app_signin?native_app_port={}&native_app_public_key={}",
port, String::try_from(PublicKey(public_key)).unwrap()
);
println!("Please visit this URL to authenticate: {}", url);
// Open the URL in default browser
// if let Err(e) = webbrowser::open(&url) {
// println!("Failed to open browser: {}. Please open the URL manually.", e);
// }
// Set timeout for accepting connections
listener.set_nonblocking(true)?;
let mut attempts = 0;
let max_attempts = 60 * 10; // Wait up to 10 minutes
while attempts < max_attempts {
match listener.accept() {
Ok((mut stream, _)) => {
let data = handle_connection(&mut stream, &private_key)?;
return Ok(data);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(Duration::from_secs(1));
attempts += 1;
continue;
}
Err(e) => return Err(e.into()),
}
}
Err("Timeout waiting for authentication".into())
}
fn handle_connection(stream: &mut TcpStream, private_key: &RsaPrivateKey) -> Result<(String, String), Box<dyn std::error::Error>> {
let mut buffer = [0; 1024];
stream.read(&mut buffer)?;
let request = String::from_utf8_lossy(&buffer[..]);
// Extract encrypted token from request
if let Some((access_token, user_id)) = extract_token_from_request(&request) {
// Decrypt the token
let token = decrypt_token(access_token, private_key)?;
// Send success response
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
<html><body><h1>Authentication successful!</h1>\
You can close this window now.</body></html>";
stream.write_all(response.as_bytes())?;
Ok((token, user_id))
} else {
// Send error response
let response = "HTTP/1.1 400 Bad Request\r\n\r\n";
stream.write_all(response.as_bytes())?;
Err("No token found in request".into())
}
}
fn extract_token_from_request(request: &str) -> Option<(String, String)> {
// Extract the URL part from the HTTP request
let request_line = request.lines().next()?;
let url_part = request_line.split_whitespace().nth(1)?;
// Parse the full URL (adding a dummy base since we have only path)
let full_url = format!("http://localhost{}", url_part);
let parsed_url = Url::parse(&full_url).ok()?;
// Get query parameters
let params: std::collections::HashMap<_, _> = parsed_url.query_pairs().collect();
// Extract access_token and user_id
let access_token = params.get("access_token")?.to_string();
println!("access_token: {}", access_token);
let user_id = params.get("user_id")?.to_string();
Some((access_token, user_id))
}
fn decrypt_token(encrypted_token: String, private_key: &RsaPrivateKey) -> Result<String, Box<dyn std::error::Error>> {
let encrypted_bytes = BASE64_URL_SAFE.decode(encrypted_token)?;
use rsa::Pkcs1v15Encrypt;
let decrypted_bytes = private_key.decrypt(Pkcs1v15Encrypt, &encrypted_bytes)?;
let token = String::from_utf8(decrypted_bytes)?;
Ok(token)
}
// rust multi line string
const events: &str = r#"User edited "models/customer.rb":
```diff @@ -2,5 +2,5 @@
class Customer def initialize
- @name = name
+ @name = name.capitalize
@email = email
@phone = phone
```"#;
const excerpt: &str = r#"```src/components/modals/ZetaReview.tsx
<|editable_region_start|>
button: {
backgroundColor: '#007BFF',
padding: 10,
borderRadius: 5,
},
buttonText: {
color: '#FFFFFF',
fontSize: 16,
},
view: <|user_cursor_is_here|>
}
return (
<View style={{ flex: 1, justifyContent: 'center', rowGap: 20 }}>
<TextButton
<|editable_region_end|>
```"#;
async fn send_prediction_request(auth_token: &str) -> Result<String, Box<dyn std::error::Error>> {
let base64_token = BASE64_URL_SAFE.encode(auth_token.as_bytes());
let body = PredictEditsBody {
input_events: events.to_string(),
input_excerpt: excerpt.to_string(),
speculated_output: excerpt.to_string(),
outline: None,
can_collect_data: false,
diagnostic_groups: None,
};
let body_str = serde_json::to_string(&body).unwrap();
// Create client
let client = reqwest::Client::new();
// Prepare headers
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", base64_token))?
);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
// Send request
let response = client
.post("https://llm.zed.dev/predict_edits/v2")
.body(body_str)
.headers(headers)
.send()
.await?;
let result = response.text().await?;
Ok(result)
}
#[derive(Deserialize, Serialize)]
struct PredictEditsBody {
input_events: String,
input_excerpt: String,
speculated_output: String,
outline: Option<String>,
can_collect_data: bool,
diagnostic_groups: Option<()>, // do not use now
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment