diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 5dbadf9c38..4a00d8332f 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -24,6 +24,17 @@ class Memory(BaseModel): index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list)) ignore_id: bool = False + def _filter_expired_messages(self, messages: list[Message]) -> list[Message]: + """Filter out expired messages from the given list. + + Args: + messages: List of messages to filter. + + Returns: + List of non-expired messages. + """ + return [message for message in messages if not message.is_expired()] + def add(self, message: Message): """Add a new message to storage, while updating the index""" if self.ignore_id: @@ -40,11 +51,13 @@ def add_batch(self, messages: Iterable[Message]): def get_by_role(self, role: str) -> list[Message]: """Return all messages of a specified role""" - return [message for message in self.storage if message.role == role] + messages = [message for message in self.storage if message.role == role] + return self._filter_expired_messages(messages) def get_by_content(self, content: str) -> list[Message]: """Return all messages containing a specified content""" - return [message for message in self.storage if content in message.content] + messages = [message for message in self.storage if content in message.content] + return self._filter_expired_messages(messages) def delete_newest(self) -> "Message": """delete the newest message from the storage""" @@ -75,11 +88,13 @@ def count(self) -> int: def try_remember(self, keyword: str) -> list[Message]: """Try to recall all messages containing a specified keyword""" - return [message for message in self.storage if keyword in message.content] + messages = [message for message in self.storage if keyword in message.content] + return self._filter_expired_messages(messages) def get(self, k=0) -> list[Message]: """Return the most recent k memories, return all when k=0""" - return self.storage[-k:] + messages = self.storage[-k:] + return self._filter_expired_messages(messages) def find_news(self, observed: list[Message], k=0) -> list[Message]: """find news (previously unseen messages) from the most recent k memories, from all memories when k=0""" @@ -94,7 +109,8 @@ def find_news(self, observed: list[Message], k=0) -> list[Message]: def get_by_action(self, action) -> list[Message]: """Return all messages triggered by a specified Action""" index = any_to_str(action) - return self.index[index] + messages = self.index[index] + return self._filter_expired_messages(messages) def get_by_actions(self, actions: Set) -> list[Message]: """Return all messages triggered by specified Actions""" @@ -104,9 +120,12 @@ def get_by_actions(self, actions: Set) -> list[Message]: if action not in self.index: continue rsp += self.index[action] - return rsp + return self._filter_expired_messages(rsp) @handle_exception def get_by_position(self, position: int) -> Optional[Message]: - """Returns the message at the given position if valid, otherwise returns None""" - return self.storage[position] + """Returns the message at the given position if valid and not expired, otherwise returns None""" + message = self.storage[position] + if message and message.is_expired(): + return None + return message diff --git a/metagpt/schema.py b/metagpt/schema.py index 52badcc21a..a82143c16b 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -240,6 +240,8 @@ class Message(BaseModel): sent_from: str = Field(default="", validate_default=True) send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) metadata: Dict[str, Any] = Field(default_factory=dict) # metadata for `content` and `instruct_content` + ttl: int = Field(default=-1, validate_default=True) # Time-To-Live in seconds, -1 means never expire + created_at: float = Field(default_factory=time.time) # Creation time in seconds since epoch @field_validator("id", mode="before") @classmethod @@ -415,6 +417,18 @@ def is_user_message(self) -> bool: def is_ai_message(self) -> bool: return self.role == "assistant" + def is_expired(self) -> bool: + """Check if the message has expired based on its TTL. + + Returns: + bool: True if the message has expired, False otherwise. + Messages with ttl=-1 never expire. + """ + if self.ttl == -1: + return False + current_time = time.time() + return current_time - self.created_at > self.ttl + class UserMessage(Message): """便于支持OpenAI的消息 diff --git a/tests/metagpt/memory/test_memory.py b/tests/metagpt/memory/test_memory.py index a072b61deb..8b066bb71a 100644 --- a/tests/metagpt/memory/test_memory.py +++ b/tests/metagpt/memory/test_memory.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of Memory +import time + from metagpt.actions import UserRequirement from metagpt.memory.memory import Memory from metagpt.schema import Message @@ -55,3 +57,135 @@ def test_memory(): memory.clear() assert memory.count() == 0 assert len(memory.index) == 0 + + +def test_message_ttl_and_created_at(): + """Test Message class ttl and created_at fields""" + # Test default values + message = Message(content="test message", role="user") + assert message.ttl == -1 + assert message.created_at > 0 + + # Test custom ttl + message_with_ttl = Message(content="test message with ttl", role="user", ttl=60) + assert message_with_ttl.ttl == 60 + + # Test is_expired method for ttl=-1 (never expire) + assert message.is_expired() == False + + # Test serialization + dumped = message.dump() + loaded = Message.load(dumped) + assert loaded.ttl == message.ttl + assert abs(loaded.created_at - message.created_at) < 0.001 + + +def test_message_expiration(): + """Test message expiration functionality""" + # Create a message that expires in 1 second + message = Message(content="expiring message", role="user", ttl=1) + assert message.is_expired() == False + + # Wait for message to expire + time.sleep(1.1) + assert message.is_expired() == True + + # Create a message that never expires + message_never_expire = Message(content="never expiring message", role="user", ttl=-1) + assert message_never_expire.is_expired() == False + + # Wait and verify it still doesn't expire + time.sleep(0.5) + assert message_never_expire.is_expired() == False + + +def test_memory_filter_expired_messages(): + """Test Memory class filtering expired messages""" + memory = Memory() + + # Create messages with different TTLs + message1 = Message(content="never expire", role="user1", ttl=-1) + message2 = Message(content="expire in 1 sec", role="user2", ttl=1) + message3 = Message(content="never expire too", role="user1", ttl=-1) + + # Add all messages to memory + memory.add_batch([message1, message2, message3]) + assert memory.count() == 3 + + # Wait for message2 to expire + time.sleep(1.1) + + # Test get() method filters expired messages + messages = memory.get() + assert len(messages) == 2 + assert message2 not in messages + + # Test get_by_role() method filters expired messages + messages = memory.get_by_role("user2") + assert len(messages) == 0 + + # Test get_by_role() for non-expired messages + messages = memory.get_by_role("user1") + assert len(messages) == 2 + + # Test get_by_content() method filters expired messages + messages = memory.get_by_content("expire") + assert len(messages) == 2 + assert message2 not in messages + + # Test try_remember() method filters expired messages + messages = memory.try_remember("expire") + assert len(messages) == 2 + assert message2 not in messages + + # Test get_by_action() method filters expired messages + messages = memory.get_by_action(UserRequirement) + assert len(messages) == 2 + assert message2 not in messages + + # Test get_by_actions() method filters expired messages + messages = memory.get_by_actions({UserRequirement}) + assert len(messages) == 2 + assert message2 not in messages + + # Test get_by_position() method returns None for expired messages + # Note: message2 is at position 1 in storage + message = memory.get_by_position(1) + assert message is None + + +def test_memory_backward_compatibility(): + """Test backward compatibility with existing code""" + memory = Memory() + + # Create messages without specifying ttl (should use default -1) + message1 = Message(content="message1", role="user1") + message2 = Message(content="message2", role="user2") + + # Verify default ttl is -1 + assert message1.ttl == -1 + assert message2.ttl == -1 + + # Add to memory + memory.add_batch([message1, message2]) + + # Verify all retrieval methods work as before + assert memory.count() == 2 + + messages = memory.get() + assert len(messages) == 2 + + messages = memory.get_by_role("user1") + assert len(messages) == 1 + assert messages[0].content == "message1" + + messages = memory.get_by_content("message") + assert len(messages) == 2 + + # Verify messages don't expire + time.sleep(0.5) + assert message1.is_expired() == False + assert message2.is_expired() == False + + messages = memory.get() + assert len(messages) == 2