AI model selection, registration, and type-safe usage
enum ModelType {
// Text generation models
TEXT_SMALL = 'text:small', // Fast, simple responses
TEXT_MEDIUM = 'text:medium', // Balanced performance
TEXT_LARGE = 'text:large', // Complex reasoning
// Embedding models
TEXT_EMBEDDING = 'text:embedding',
// Image models
IMAGE_GENERATION = 'image:generation',
IMAGE_ANALYSIS = 'image:analysis',
// Audio models
SPEECH_TO_TEXT = 'speech:to:text',
TEXT_TO_SPEECH = 'text:to:speech',
// Specialized models
CODE_GENERATION = 'code:generation',
CLASSIFICATION = 'classification'
}
// Text generation parameters
interface TextGenerationParams {
prompt: string;
messages?: Message[];
temperature?: number; // 0.0 - 2.0
maxTokens?: number;
topP?: number;
frequencyPenalty?: number;
presencePenalty?: number;
stopSequences?: string[];
systemPrompt?: string;
}
// Embedding parameters
interface EmbeddingParams {
input: string | string[];
model?: string;
dimensions?: number;
}
// Image generation parameters
interface ImageGenerationParams {
prompt: string;
negativePrompt?: string;
width?: number;
height?: number;
steps?: number;
seed?: number;
style?: string;
}
// Speech-to-text parameters
interface SpeechToTextParams {
audio: Buffer | string; // Audio data or URL
language?: string;
format?: 'json' | 'text' | 'srt';
temperature?: number;
}
// Register a model handler
runtime.registerModel(
ModelType.TEXT_LARGE,
async (runtime, params) => {
// Model implementation
const response = await callAPI(params);
return response.text;
},
'openai', // provider name
100 // priority (higher = preferred)
);
// Register multiple models from a plugin
const modelPlugin: Plugin = {
name: 'openai-models',
models: [
{
type: ModelType.TEXT_LARGE,
handler: handleTextGeneration,
provider: 'openai',
priority: 100
},
{
type: ModelType.TEXT_EMBEDDING,
handler: handleEmbedding,
provider: 'openai',
priority: 100
}
]
};
type ModelHandler<T = any, R = any> = (
runtime: IAgentRuntime,
params: T
) => Promise<R>;
interface ModelRegistration {
type: ModelTypeName;
handler: ModelHandler;
provider: string;
priority: number;
}
// Text generation
const response = await runtime.useModel(
ModelType.TEXT_LARGE,
{
prompt: "Explain quantum computing",
temperature: 0.7,
maxTokens: 500
}
);
// Get embeddings
const embedding = await runtime.useModel(
ModelType.TEXT_EMBEDDING,
{ input: "Text to embed" }
);
// Generate image
const image = await runtime.useModel(
ModelType.IMAGE_GENERATION,
{
prompt: "A sunset over mountains",
width: 1024,
height: 1024,
steps: 50
}
);
// Speech to text
const transcript = await runtime.useModel(
ModelType.SPEECH_TO_TEXT,
{
audio: audioBuffer,
language: 'en',
format: 'json'
}
);
// Use specific provider
const response = await runtime.useModel(
ModelType.TEXT_LARGE,
{ prompt: "Hello" },
'anthropic' // Force specific provider
);
// Get available providers
const providers = runtime.getModelProviders(ModelType.TEXT_LARGE);
console.log('Available providers:', providers);
// ['openai', 'anthropic', 'ollama']
// Higher priority providers are preferred
runtime.registerModel(ModelType.TEXT_LARGE, handlerA, 'provider-a', 100);
runtime.registerModel(ModelType.TEXT_LARGE, handlerB, 'provider-b', 90);
runtime.registerModel(ModelType.TEXT_LARGE, handlerC, 'provider-c', 80);
// Will use provider-a (priority 100)
await runtime.useModel(ModelType.TEXT_LARGE, params);
// Automatic fallback on failure
class ModelRouter {
async useModel(type: ModelType, params: any, preferredProvider?: string) {
const providers = this.getProvidersByPriority(type, preferredProvider);
for (const provider of providers) {
try {
return await provider.handler(this.runtime, params);
} catch (error) {
this.logger.warn(`Provider ${provider.name} failed:`, error);
// Try next provider
if (provider !== providers[providers.length - 1]) {
continue;
}
// All providers failed
throw new Error(`No providers available for ${type}`);
}
}
}
}
class OpenAIModelProvider {
private client: OpenAI;
constructor(runtime: IAgentRuntime) {
const apiKey = runtime.getSetting('OPENAI_API_KEY');
this.client = new OpenAI({ apiKey });
}
async handleTextGeneration(params: TextGenerationParams) {
const response = await this.client.chat.completions.create({
model: params.model || 'gpt-4',
messages: params.messages || [
{ role: 'user', content: params.prompt }
],
temperature: params.temperature,
max_tokens: params.maxTokens,
top_p: params.topP,
frequency_penalty: params.frequencyPenalty,
presence_penalty: params.presencePenalty,
stop: params.stopSequences
});
return response.choices[0].message.content;
}
async handleEmbedding(params: EmbeddingParams) {
const response = await this.client.embeddings.create({
model: 'text-embedding-3-small',
input: params.input,
dimensions: params.dimensions
});
return Array.isArray(params.input)
? response.data.map(d => d.embedding)
: response.data[0].embedding;
}
register(runtime: IAgentRuntime) {
runtime.registerModel(
ModelType.TEXT_LARGE,
this.handleTextGeneration.bind(this),
'openai',
100
);
runtime.registerModel(
ModelType.TEXT_EMBEDDING,
this.handleEmbedding.bind(this),
'openai',
100
);
}
}
class AnthropicModelProvider {
private client: Anthropic;
constructor(runtime: IAgentRuntime) {
const apiKey = runtime.getSetting('ANTHROPIC_API_KEY');
this.client = new Anthropic({ apiKey });
}
async handleTextGeneration(params: TextGenerationParams) {
const response = await this.client.messages.create({
model: params.model || 'claude-3-opus-20240229',
messages: params.messages || [
{ role: 'user', content: params.prompt }
],
max_tokens: params.maxTokens || 1000,
temperature: params.temperature,
system: params.systemPrompt
});
return response.content[0].text;
}
register(runtime: IAgentRuntime) {
runtime.registerModel(
ModelType.TEXT_LARGE,
this.handleTextGeneration.bind(this),
'anthropic',
95 // Slightly lower priority than OpenAI
);
}
}
class OllamaModelProvider {
private baseUrl: string;
constructor(runtime: IAgentRuntime) {
this.baseUrl = runtime.getSetting('OLLAMA_BASE_URL') || 'http://localhost:11434';
}
async handleTextGeneration(params: TextGenerationParams) {
const response = await fetch(`${this.baseUrl}/api/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: params.model || 'llama2',
prompt: params.prompt,
temperature: params.temperature,
options: {
num_predict: params.maxTokens,
top_p: params.topP,
stop: params.stopSequences
}
})
});
const data = await response.json();
return data.response;
}
async handleEmbedding(params: EmbeddingParams) {
const response = await fetch(`${this.baseUrl}/api/embeddings`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: params.model || 'all-minilm',
prompt: params.input
})
});
const data = await response.json();
return data.embedding;
}
register(runtime: IAgentRuntime) {
// Lower priority for local models
runtime.registerModel(
ModelType.TEXT_LARGE,
this.handleTextGeneration.bind(this),
'ollama',
50
);
runtime.registerModel(
ModelType.TEXT_EMBEDDING,
this.handleEmbedding.bind(this),
'ollama',
50
);
}
}
// Runtime automatically selects best available provider
const response = await runtime.useModel(
ModelType.TEXT_LARGE,
{ prompt: "Hello" }
);
// Selection order:
// 1. Check if preferred provider specified
// 2. Sort available providers by priority
// 3. Try each provider until success
// 4. Cache successful provider for session
// Select model based on context
async function selectModelForTask(runtime: IAgentRuntime, task: string) {
const complexity = analyzeComplexity(task);
if (complexity < 0.3) {
// Simple task - use small model
return runtime.useModel(ModelType.TEXT_SMALL, {
prompt: task,
temperature: 0.3
});
} else if (complexity < 0.7) {
// Medium complexity - use medium model
return runtime.useModel(ModelType.TEXT_MEDIUM, {
prompt: task,
temperature: 0.5
});
} else {
// Complex task - use large model
return runtime.useModel(ModelType.TEXT_LARGE, {
prompt: task,
temperature: 0.7,
maxTokens: 2000
});
}
}
// Track and optimize model usage costs
class CostOptimizedModelRouter {
private costs = {
'openai': { [ModelType.TEXT_LARGE]: 0.03, [ModelType.TEXT_EMBEDDING]: 0.0001 },
'anthropic': { [ModelType.TEXT_LARGE]: 0.025 },
'ollama': { [ModelType.TEXT_LARGE]: 0, [ModelType.TEXT_EMBEDDING]: 0 }
};
async useModel(type: ModelType, params: any, maxCost?: number) {
const providers = this.getProvidersByCost(type, maxCost);
for (const provider of providers) {
try {
const result = await provider.handler(this.runtime, params);
// Track usage
this.trackUsage(provider.name, type, params);
return result;
} catch (error) {
continue;
}
}
}
private getProvidersByCost(type: ModelType, maxCost?: number) {
return this.providers
.filter(p => {
const cost = this.costs[p.name]?.[type] || Infinity;
return !maxCost || cost <= maxCost;
})
.sort((a, b) => {
const costA = this.costs[a.name]?.[type] || Infinity;
const costB = this.costs[b.name]?.[type] || Infinity;
return costA - costB;
});
}
}
class ModelCache {
private cache = new Map<string, { result: any; timestamp: number }>();
private ttl = 60 * 60 * 1000; // 1 hour
getCacheKey(type: ModelType, params: any): string {
return `${type}:${JSON.stringify(params)}`;
}
get(type: ModelType, params: any): any | null {
const key = this.getCacheKey(type, params);
const cached = this.cache.get(key);
if (!cached) return null;
if (Date.now() - cached.timestamp > this.ttl) {
this.cache.delete(key);
return null;
}
return cached.result;
}
set(type: ModelType, params: any, result: any) {
const key = this.getCacheKey(type, params);
this.cache.set(key, {
result,
timestamp: Date.now()
});
}
}
// Use with runtime
const cache = new ModelCache();
async function cachedModelCall(runtime: IAgentRuntime, type: ModelType, params: any) {
// Check cache
const cached = cache.get(type, params);
if (cached) return cached;
// Make call
const result = await runtime.useModel(type, params);
// Cache result
cache.set(type, params, result);
return result;
}
interface ModelUsageMetrics {
provider: string;
modelType: ModelType;
count: number;
totalTokens: number;
totalDuration: number;
avgDuration: number;
errors: number;
cost: number;
}
class ModelMonitor {
private metrics = new Map<string, ModelUsageMetrics>();
async trackUsage(
provider: string,
type: ModelType,
params: any,
result: any,
duration: number
) {
const key = `${provider}:${type}`;
if (!this.metrics.has(key)) {
this.metrics.set(key, {
provider,
modelType: type,
count: 0,
totalTokens: 0,
totalDuration: 0,
avgDuration: 0,
errors: 0,
cost: 0
});
}
const metrics = this.metrics.get(key);
metrics.count++;
metrics.totalDuration += duration;
metrics.avgDuration = metrics.totalDuration / metrics.count;
// Estimate tokens (simplified)
if (type === ModelType.TEXT_LARGE) {
const tokens = this.estimateTokens(params.prompt) +
this.estimateTokens(result);
metrics.totalTokens += tokens;
metrics.cost += this.calculateCost(provider, type, tokens);
}
// Emit metrics event
await this.runtime.emit(EventType.MODEL_USED, {
runtime: this.runtime,
modelType: type,
provider,
params,
result,
duration,
metrics
});
}
}
async function modelCallWithRetry(
runtime: IAgentRuntime,
type: ModelType,
params: any,
maxRetries = 3
) {
let lastError: Error;
for (let i = 0; i < maxRetries; i++) {
try {
return await runtime.useModel(type, params);
} catch (error) {
lastError = error;
// Check if retryable
if (isRateLimitError(error)) {
// Wait with exponential backoff
const delay = Math.pow(2, i) * 1000;
await new Promise(resolve => setTimeout(resolve, delay));
continue;
}
// Non-retryable error
throw error;
}
}
throw lastError;
}
// Fallback to simpler models on failure
async function modelCallWithFallback(
runtime: IAgentRuntime,
params: TextGenerationParams
) {
const modelHierarchy = [
ModelType.TEXT_LARGE,
ModelType.TEXT_MEDIUM,
ModelType.TEXT_SMALL
];
for (const modelType of modelHierarchy) {
try {
return await runtime.useModel(modelType, params);
} catch (error) {
runtime.logger.warn(`Model ${modelType} failed, trying fallback`);
if (modelType === ModelType.TEXT_SMALL) {
// Last option failed
throw error;
}
}
}
}