跳转至

项目3: 多模态应用

⚠️ 时效性说明:本章涉及前沿模型/价格/榜单等信息,可能随版本快速变化;请以论文原文、官方发布页和 API 文档为准。

难度: ⭐⭐⭐⭐⭐ 高级 时间: 20-25小时 涉及知识: 多模态模型、图像处理、文本生成、前后端集成


📖 项目概述

项目背景

传统的AI应用通常只处理单一模态的数据,如纯文本或纯图像。但现实世界的信息是多模态的,文本、图像、音频、视频等多种形式的信息经常交织在一起。多模态AI应用能够同时处理和理解多种模态的数据,实现更丰富的交互体验。

多模态应用的应用场景包括: - 图像描述生成 - 文生图功能 - 图文对话 - 图像编辑 - 多模态搜索 - 视觉问答

项目目标

构建一个功能丰富的多模态应用,能够: - 上传和处理图像 - 生成图像描述 - 根据文本生成图像 - 进行图文对话 - 编辑图像 - 提供友好的Web界面

技术栈

  • 大模型: OpenAI GPT-4V / Claude 3 Vision / 通义千问VL
  • 图像生成: Stable Diffusion / DALL-E 3
  • 图像理解: CLIP / BLIP
  • 后端框架: FastAPI
  • 前端框架: React + Streamlit
  • 图像处理: Pillow, OpenCV
  • 向量数据库: ChromaDB

🏗️ 项目结构

Text Only
multimodal-app/
├── app/                      # 应用主目录
│   ├── __init__.py
│   ├── main.py              # FastAPI主应用
│   ├── config.py            # 配置文件
│   ├── image/               # 图像处理模块
│   │   ├── __init__.py
│   │   ├── processor.py     # 图像处理器
│   │   ├── generator.py     # 图像生成器
│   │   └── editor.py       # 图像编辑器
│   ├── vision/              # 视觉理解模块
│   │   ├── __init__.py
│   │   ├── caption.py       # 图像描述
│   │   ├── vqa.py          # 视觉问答
│   │   └── search.py       # 图像搜索
│   ├── models/              # 模型管理
│   │   ├── __init__.py
│   │   ├── text_to_image.py # 文生图模型
│   │   ├── image_to_text.py # 图生文模型
│   │   └── multimodal.py   # 多模态模型
│   └── api/                 # API路由
│       ├── __init__.py
│       ├── image.py         # 图像API
│       ├── caption.py       # 描述API
│       ├── generate.py      # 生成API
│       └── chat.py          # 聊天API
├── frontend/                 # 前端目录
│   ├── streamlit/           # Streamlit应用
│   │   ├── app.py
│   │   └── pages/
│   └── react/              # React应用
│       ├── src/
│       └── package.json
├── data/                    # 数据目录
│   ├── images/             # 图像存储
│   ├── models/             # 模型缓存
│   └── embeddings/         # 嵌入向量
├── tests/                   # 测试目录
│   ├── test_image.py
│   ├── test_caption.py
│   └── test_generate.py
├── utils/                   # 工具函数
│   ├── __init__.py
│   ├── image_utils.py      # 图像工具
│   └── logger.py           # 日志工具
├── requirements.txt         # Python依赖
├── Dockerfile              # Docker配置
├── docker-compose.yml      # Docker Compose配置
└── README.md              # 项目说明

🎯 核心功能

1. 图像上传与处理

  • 图像上传: 支持多种图像格式
  • 图像预处理: 调整大小、格式转换
  • 图像增强: 亮度、对比度调整
  • 图像压缩: 优化图像大小

2. 图像描述生成

  • 自动描述: 生成图像的自然语言描述
  • 详细描述: 生成详细的图像描述
  • 标签生成: 提取图像中的关键标签
  • OCR识别: 识别图像中的文字

3. 文生图功能

  • 文本生成: 根据文本描述生成图像
  • 风格控制: 控制生成图像的风格
  • 参数调整: 调整生成参数
  • 批量生成: 批量生成多张图像

4. 图文对话

  • 多轮对话: 支持连续的图文对话
  • 上下文理解: 理解对话上下文
  • 图像引用: 在对话中引用图像
  • 历史记录: 保存对话历史

5. 图像编辑

  • 局部编辑: 编辑图像的特定区域
  • 风格迁移: 改变图像风格
  • 图像修复: 修复损坏的图像
  • 图像扩展: 扩展图像内容

6. 多模态搜索

  • 文本搜图: 用文本搜索相似图像
  • 图搜图: 用图像搜索相似图像
  • 混合搜索: 结合文本和图像搜索
  • 结果排序: 智能排序搜索结果

💻 代码实现

1. 配置文件 (app/config.py)

Python
"""
多模态应用配置文件
"""
import os
from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    """应用配置"""

    # API配置
    API_HOST: str = "0.0.0.0"
    API_PORT: int = 8000
    API_PREFIX: str = "/api/v1"

    # OpenAI配置
    OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
    OPENAI_MODEL: str = "gpt-4o"
    OPENAI_MAX_TOKENS: int = 500

    # 图像生成配置
    DALL_E_API_KEY: str = os.getenv("DALL_E_API_KEY", "")
    DALL_E_MODEL: str = "dall-e-3"
    DALL_E_SIZE: str = "1024x1024"
    DALL_E_QUALITY: str = "standard"

    # Stable Diffusion配置
    SD_API_URL: str = os.getenv("SD_API_URL", "")
    SD_MODEL: str = "runwayml/stable-diffusion-v1-5"
    SD_STEPS: int = 20
    SD_GUIDANCE_SCALE: float = 7.5

    # 图像处理配置
    MAX_IMAGE_SIZE: int = 10 * 1024 * 1024  # 10MB
    ALLOWED_FORMATS: list[str] = [".jpg", ".jpeg", ".png", ".gif", ".webp"]
    MAX_DIMENSION: int = 2048

    # 存储配置
    IMAGE_DIR: str = "./data/images"
    MODEL_CACHE_DIR: str = "./data/models"
    EMBEDDING_DIR: str = "./data/embeddings"

    # 向量数据库配置
    CHROMA_PERSIST_DIR: str = "./data/embeddings"
    CHROMA_COLLECTION_NAME: str = "images"

    # CLIP配置
    CLIP_MODEL: str = "openai/clip-vit-base-patch32"

    class Config:
        env_file = ".env"
        case_sensitive = True

# 全局配置实例
settings = Settings()

2. 图像处理器 (app/image/processor.py)

Python
"""
图像处理模块
"""
from pathlib import Path
from PIL import Image, ImageEnhance, ImageFilter
import io
import base64
from app.config import settings

class ImageProcessor:
    """图像处理器"""

    def __init__(self):
        """初始化图像处理器"""
        # 创建存储目录
        Path(settings.IMAGE_DIR).mkdir(parents=True, exist_ok=True)

    def load_image(self, file_path: str) -> Image.Image:
        """
        加载图像

        Args:
            file_path: 图像文件路径

        Returns:
            PIL Image对象
        """
        return Image.open(file_path)

    def save_image(
        self,
        image: Image.Image,
        filename: str,
        format: str = "PNG"
    ) -> str:
        """
        保存图像

        Args:
            image: PIL Image对象
            filename: 文件名
            format: 图像格式

        Returns:
            保存路径
        """
        file_path = Path(settings.IMAGE_DIR) / filename
        image.save(file_path, format=format)
        return str(file_path)

    def resize_image(
        self,
        image: Image.Image,
        max_dimension: int | None = None
    ) -> Image.Image:
        """
        调整图像大小

        Args:
            image: PIL Image对象
            max_dimension: 最大尺寸

        Returns:
            调整后的图像
        """
        if max_dimension is None:
            max_dimension = settings.MAX_DIMENSION

        width, height = image.size

        if width <= max_dimension and height <= max_dimension:
            return image

        # 计算缩放比例
        ratio = min(max_dimension / width, max_dimension / height)
        new_width = int(width * ratio)
        new_height = int(height * ratio)

        return image.resize((new_width, new_height), Image.LANCZOS)

    def enhance_brightness(
        self,
        image: Image.Image,
        factor: float = 1.0
    ) -> Image.Image:
        """
        调整亮度

        Args:
            image: PIL Image对象
            factor: 亮度因子 (1.0为原始值)

        Returns:
            调整后的图像
        """
        enhancer = ImageEnhance.Brightness(image)
        return enhancer.enhance(factor)

    def enhance_contrast(
        self,
        image: Image.Image,
        factor: float = 1.0
    ) -> Image.Image:
        """
        调整对比度

        Args:
            image: PIL Image对象
            factor: 对比度因子 (1.0为原始值)

        Returns:
            调整后的图像
        """
        enhancer = ImageEnhance.Contrast(image)
        return enhancer.enhance(factor)

    def apply_filter(
        self,
        image: Image.Image,
        filter_type: str = "sharpen"
    ) -> Image.Image:
        """
        应用滤镜

        Args:
            image: PIL Image对象
            filter_type: 滤镜类型 (sharpen, blur, edge_enhance)

        Returns:
            处理后的图像
        """
        filters = {
            "sharpen": ImageFilter.SHARPEN,
            "blur": ImageFilter.BLUR,
            "edge_enhance": ImageFilter.EDGE_ENHANCE,
            "smooth": ImageFilter.SMOOTH
        }

        if filter_type not in filters:
            raise ValueError(f"不支持的滤镜类型: {filter_type}")

        return image.filter(filters[filter_type])

    def crop_image(
        self,
        image: Image.Image,
        box: tuple[int, int, int, int]
    ) -> Image.Image:
        """
        裁剪图像

        Args:
            image: PIL Image对象
            box: 裁剪区域 (left, top, right, bottom)

        Returns:
            裁剪后的图像
        """
        return image.crop(box)

    def rotate_image(
        self,
        image: Image.Image,
        angle: float
    ) -> Image.Image:
        """
        旋转图像

        Args:
            image: PIL Image对象
            angle: 旋转角度(度)

        Returns:
            旋转后的图像
        """
        return image.rotate(angle, expand=True)

    def image_to_base64(
        self,
        image: Image.Image,
        format: str = "PNG"
    ) -> str:
        """
        将图像转换为Base64编码

        Args:
            image: PIL Image对象
            format: 图像格式

        Returns:
            Base64编码字符串
        """
        buffer = io.BytesIO()
        image.save(buffer, format=format)
        image_bytes = buffer.getvalue()
        return base64.b64encode(image_bytes).decode('utf-8')

    def base64_to_image(
        self,
        base64_string: str
    ) -> Image.Image:
        """
        将Base64编码转换为图像

        Args:
            base64_string: Base64编码字符串

        Returns:
            PIL Image对象
        """
        image_bytes = base64.b64decode(base64_string)
        return Image.open(io.BytesIO(image_bytes))

    def validate_image(
        self,
        file_path: str,
        max_size: int | None = None
    ) -> tuple[bool, str | None]:
        """
        验证图像

        Args:
            file_path: 图像文件路径
            max_size: 最大文件大小(字节)

        Returns:
            (是否有效, 错误信息)
        """
        if max_size is None:
            max_size = settings.MAX_IMAGE_SIZE

        file_path = Path(file_path)

        # 检查文件是否存在
        if not file_path.exists():
            return False, "文件不存在"

        # 检查文件大小
        if file_path.stat().st_size > max_size:
            return False, f"文件大小超过限制: {max_size} bytes"

        # 检查文件格式
        if file_path.suffix.lower() not in settings.ALLOWED_FORMATS:
            return False, f"不支持的文件格式: {file_path.suffix}"

        # 尝试打开图像
        try:  # try/except捕获异常,防止程序崩溃
            Image.open(file_path)
        except Exception as e:
            return False, f"无法打开图像: {str(e)}"

        return True, None

3. 图像描述生成器 (app/vision/caption.py)

Python
"""
图像描述生成模块
"""
from openai import OpenAI
from PIL import Image
import base64
import io
from app.config import settings

class ImageCaptioner:
    """图像描述生成器"""

    def __init__(self, api_key: str | None = None):
        """
        初始化图像描述生成器

        Args:
            api_key: OpenAI API密钥
        """
        if api_key is None:
            api_key = settings.OPENAI_API_KEY

        self.client = OpenAI(api_key=api_key)
        self.model = settings.OPENAI_MODEL

    def generate_caption(
        self,
        image: Image.Image,
        detail: str = "auto",
        max_tokens: int | None = None
    ) -> str:
        """
        生成图像描述

        Args:
            image: PIL Image对象
            detail: 描述详细程度 (auto, low, high)
            max_tokens: 最大token数

        Returns:
            图像描述
        """
        if max_tokens is None:
            max_tokens = settings.OPENAI_MAX_TOKENS

        # 将图像转换为base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')

        # 构建消息
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "请详细描述这张图像的内容。"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{image_base64}",
                            "detail": detail
                        }
                    }
                ]
            }
        ]

        # 调用API
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            max_tokens=max_tokens
        )

        return response.choices[0].message.content

    def generate_tags(
        self,
        image: Image.Image,
        num_tags: int = 5
    ) -> list[str]:
        """
        生成图像标签

        Args:
            image: PIL Image对象
            num_tags: 标签数量

        Returns:
            标签列表
        """
        # 将图像转换为base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')

        # 构建消息
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": f"请为这张图像生成{num_tags}个关键词标签,用逗号分隔。"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{image_base64}"
                        }
                    }
                ]
            }
        ]

        # 调用API
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            max_tokens=100
        )

        # 解析标签
        tags_text = response.choices[0].message.content
        tags = [tag.strip() for tag in tags_text.split(",")]  # 链式调用:strip去除空白

        return tags[:num_tags]

    def extract_text(
        self,
        image: Image.Image
    ) -> str:
        """
        提取图像中的文字(OCR)

        Args:
            image: PIL Image对象

        Returns:
            提取的文字
        """
        # 将图像转换为base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')

        # 构建消息
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "请提取这张图像中的所有文字内容。"
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{image_base64}"
                        }
                    }
                ]
            }
        ]

        # 调用API
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            max_tokens=500
        )

        return response.choices[0].message.content

4. 图像生成器 (app/image/generator.py)

Python
"""
图像生成模块
"""
from openai import OpenAI
import requests
from app.config import settings

class ImageGenerator:
    """图像生成器"""

    def __init__(self, api_key: str | None = None):
        """
        初始化图像生成器

        Args:
            api_key: OpenAI API密钥
        """
        if api_key is None:
            api_key = settings.OPENAI_API_KEY

        self.client = OpenAI(api_key=api_key)
        self.model = settings.DALL_E_MODEL
        self.size = settings.DALL_E_SIZE
        self.quality = settings.DALL_E_QUALITY

    def generate_image(
        self,
        prompt: str,
        n: int = 1,
        size: str | None = None,
        quality: str | None = None
    ) -> list[str]:
        """
        生成图像

        Args:
            prompt: 文本提示
            n: 生成数量
            size: 图像尺寸
            quality: 图像质量

        Returns:
            图像URL列表
        """
        if size is None:
            size = self.size
        if quality is None:
            quality = self.quality

        # 调用DALL-E API
        response = self.client.images.generate(
            model=self.model,
            prompt=prompt,
            n=n,
            size=size,
            quality=quality
        )

        # 提取URL
        urls = [image.url for image in response.data]

        return urls

    def edit_image(
        self,
        image_path: str,
        mask_path: str | None = None,
        prompt: str,
        n: int = 1
    ) -> list[str]:
        """
        编辑图像

        Args:
            image_path: 原始图像路径
            mask_path: 掩码图像路径(可选)
            prompt: 编辑提示
            n: 生成数量

        Returns:
            图像URL列表
        """
        # 读取图像
        with open(image_path, "rb") as f:  # with自动管理文件关闭
            image = f.read()

        mask = None
        if mask_path:
            with open(mask_path, "rb") as f:
                mask = f.read()

        # 调用DALL-E编辑API
        response = self.client.images.edit(
            image=image,
            mask=mask,
            prompt=prompt,
            n=n,
            size=self.size
        )

        # 提取URL
        urls = [image.url for image in response.data]

        return urls

    def create_variation(
        self,
        image_path: str,
        n: int = 1
    ) -> list[str]:
        """
        创建图像变体

        Args:
            image_path: 原始图像路径
            n: 生成数量

        Returns:
            图像URL列表
        """
        # 读取图像
        with open(image_path, "rb") as f:
            image = f.read()

        # 调用DALL-E变体API
        response = self.client.images.create_variation(
            image=image,
            n=n,
            size=self.size
        )

        # 提取URL
        urls = [image.url for image in response.data]

        return urls

    def download_image(self, url: str, save_path: str) -> str:
        """
        下载图像

        Args:
            url: 图像URL
            save_path: 保存路径

        Returns:
            保存路径
        """
        response = requests.get(url)
        response.raise_for_status()

        with open(save_path, "wb") as f:
            f.write(response.content)

        return save_path

5. 视觉问答 (app/vision/vqa.py)

Python
"""
视觉问答模块
"""
from typing import Any
from openai import OpenAI
from PIL import Image
import base64
import io
from app.config import settings

class VisualQA:
    """视觉问答"""

    def __init__(self, api_key: str | None = None):
        """
        初始化视觉问答

        Args:
            api_key: OpenAI API密钥
        """
        if api_key is None:
            api_key = settings.OPENAI_API_KEY

        self.client = OpenAI(api_key=api_key)
        self.model = settings.OPENAI_MODEL
        self.conversation_history: list[dict[str, Any]] = []

    def ask(
        self,
        image: Image.Image,
        question: str,
        max_tokens: int | None = None
    ) -> str:
        """
        提问

        Args:
            image: PIL Image对象
            question: 问题
            max_tokens: 最大token数

        Returns:
            答案
        """
        if max_tokens is None:
            max_tokens = settings.OPENAI_MAX_TOKENS

        # 将图像转换为base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')

        # 构建消息
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": question
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{image_base64}"
                        }
                    }
                ]
            }
        ]

        # 添加历史对话
        messages.extend(self.conversation_history)

        # 调用API
        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            max_tokens=max_tokens
        )

        answer = response.choices[0].message.content

        # 保存到历史
        self.conversation_history.append({
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": question
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{image_base64}"
                    }
                }
            ]
        })
        self.conversation_history.append({
            "role": "assistant",
            "content": answer
        })

        return answer

    def clear_history(self):
        """清除对话历史"""
        self.conversation_history = []

    def get_history(self) -> list[dict[str, Any]]:
        """获取对话历史"""
        return self.conversation_history

6. FastAPI主应用 (app/main.py)

Python
"""
FastAPI主应用
"""
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import uvicorn

from app.config import settings
from app.api import image, caption, generate, chat

# 创建FastAPI应用
app = FastAPI(
    title="多模态应用",
    description="基于GPT-4V的多模态AI应用",
    version="1.0.0"
)

# 配置CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 注册路由
app.include_router(
    image.router,
    prefix=settings.API_PREFIX,
    tags=["image"]
)
app.include_router(
    caption.router,
    prefix=settings.API_PREFIX,
    tags=["caption"]
)
app.include_router(
    generate.router,
    prefix=settings.API_PREFIX,
    tags=["generate"]
)
app.include_router(
    chat.router,
    prefix=settings.API_PREFIX,
    tags=["chat"]
)

@app.get("/")
async def root():  # async def定义协程函数
    """根路径"""
    return {
        "message": "多模态应用",
        "version": "1.0.0",
        "docs": "/docs"
    }

@app.get("/health")
async def health_check():
    """健康检查"""
    return {"status": "healthy"}

if __name__ == "__main__":
    uvicorn.run(
        "app.main:app",
        host=settings.API_HOST,
        port=settings.API_PORT,
        reload=True
    )

7. 图像API (app/api/image.py)

Python
"""
图像API
"""
from fastapi import APIRouter, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from pathlib import Path
import uuid

from app.config import settings
from app.image.processor import ImageProcessor

router = APIRouter()

# 初始化图像处理器
image_processor = ImageProcessor()

@router.post("/image/upload")
async def upload_image(file: UploadFile = File(...)):
    """
    上传图像

    Args:
        file: 上传的文件

    Returns:
        上传结果
    """
    # 检查文件格式
    file_extension = Path(file.filename).suffix.lower()
    if file_extension not in settings.ALLOWED_FORMATS:
        raise HTTPException(
            status_code=400,
            detail=f"不支持的文件格式: {file_extension}"
        )

    # 生成唯一文件名
    file_id = str(uuid.uuid4())
    file_path = Path(settings.IMAGE_DIR) / f"{file_id}{file_extension}"

    # 保存文件
    content = await file.read()  # await等待异步操作完成
    with open(file_path, "wb") as f:
        f.write(content)

    # 验证图像
    is_valid, error = image_processor.validate_image(str(file_path))
    if not is_valid:
        file_path.unlink()
        raise HTTPException(status_code=400, detail=error)

    return JSONResponse(
        status_code=200,
        content={
            "message": "图像上传成功",
            "image_id": file_id,
            "file_name": file.filename,
            "file_path": str(file_path)
        }
    )

@router.get("/image/{image_id}/info")
async def get_image_info(image_id: str):
    """
    获取图像信息

    Args:
        image_id: 图像ID

    Returns:
        图像信息
    """
    # 查找图像文件
    image_dir = Path(settings.IMAGE_DIR)
    image_files = list(image_dir.glob(f"{image_id}.*"))

    if not image_files:
        raise HTTPException(status_code=404, detail="图像不存在")

    image_path = image_files[0]
    image = image_processor.load_image(str(image_path))

    return {
        "image_id": image_id,
        "file_name": image_path.name,
        "file_path": str(image_path),
        "size": image.size,
        "mode": image.mode,
        "format": image.format
    }

@router.delete("/image/{image_id}")
async def delete_image(image_id: str):
    """
    删除图像

    Args:
        image_id: 图像ID

    Returns:
        删除结果
    """
    # 查找图像文件
    image_dir = Path(settings.IMAGE_DIR)
    image_files = list(image_dir.glob(f"{image_id}.*"))

    if not image_files:
        raise HTTPException(status_code=404, detail="图像不存在")

    # 删除文件
    for image_file in image_files:
        image_file.unlink()

    return {"message": "图像删除成功"}

8. 描述API (app/api/caption.py)

Python
"""
图像描述API
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel

from app.config import settings
from app.vision.caption import ImageCaptioner
from app.image.processor import ImageProcessor
from pathlib import Path

router = APIRouter()

# 初始化组件
captioner = ImageCaptioner()
image_processor = ImageProcessor()

class CaptionRequest(BaseModel):  # Pydantic BaseModel:自动数据验证和序列化
    """描述请求"""
    image_id: str
    detail: str | None = "auto"
    max_tokens: int | None = None

class TagsRequest(BaseModel):
    """标签请求"""
    image_id: str
    num_tags: int | None = 5

@router.post("/caption/generate")
async def generate_caption(request: CaptionRequest):
    """
    生成图像描述

    Args:
        request: 描述请求

    Returns:
        描述结果
    """
    # 查找图像文件
    image_dir = Path(settings.IMAGE_DIR)
    image_files = list(image_dir.glob(f"{request.image_id}.*"))

    if not image_files:
        raise HTTPException(status_code=404, detail="图像不存在")

    # 加载图像
    image = image_processor.load_image(str(image_files[0]))

    # 生成描述
    caption = captioner.generate_caption(
        image,
        detail=request.detail,
        max_tokens=request.max_tokens
    )

    return {
        "image_id": request.image_id,
        "caption": caption
    }

@router.post("/caption/tags")
async def generate_tags(request: TagsRequest):
    """
    生成图像标签

    Args:
        request: 标签请求

    Returns:
        标签结果
    """
    # 查找图像文件
    image_dir = Path(settings.IMAGE_DIR)
    image_files = list(image_dir.glob(f"{request.image_id}.*"))

    if not image_files:
        raise HTTPException(status_code=404, detail="图像不存在")

    # 加载图像
    image = image_processor.load_image(str(image_files[0]))

    # 生成标签
    tags = captioner.generate_tags(
        image,
        num_tags=request.num_tags
    )

    return {
        "image_id": request.image_id,
        "tags": tags
    }

@router.post("/caption/ocr")
async def extract_text(image_id: str):
    """
    提取图像中的文字

    Args:
        image_id: 图像ID

    Returns:
        提取的文字
    """
    # 查找图像文件
    image_dir = Path(settings.IMAGE_DIR)
    image_files = list(image_dir.glob(f"{image_id}.*"))

    if not image_files:
        raise HTTPException(status_code=404, detail="图像不存在")

    # 加载图像
    image = image_processor.load_image(str(image_files[0]))

    # 提取文字
    text = captioner.extract_text(image)

    return {
        "image_id": image_id,
        "text": text
    }

9. 生成API (app/api/generate.py)

Python
"""
图像生成API
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel

from app.config import settings
from app.image.generator import ImageGenerator
from pathlib import Path
import uuid

router = APIRouter()

# 初始化图像生成器
generator = ImageGenerator()

class GenerateRequest(BaseModel):
    """生成请求"""
    prompt: str
    n: int | None = 1
    size: str | None = None
    quality: str | None = None

@router.post("/generate/image")
async def generate_image(request: GenerateRequest):
    """
    生成图像

    Args:
        request: 生成请求

    Returns:
        生成结果
    """
    try:
        # 生成图像
        urls = generator.generate_image(
            prompt=request.prompt,
            n=request.n,
            size=request.size,
            quality=request.quality
        )

        # 下载图像
        image_ids = []
        for i, url in enumerate(urls):  # enumerate同时获取索引和元素
            image_id = str(uuid.uuid4())
            file_path = Path(settings.IMAGE_DIR) / f"{image_id}.png"
            generator.download_image(url, str(file_path))
            image_ids.append(image_id)

        return {
            "prompt": request.prompt,
            "image_ids": image_ids,
            "urls": urls
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

10. 聊天API (app/api/chat.py)

Python
"""
图文对话API
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel

from app.config import settings
from app.vision.vqa import VisualQA
from app.image.processor import ImageProcessor
from pathlib import Path

router = APIRouter()

# 初始化组件
vqa = VisualQA()
image_processor = ImageProcessor()

class ChatRequest(BaseModel):
    """聊天请求"""
    image_id: str
    question: str
    max_tokens: int | None = None

@router.post("/chat")
async def chat(request: ChatRequest):
    """
    图文对话

    Args:
        request: 聊天请求

    Returns:
        对话结果
    """
    # 查找图像文件
    image_dir = Path(settings.IMAGE_DIR)
    image_files = list(image_dir.glob(f"{request.image_id}.*"))

    if not image_files:
        raise HTTPException(status_code=404, detail="图像不存在")

    # 加载图像
    image = image_processor.load_image(str(image_files[0]))

    # 提问
    answer = vqa.ask(
        image,
        request.question,
        max_tokens=request.max_tokens
    )

    return {
        "image_id": request.image_id,
        "question": request.question,
        "answer": answer
    }

@router.post("/chat/history/clear")
async def clear_chat_history():
    """
    清除对话历史

    Returns:
        清除结果
    """
    vqa.clear_history()
    return {"message": "对话历史已清除"}

@router.get("/chat/history")
async def get_chat_history():
    """
    获取对话历史

    Returns:
        对话历史
    """
    history = vqa.get_history()
    return {"history": history}

11. Streamlit前端 (frontend/streamlit/app.py)

Python
"""
Streamlit前端应用
"""
import streamlit as st
import requests
from pathlib import Path
from PIL import Image
import io
import base64

# 配置页面
st.set_page_config(
    page_title="多模态应用",
    page_icon="🖼️",
    layout="wide"
)

# API配置
API_BASE_URL = "http://localhost:8000/api/v1"

def main():
    """主函数"""
    st.title("🖼️ 多模态应用")

    # 侧边栏
    with st.sidebar:
        st.header("功能选择")

        # 功能选择
        page = st.radio(
            "选择功能",
            ["图像上传", "图像描述", "图像生成", "图文对话"]
        )

    # 根据选择显示不同页面
    if page == "图像上传":
        show_upload_page()
    elif page == "图像描述":
        show_caption_page()
    elif page == "图像生成":
        show_generate_page()
    elif page == "图文对话":
        show_chat_page()

def show_upload_page():
    """显示上传页面"""
    st.header("📤 图像上传")

    # 上传图像
    uploaded_file = st.file_uploader(
        "上传图像",
        type=["jpg", "jpeg", "png", "gif", "webp"]
    )

    if uploaded_file:
        # 显示图像
        image = Image.open(uploaded_file)
        st.image(image, caption="上传的图像", use_column_width=True)

        # 上传到服务器
        if st.button("上传"):
            with st.spinner("正在上传..."):
                files = {"file": uploaded_file}
                response = requests.post(
                    f"{API_BASE_URL}/image/upload",
                    files=files
                )

                if response.status_code == 200:
                    result = response.json()
                    st.success("上传成功!")
                    st.json(result)
                    st.session_state["current_image_id"] = result["image_id"]
                else:
                    st.error(f"上传失败: {response.text}")

def show_caption_page():
    """显示描述页面"""
    st.header("📝 图像描述")

    # 选择图像ID
    image_id = st.text_input("图像ID", value=st.session_state.get("current_image_id", ""))

    if not image_id:
        st.warning("请先上传图像")
        return

    # 获取图像信息
    try:
        response = requests.get(f"{API_BASE_URL}/image/{image_id}/info")
        if response.status_code == 200:
            info = response.json()

            # 显示图像
            image = Image.open(info["file_path"])
            st.image(image, caption=info["file_name"], use_column_width=True)
    except:
        st.error("无法加载图像")
        return

    # 生成描述
    col1, col2 = st.columns(2)
    with col1:
        detail = st.selectbox("描述详细程度", ["auto", "low", "high"])
    with col2:
        max_tokens = st.number_input("最大Token数", value=500, min_value=100, max_value=1000)

    if st.button("生成描述"):
        with st.spinner("正在生成描述..."):
            response = requests.post(
                f"{API_BASE_URL}/caption/generate",
                json={
                    "image_id": image_id,
                    "detail": detail,
                    "max_tokens": max_tokens
                }
            )

            if response.status_code == 200:
                result = response.json()
                st.success("描述生成成功!")
                st.write(result["caption"])
            else:
                st.error(f"生成失败: {response.text}")

    # 生成标签
    st.divider()
    st.subheader("🏷️ 生成标签")

    num_tags = st.number_input("标签数量", value=5, min_value=1, max_value=10)

    if st.button("生成标签"):
        with st.spinner("正在生成标签..."):
            response = requests.post(
                f"{API_BASE_URL}/caption/tags",
                json={
                    "image_id": image_id,
                    "num_tags": num_tags
                }
            )

            if response.status_code == 200:
                result = response.json()
                st.success("标签生成成功!")
                tags = result["tags"]
                for tag in tags:
                    st.badge(tag)
            else:
                st.error(f"生成失败: {response.text}")

def show_generate_page():
    """显示生成页面"""
    st.header("🎨 图像生成")

    # 输入提示
    prompt = st.text_area("图像描述", height=100)

    # 参数设置
    col1, col2 = st.columns(2)
    with col1:
        n = st.number_input("生成数量", value=1, min_value=1, max_value=4)
    with col2:
        size = st.selectbox("图像尺寸", ["1024x1024", "512x512", "256x256"])

    # 生成图像
    if st.button("生成图像"):
        if not prompt:
            st.warning("请输入图像描述")
            return

        with st.spinner("正在生成图像..."):
            response = requests.post(
                f"{API_BASE_URL}/generate/image",
                json={
                    "prompt": prompt,
                    "n": n,
                    "size": size
                }
            )

            if response.status_code == 200:
                result = response.json()
                st.success("图像生成成功!")

                # 显示生成的图像
                for image_id in result["image_ids"]:
                    try:
                        info_response = requests.get(f"{API_BASE_URL}/image/{image_id}/info")
                        if info_response.status_code == 200:
                            info = info_response.json()
                            image = Image.open(info["file_path"])
                            st.image(image, caption=info["file_name"], use_column_width=True)
                    except:
                        pass
            else:
                st.error(f"生成失败: {response.text}")

def show_chat_page():
    """显示对话页面"""
    st.header("💬 图文对话")

    # 选择图像ID
    image_id = st.text_input("图像ID", value=st.session_state.get("current_image_id", ""))

    if not image_id:
        st.warning("请先上传图像")
        return

    # 获取图像信息
    try:
        response = requests.get(f"{API_BASE_URL}/image/{image_id}/info")
        if response.status_code == 200:
            info = response.json()

            # 显示图像
            image = Image.open(info["file_path"])
            st.image(image, caption=info["file_name"], use_column_width=True)
    except:
        st.error("无法加载图像")
        return

    # 初始化对话历史
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []

    # 显示对话历史
    for message in st.session_state.chat_history:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    # 用户输入
    if prompt := st.chat_input("请输入您的问题..."):
        # 显示用户消息
        st.session_state.chat_history.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        # 获取AI回复
        with st.chat_message("assistant"):
            with st.spinner("正在思考..."):
                response = requests.post(
                    f"{API_BASE_URL}/chat",
                    json={
                        "image_id": image_id,
                        "question": prompt
                    }
                )

                if response.status_code == 200:
                    result = response.json()
                    st.markdown(result["answer"])
                    st.session_state.chat_history.append({
                        "role": "assistant",
                        "content": result["answer"]
                    })
                else:
                    st.error(f"请求失败: {response.text}")

if __name__ == "__main__":
    main()

12. 依赖文件 (requirements.txt)

Text Only
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
pydantic==2.5.0
pydantic-settings==2.1.0
openai==1.3.0
pillow==10.1.0
requests==2.31.0
streamlit==1.28.0
python-dotenv==1.0.0

13. 环境变量文件 (.env.example)

Text Only
# OpenAI API配置
OPENAI_API_KEY=your_openai_api_key_here

# 图像生成配置
DALL_E_API_KEY=your_openai_api_key_here

🚀 部署说明

1. 本地部署

步骤1: 克隆项目

Bash
git clone https://github.com/yourusername/multimodal-app.git
cd multimodal-app

步骤2: 创建虚拟环境

Bash
python -m venv venv

# 激活虚拟环境
# Windows
venv\Scripts\activate
# Linux/Mac
source venv/bin/activate

步骤3: 安装依赖

Bash
pip install -r requirements.txt

步骤4: 配置环境变量

Bash
cp .env.example .env
# 编辑.env文件,填入API密钥

步骤5: 启动后端服务

Bash
python -m app.main

步骤6: 启动前端服务

Bash
cd frontend/streamlit
streamlit run app.py

2. Docker部署

Bash
# 构建镜像
docker build -t multimodal-app .

# 运行容器
docker run -d \
  --name multimodal-app \
  -p 8000:8000 \
  -p 8501:8501 \
  -e OPENAI_API_KEY=your_api_key_here \
  multimodal-app

🔧 扩展方向

1. 功能扩展

  • 视频处理: 支持视频上传和处理
  • 音频处理: 支持音频识别和生成
  • 3D模型: 支持3D模型处理
  • AR/VR: 集成AR/VR功能
  • 实时处理: 支持实时图像处理

2. 性能优化

  • 缓存机制: 缓存处理结果
  • 异步处理: 全面支持异步操作
  • 批量处理: 支持批量图像处理
  • GPU加速: 使用GPU加速处理

3. 用户体验

  • 实时预览: 实时预览处理效果
  • 拖拽上传: 支持拖拽上传
  • 历史记录: 保存操作历史
  • 分享功能: 支持分享结果

4. 企业功能

  • 权限管理: 用户权限控制
  • 审计日志: 操作审计
  • API限流: API访问限流
  • 数据备份: 自动数据备份

📚 学习收获

完成本项目后,你将掌握:

  • 多模态AI: 理解多模态AI的原理
  • 图像处理: 掌握图像处理技术
  • 图像生成: 理解文生图技术
  • 视觉问答: 实现视觉问答系统
  • API集成: 熟练使用各种AI API
  • 前后端集成: 掌握前后端集成技术
  • 系统设计: 设计复杂的多模态系统

🎉 开始学习

现在你已经了解了整个多模态应用的实现,开始动手构建你自己的多模态应用吧!

推荐学习顺序: 1. 先实现图像上传和描述功能 2. 然后添加图像生成功能 3. 接着实现图文对话 4. 最后添加更多高级功能

祝你学习顺利! 💪