Skip to content

Commit

Permalink
Fix custom model documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 5, 2024
1 parent 949ed5c commit 8c47f80
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
8 changes: 5 additions & 3 deletions docs/references/network_architectures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ You can also build your own encoder factory.
# your own neural network
class CustomEncoder(nn.Module):
def __init__(self, obsevation_shape, feature_size):
def __init__(self, observation_shape, feature_size):
super().__init__()
self.feature_size = feature_size
self.fc1 = nn.Linear(observation_shape[0], 64)
self.fc2 = nn.Linear(64, feature_size)
Expand Down Expand Up @@ -72,7 +73,8 @@ controls.
.. code-block:: python
class CustomEncoderWithAction(nn.Module):
def __init__(self, obsevation_shape, action_size, feature_size):
def __init__(self, observation_shape, action_size, feature_size):
super().__init__()
self.feature_size = feature_size
self.fc1 = nn.Linear(observation_shape[0] + action_size, 64)
self.fc2 = nn.Linear(64, feature_size)
Expand All @@ -90,7 +92,7 @@ controls.
def create(self, observation_shape):
return CustomEncoder(observation_shape, self.feature_size)
def create_with_action(observation_shape, action_size, discrete_action):
def create_with_action(self, observation_shape, action_size, discrete_action):
return CustomEncoderWithAction(observation_shape, action_size, self.feature_size)
@staticmethod
Expand Down
17 changes: 17 additions & 0 deletions docs/tutorials/customize_neural_network.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,20 @@ Now, you can customize actor-critic algorithms.
actor_encoder_factory=encoder_factory,
critic_encoder_factory=encoder_factory,
).create()
Make your models loadable
-------------------------

If you want ``load_learnable`` method to load the algorithm configuration including
your encoder configuration, you need to register your encoder factory.

.. code-block:: python
from d3rlpy.models.encoders import register_encoder_factory
# register your own encoder factory
register_encoder_factory(CustomEncoderFactory)
# load algorithm from d3
dqn = d3rlpy.load_learnable("model.d3")

0 comments on commit 8c47f80

Please sign in to comment.