Created
March 20, 2024 15:24
-
-
Save snopoke/f88ed941c92bf04150a9e64d3da8c4d3 to your computer and use it in GitHub Desktop.
Langchain runnable split / merge POC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Demonstrate a method to create a pipeline that can handle dynamic splits in the pipeline based on the input type. | |
""" | |
import functools | |
import operator | |
from typing import Any | |
from langchain_core.callbacks import CallbackManagerForChainRun | |
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda, RunnableSerializable, patch_config | |
from langchain_core.runnables.base import RunnableEach | |
from langchain_core.runnables.utils import Input, Output | |
verbose = [] | |
def get_input_list(item): | |
out = list(range(item)) | |
verbose and print(f" get_input_list: in: {item}, out: {out}") | |
return out | |
def get_input_string(item): | |
out = str(item) * 2 | |
verbose and print(f" get_input_string: in: {item}, out: {out}") | |
return out | |
def process(item): | |
out = item * 2 | |
verbose and print(f" process: in: {item}, out: {out}") | |
return out | |
def process_merge(item): | |
out = "-".join([str(i) for i in item]) | |
verbose and print(f" process_merge: in: {item}, out: {out}") | |
return out | |
class FlexibleRunnableEach(RunnableEach): | |
"""Runnable that will split the pipeline if the input is a list. Otherwise, it will invoke the bound runnable.""" | |
# This shouldn't really extend `RunnableEach`, this is just for demonstration purposes | |
# It should rather be a standalone runnable extending RunnableSerializable | |
def _invoke( | |
self, | |
inputs: Input | list[Input], | |
run_manager: CallbackManagerForChainRun, | |
config: RunnableConfig, | |
**kwargs: Any, | |
) -> list[Output]: | |
if isinstance(inputs, list): | |
return self.batch(inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs) | |
return self.bound.invoke(inputs, patch_config(config, callbacks=run_manager.get_child())) | |
class MergeRunnable(RunnableSerializable[Input, Output]): | |
bound: Runnable[Input, Output] | |
class Config: | |
arbitrary_types_allowed = True | |
def _invoke( | |
self, | |
inputs: Input | list[Input], | |
run_manager: CallbackManagerForChainRun, | |
config: RunnableConfig, | |
**kwargs: Any, | |
) -> list[Output]: | |
if not isinstance(inputs, list): | |
inputs = [inputs] | |
return self.bound.invoke(inputs, patch_config(config, callbacks=run_manager.get_child())) | |
def invoke(self, inputs: Input | list[Input], config: RunnableConfig | None = None, **kwargs: Any) -> list[Output]: | |
return self._call_with_config(self._invoke, inputs, config, **kwargs) | |
def main(): | |
# set_debug(True) | |
pipes = [ | |
[RunnableLambda(get_input_list), RunnableLambda(get_input_list), RunnableLambda(process)], | |
[ | |
RunnableLambda(get_input_list), | |
MergeRunnable(bound=RunnableLambda(process_merge)), | |
RunnableLambda(get_input_string), | |
], | |
[RunnableLambda(get_input_string), RunnableLambda(get_input_string), RunnableLambda(process)], | |
[ | |
RunnableLambda(get_input_string), | |
RunnableLambda(get_input_string), | |
MergeRunnable(bound=RunnableLambda(process_merge)), | |
], | |
] | |
for pipe in pipes: | |
run_pipe(pipe) | |
def run_pipe(steps): | |
# Wrap each step in a runnable that will split the pipeline if the input is a list | |
# This could be made more explicit if we know that a step is expected to return a list or not | |
def _wrap_step(step: RunnableSerializable): | |
if isinstance(step, MergeRunnable): | |
return step | |
return FlexibleRunnableEach(bound=step) | |
chain = functools.reduce(operator.or_, map(_wrap_step, steps)) | |
print(f"\nRunning chain: {chain}\n================================") | |
for val in ([5, 6], 4): | |
output = chain.invoke(val) | |
print(f" input: '{val}', output: '{output}'") | |
print(" -----------------------------") | |
if __name__ == "__main__": | |
verbose.append(1) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment