Skip to content

Instantly share code, notes, and snippets.

@ikouchiha47
Last active August 1, 2024 20:35
Show Gist options
  • Save ikouchiha47/b0fe05e134b50db5f1acea8841f13ba8 to your computer and use it in GitHub Desktop.
Save ikouchiha47/b0fe05e134b50db5f1acea8841f13ba8 to your computer and use it in GitHub Desktop.
LLM (llama3) powered test generator in neovim
-- codebrewer.lua
--
-- test-writer.lua
--
local M = {}
local uv = vim.loop
local async = require 'plenary.async'
-- Function to get the entire buffer content
local function get_buffer_content()
return table.concat(vim.api.nvim_buf_get_lines(0, 0, -1, false), '\n')
end
-- Function to get the selected text
local function get_selected_text()
local start_pos = vim.fn.getpos "'<"
local end_pos = vim.fn.getpos "'>"
local lines = vim.api.nvim_buf_get_lines(0, start_pos[2] - 1, end_pos[2], false)
if #lines == 1 then
lines[1] = string.sub(lines[1], start_pos[3], end_pos[3])
else
lines[1] = string.sub(lines[1], start_pos[3])
lines[#lines] = string.sub(lines[#lines], 1, end_pos[3])
end
return table.concat(lines, '\n')
end
-- Function to create a new buffer and display the result
-- toggle_streaming and dump with append_to_result_buffer
local function display_result(result)
vim.schedule(function()
vim.cmd 'new'
local buf = vim.api.nvim_get_current_buf()
vim.api.nvim_buf_set_lines(buf, 0, -1, false, vim.split(result, '\n'))
vim.bo[buf].filetype = vim.bo.filetype -- Set the filetype to match the original buffer
vim.bo[buf].buftype = 'nofile'
end)
end
local function create_result_buffer()
vim.cmd 'new'
local buf = vim.api.nvim_get_current_buf()
vim.bo[buf].filetype = vim.bo.filetype -- Set the filetype to match the original buffer
vim.bo[buf].buftype = 'nofile'
return buf
end
local line_buffer = ''
-- Function to append text to the result buffer
local function append_to_result_buffer(buf, text)
vim.schedule(function()
local full_text = line_buffer .. text
local lines = vim.split(full_text, '\n')
line_buffer = lines[#lines]
table.remove(lines)
if #lines > 0 then
vim.api.nvim_buf_set_option(buf, 'modifiable', true)
-- Append complete lines to the buffer
vim.api.nvim_buf_set_lines(buf, -1, -1, false, lines)
vim.api.nvim_buf_set_option(buf, 'modifiable', false)
-- Move cursor to the end of the buffer
local last_line = vim.api.nvim_buf_line_count(buf)
vim.api.nvim_win_set_cursor(0, { last_line, 0 })
end
end)
end
M.prompt = {
test_prompt = function(language)
return string.format(
'Write tests for the %s code. There should be atleast one test case of success and failure. Add more test cases as required for each branching, for full code covergae. Only output the test codes and nothing else.',
language
)
end,
refactor_prompt = function(language)
return string.format(
'Refactor the %s code. Refactor should include but not limited to: improving time complexity, better design principle. code locality, single responsibilty. If a refactor is not needed, do not unnecesarily refactor the code',
language
)
end,
}
-- Function to run ollama command asynchronously
local function run_ollama_async(content, filetype, run_type)
local stdout = uv.new_pipe(false)
local stderr = uv.new_pipe(false)
local handle, pid
-- local result = ''
local error_msg = ''
local result_buf = create_result_buffer()
local function on_exit(code, signal)
stdout:read_stop()
stderr:read_stop()
stdout:close()
stderr:close()
if handle then
handle:close()
end
M.current_handle = nil
if line_buffer ~= '' then
append_to_result_buffer(result_buf, '\n')
line_buffer = ''
end
if code ~= 0 then
vim.schedule(function()
vim.api.nvim_echo({ { 'Generation failed: ' .. error_msg, 'ErrorMsg' } }, false, {})
end)
else
vim.schedule(function()
vim.api.nvim_echo({ { 'Generation complete', 'Normal' } }, false, {})
-- toggle_streaming
-- display_result(result)
end)
end
end
local function on_stdout(err, chunk)
assert(not err, err)
if chunk then
-- toggle_streaming
-- result = result .. chunk
append_to_result_buffer(result_buf, chunk)
end
end
local function on_stderr(err, chunk)
assert(not err, err)
if chunk then
error_msg = error_msg .. chunk
end
end
local prompt = 'Do nothing and exit'
if run_type == 'test' then
prompt = M.prompt.test_prompt(filetype)
elseif run_type == 'refactor' then
prompt = M.prompt.refactor_prompt(filetype)
end
local cmd = string.format("echo '%s' | ollama run llama3-coder '%s'", content:gsub("'", "'\\''"), prompt)
handle, pid = uv.spawn('sh', {
args = { '-c', cmd },
stdio = { nil, stdout, stderr },
}, on_exit)
uv.read_start(stdout, on_stdout)
uv.read_start(stderr, on_stderr)
M.current_handle = handle
return handle
end
-- Main function to refactor code
M.refactor = async.void(function(range)
if M.current_handle then
vim.api.nvim_echo({ { 'Generation already in progress', 'WarningMsg' } }, false, {})
return
end
local content
if range == 0 then
content = get_buffer_content()
else
content = get_selected_text()
end
local filetype = vim.bo.filetype
vim.api.nvim_echo({ { 'Generation in progress...', 'Normal' } }, false, {})
run_ollama_async(content, filetype, 'refactor')
end)
-- Main function to generate tests
M.generate_tests = async.void(function(range)
if M.current_handle then
vim.api.nvim_echo({ { 'Generation already in progress', 'WarningMsg' } }, false, {})
return
end
local content
if range == 0 then
content = get_buffer_content()
else
content = get_selected_text()
end
local filetype = vim.bo.filetype
vim.api.nvim_echo({ { 'Generation in progress...', 'Normal' } }, false, {})
run_ollama_async(content, filetype, 'test')
end)
-- Function to cancel the current test generation
function M.cancel_generation()
if M.current_handle then
uv.process_kill(M.current_handle, 'sigterm')
M.current_handle = nil
vim.api.nvim_echo({ { 'Test generation cancelled.', 'WarningMsg' } }, false, {})
else
vim.api.nvim_echo({ { 'No test generation in progress.', 'WarningMsg' } }, false, {})
end
end
-- Set up the plugin commands
return {
'nvim-lua/plenary.nvim',
config = function()
vim.api.nvim_create_user_command('GenerateTests', function(params)
M.generate_tests(params.range)
end, { range = true })
vim.api.nvim_create_user_command('Refactor', function(params)
M.refactor(params.range)
end, { range = true })
vim.api.nvim_create_user_command('CancelGeneration', M.cancel_generation, {})
end,
}
gen.model.tester:
ollama create llama3-coder -f Modelfile
ollama run llama3-coder < /dev/null
cache.model.tester:
ollama run llama3-coder < /dev/null
FROM llama3
PARAMETER temperature 0.2
SYSTEM You are a coding assistant who can write tests for the provided code in the provided language.
  • Have ollama running/installed.
  • Run the Makefile make gen.model.tester
  • You can change the context window in the Modelfile
  • Open something in vim, select a function/class/struct/nothing and run :GenerateTests
  • This would probably not work for very large text.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment