Last active
February 25, 2025 17:18
-
-
Save angrypie/0e8f293a81d1e828e767a7c71e895739 to your computer and use it in GitHub Desktop.
zed auth token
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "zeta" | |
version = "0.1.0" | |
edition = "2021" | |
[dependencies] | |
anyhow = "1.0.96" | |
base64 = "0.22.1" | |
rand = "0.8.0" | |
rand_core = "0.9.2" | |
reqwest = "0.12.12" | |
rsa = "0.9.0" | |
serde = { version = "1.0.218", features = ["derive"] } | |
serde_json = "1.0.139" | |
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] } | |
url = "2.5.4" | |
webbrowser = "1.0.3" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// use mlua::prelude::*; | |
use base64::prelude::*; | |
use url::Url; | |
use rsa::{pkcs1::{DecodeRsaPublicKey, EncodeRsaPublicKey }, RsaPrivateKey, RsaPublicKey}; | |
use anyhow::{Context, Result}; | |
use std::net::{TcpListener, TcpStream}; | |
use std::io::{Write, Read}; | |
use std::time::Duration; | |
use rand::rngs::OsRng; | |
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; | |
use serde::{Serialize, Deserialize}; | |
struct PublicKey(RsaPublicKey); | |
impl TryFrom<PublicKey> for String { | |
type Error = anyhow::Error; | |
fn try_from(key: PublicKey) -> Result<Self> { | |
let bytes = key | |
.0 | |
.to_pkcs1_der() | |
.context("failed to serialize public key")?; | |
let string = BASE64_URL_SAFE.encode(&bytes); | |
Ok(string) | |
} | |
} | |
impl TryFrom<String> for PublicKey { | |
type Error = anyhow::Error; | |
fn try_from(value: String) -> Result<Self> { | |
let bytes = BASE64_URL_SAFE | |
.decode(&value) | |
.context("failed to base64-decode public key string")?; | |
let key = Self(RsaPublicKey::from_pkcs1_der(&bytes).context("failed to parse public key")?); | |
Ok(key) | |
} | |
} | |
// /// creates login url that should be opened from browser | |
// fn url(port: u16) -> Result<String> { | |
// let mut rng = thread_rng(); | |
// let bits = 2048; | |
// let private_key = RsaPrivateKey::new(&mut rng, bits).unwrap(); | |
// let public_key = RsaPublicKey::from(&private_key); | |
// //print both private and public key | |
// Ok(format!("https://zed.dev/native_app_signin?native_app_port={}&native_app_public_key={}", port, String::try_from(PublicKey(public_key)).unwrap())) | |
// } | |
// #[mlua::lua_module] | |
// fn zeta_auth(lua: &Lua) -> LuaResult<LuaTable> { | |
// let exports = lua.create_table()?; | |
// exports.set("url", lua.create_function(url)?)?; | |
// Ok(exports) | |
// } | |
// | |
#[tokio::main] | |
async fn main() -> Result<(), Box<dyn std::error::Error>> { | |
let (token_json, user_id) = obtain_token_and_id()?; | |
println!("Obtained auth token: {}", token_json); | |
println!("Obtained user id: {}", user_id); | |
// let response = send_prediction_request(&token).await; | |
// let tmp_token = "bW9MWmdyZTFWNUlKVUdFX2xDWDJ6RmtVSVUzVkhzWWdaSUR3Mk9pTnoyblNRdDJDTkZHbE1TbVlxbzJ4cm5NVQ=="; | |
// let response = send_prediction_request(&tmp_token).await; | |
// | |
// | |
// match response { | |
// Ok(response) => { | |
// println!("Response: {}", response); | |
// } | |
// Err(e) => { | |
// println!("Error: {}", e); | |
// } | |
// } | |
std::thread::sleep(Duration::from_secs(60)); | |
Ok(()) | |
} | |
//returns token and user_id | |
fn obtain_token_and_id() -> Result<(String, String), Box<dyn std::error::Error>> { | |
let port = 8181; // TODO: get free port | |
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))?; | |
println!("Server listening on port {}", port); | |
let bits = 2048; | |
let private_key = RsaPrivateKey::new(&mut OsRng, bits).unwrap(); | |
let public_key = RsaPublicKey::from(&private_key); | |
let url = format!( | |
"https://zed.dev/native_app_signin?native_app_port={}&native_app_public_key={}", | |
port, String::try_from(PublicKey(public_key)).unwrap() | |
); | |
println!("Please visit this URL to authenticate: {}", url); | |
// Open the URL in default browser | |
// if let Err(e) = webbrowser::open(&url) { | |
// println!("Failed to open browser: {}. Please open the URL manually.", e); | |
// } | |
// Set timeout for accepting connections | |
listener.set_nonblocking(true)?; | |
let mut attempts = 0; | |
let max_attempts = 60 * 10; // Wait up to 10 minutes | |
while attempts < max_attempts { | |
match listener.accept() { | |
Ok((mut stream, _)) => { | |
let data = handle_connection(&mut stream, &private_key)?; | |
return Ok(data); | |
} | |
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { | |
std::thread::sleep(Duration::from_secs(1)); | |
attempts += 1; | |
continue; | |
} | |
Err(e) => return Err(e.into()), | |
} | |
} | |
Err("Timeout waiting for authentication".into()) | |
} | |
fn handle_connection(stream: &mut TcpStream, private_key: &RsaPrivateKey) -> Result<(String, String), Box<dyn std::error::Error>> { | |
let mut buffer = [0; 1024]; | |
stream.read(&mut buffer)?; | |
let request = String::from_utf8_lossy(&buffer[..]); | |
// Extract encrypted token from request | |
if let Some((access_token, user_id)) = extract_token_from_request(&request) { | |
// Decrypt the token | |
let token = decrypt_token(access_token, private_key)?; | |
// Send success response | |
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\ | |
<html><body><h1>Authentication successful!</h1>\ | |
You can close this window now.</body></html>"; | |
stream.write_all(response.as_bytes())?; | |
Ok((token, user_id)) | |
} else { | |
// Send error response | |
let response = "HTTP/1.1 400 Bad Request\r\n\r\n"; | |
stream.write_all(response.as_bytes())?; | |
Err("No token found in request".into()) | |
} | |
} | |
fn extract_token_from_request(request: &str) -> Option<(String, String)> { | |
// Extract the URL part from the HTTP request | |
let request_line = request.lines().next()?; | |
let url_part = request_line.split_whitespace().nth(1)?; | |
// Parse the full URL (adding a dummy base since we have only path) | |
let full_url = format!("http://localhost{}", url_part); | |
let parsed_url = Url::parse(&full_url).ok()?; | |
// Get query parameters | |
let params: std::collections::HashMap<_, _> = parsed_url.query_pairs().collect(); | |
// Extract access_token and user_id | |
let access_token = params.get("access_token")?.to_string(); | |
println!("access_token: {}", access_token); | |
let user_id = params.get("user_id")?.to_string(); | |
Some((access_token, user_id)) | |
} | |
fn decrypt_token(encrypted_token: String, private_key: &RsaPrivateKey) -> Result<String, Box<dyn std::error::Error>> { | |
let encrypted_bytes = BASE64_URL_SAFE.decode(encrypted_token)?; | |
use rsa::Pkcs1v15Encrypt; | |
let decrypted_bytes = private_key.decrypt(Pkcs1v15Encrypt, &encrypted_bytes)?; | |
let token = String::from_utf8(decrypted_bytes)?; | |
Ok(token) | |
} | |
// rust multi line string | |
const events: &str = r#"User edited "models/customer.rb": | |
```diff @@ -2,5 +2,5 @@ | |
class Customer def initialize | |
- @name = name | |
+ @name = name.capitalize | |
@email = email | |
@phone = phone | |
```"#; | |
const excerpt: &str = r#"```src/components/modals/ZetaReview.tsx | |
<|editable_region_start|> | |
button: { | |
backgroundColor: '#007BFF', | |
padding: 10, | |
borderRadius: 5, | |
}, | |
buttonText: { | |
color: '#FFFFFF', | |
fontSize: 16, | |
}, | |
view: <|user_cursor_is_here|> | |
} | |
return ( | |
<View style={{ flex: 1, justifyContent: 'center', rowGap: 20 }}> | |
<TextButton | |
<|editable_region_end|> | |
```"#; | |
async fn send_prediction_request(auth_token: &str) -> Result<String, Box<dyn std::error::Error>> { | |
let base64_token = BASE64_URL_SAFE.encode(auth_token.as_bytes()); | |
let body = PredictEditsBody { | |
input_events: events.to_string(), | |
input_excerpt: excerpt.to_string(), | |
speculated_output: excerpt.to_string(), | |
outline: None, | |
can_collect_data: false, | |
diagnostic_groups: None, | |
}; | |
let body_str = serde_json::to_string(&body).unwrap(); | |
// Create client | |
let client = reqwest::Client::new(); | |
// Prepare headers | |
let mut headers = HeaderMap::new(); | |
headers.insert( | |
AUTHORIZATION, | |
HeaderValue::from_str(&format!("Bearer {}", base64_token))? | |
); | |
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); | |
// Send request | |
let response = client | |
.post("https://llm.zed.dev/predict_edits/v2") | |
.body(body_str) | |
.headers(headers) | |
.send() | |
.await?; | |
let result = response.text().await?; | |
Ok(result) | |
} | |
#[derive(Deserialize, Serialize)] | |
struct PredictEditsBody { | |
input_events: String, | |
input_excerpt: String, | |
speculated_output: String, | |
outline: Option<String>, | |
can_collect_data: bool, | |
diagnostic_groups: Option<()>, // do not use now | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment