Skip to main content

Custom chat models

This notebook goes over how to create a custom chat model wrapper, in case you want to use your own chat model or a different wrapper than one that is directly supported in LangChain.

There are a few required things that a chat model needs to implement after extending the SimpleChatModel class:

  • A _call method that takes in a list of messages and call options (which includes things like stop sequences), and returns a string.
  • A _llmType method that returns a string. Used for logging purposes only.

You can also implement the following optional method:

  • A _streamResponseChunks method that returns an AsyncGenerator and yields ChatGenerationChunks. This allows the LLM to support streaming outputs.

Let’s implement a very simple custom chat model that just echoes back the first n characters of the input.

import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";

export interface CustomChatModelInput extends BaseChatModelParams {
n: number;
}

export class CustomChatModel extends SimpleChatModel {
n: number;

constructor(fields: CustomChatModelInput) {
super(fields);
this.n = fields.n;
}

_llmType() {
return "custom";
}

async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error("No messages provided.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
return messages[0].content.slice(0, this.n);
}

async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of messages[0].content.slice(0, this.n)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: letter,
}),
text: letter,
});
await runManager?.handleLLMNewToken(letter);
}
}
}

We can now use this as any other chat model:

const chatModel = new CustomChatModel({ n: 4 });

await chatModel.invoke([["human", "I am an LLM"]]);
AIMessage {
content: 'I am',
additional_kwargs: {}
}

And support streaming:

const stream = await chatModel.stream([["human", "I am an LLM"]]);

for await (const chunk of stream) {
console.log(chunk);
}
AIMessageChunk {
content: 'I',
additional_kwargs: {}
}
AIMessageChunk {
content: ' ',
additional_kwargs: {}
}
AIMessageChunk {
content: 'a',
additional_kwargs: {}
}
AIMessageChunk {
content: 'm',
additional_kwargs: {}
}

Richer outputs​

If you want to take advantage of LangChain's callback system for functionality like token tracking, you can extend the BaseChatModel class and implement the lower level _generate method. It also takes a list of BaseMessages as input, but requires you to construct and return a ChatGeneration object that permits additional metadata. Here's an example:

import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";

export interface AdvancedCustomChatModelOptions
extends BaseChatModelCallOptions {}

export interface AdvancedCustomChatModelParams extends BaseChatModelParams {
n: number;
}

export class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
n: number;

static lc_name(): string {
return "AdvancedCustomChatModel";
}

constructor(fields: AdvancedCustomChatModelParams) {
super(fields);
this.n = fields.n;
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
const content = messages[0].content.slice(0, this.n);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: { tokenUsage },
};
}

_llmType(): string {
return "advanced_custom_chat_model";
}
}

This will pass the additional returned information in callback events and in the `streamEvents method:

const chatModel = new AdvancedCustomChatModel({ n: 4 });

const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
version: "v1",
});

for await (const event of eventStream) {
if (event.event === "on_llm_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_llm_end",
"name": "AdvancedCustomChatModel",
"run_id": "b500b98d-bee5-4805-9b92-532a491f5c70",
"tags": [],
"metadata": {},
"data": {
"output": {
"generations": [
[
{
"message": {
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"messages",
"AIMessage"
],
"kwargs": {
"content": "I am",
"additional_kwargs": {}
}
},
"text": "I am"
}
]
],
"llmOutput": {
"tokenUsage": {
"usedTokens": 4
}
}
}
}
}

Help us out by providing feedback on this documentation page: