Custom chat models
This notebook goes over how to create a custom chat model wrapper, in case you want to use your own chat model or a different wrapper than one that is directly supported in LangChain.
There are a few required things that a chat model needs to implement after extending the SimpleChatModel
class:
- A
_call
method that takes in a list of messages and call options (which includes things likestop
sequences), and returns a string. - A
_llmType
method that returns a string. Used for logging purposes only.
You can also implement the following optional method:
- A
_streamResponseChunks
method that returns anAsyncGenerator
and yieldsChatGenerationChunks
. This allows the LLM to support streaming outputs.
Letβs implement a very simple custom chat model that just echoes back the first n
characters of the input.
import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
export interface CustomChatModelInput extends BaseChatModelParams {
n: number;
}
export class CustomChatModel extends SimpleChatModel {
n: number;
constructor(fields: CustomChatModelInput) {
super(fields);
this.n = fields.n;
}
_llmType() {
return "custom";
}
async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error("No messages provided.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
return messages[0].content.slice(0, this.n);
}
async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of messages[0].content.slice(0, this.n)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: letter,
}),
text: letter,
});
await runManager?.handleLLMNewToken(letter);
}
}
}
We can now use this as any other chat model:
const chatModel = new CustomChatModel({ n: 4 });
await chatModel.invoke([["human", "I am an LLM"]]);
AIMessage {
content: 'I am',
additional_kwargs: {}
}
And support streaming:
const stream = await chatModel.stream([["human", "I am an LLM"]]);
for await (const chunk of stream) {
console.log(chunk);
}
AIMessageChunk {
content: 'I',
additional_kwargs: {}
}
AIMessageChunk {
content: ' ',
additional_kwargs: {}
}
AIMessageChunk {
content: 'a',
additional_kwargs: {}
}
AIMessageChunk {
content: 'm',
additional_kwargs: {}
}
Richer outputsβ
If you want to take advantage of LangChain's callback system for functionality like token tracking, you can extend the BaseChatModel
class and implement the lower level
_generate
method. It also takes a list of BaseMessage
s as input, but requires you to construct and return a ChatGeneration
object that permits additional metadata.
Here's an example:
import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
export interface AdvancedCustomChatModelOptions
extends BaseChatModelCallOptions {}
export interface AdvancedCustomChatModelParams extends BaseChatModelParams {
n: number;
}
export class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
n: number;
static lc_name(): string {
return "AdvancedCustomChatModel";
}
constructor(fields: AdvancedCustomChatModelParams) {
super(fields);
this.n = fields.n;
}
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
const content = messages[0].content.slice(0, this.n);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: { tokenUsage },
};
}
_llmType(): string {
return "advanced_custom_chat_model";
}
}
This will pass the additional returned information in callback events and in the `streamEvents method:
const chatModel = new AdvancedCustomChatModel({ n: 4 });
const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
version: "v1",
});
for await (const event of eventStream) {
if (event.event === "on_llm_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_llm_end",
"name": "AdvancedCustomChatModel",
"run_id": "b500b98d-bee5-4805-9b92-532a491f5c70",
"tags": [],
"metadata": {},
"data": {
"output": {
"generations": [
[
{
"message": {
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"messages",
"AIMessage"
],
"kwargs": {
"content": "I am",
"additional_kwargs": {}
}
},
"text": "I am"
}
]
],
"llmOutput": {
"tokenUsage": {
"usedTokens": 4
}
}
}
}
}