SymbolicModel
pytorch_symbolic.SymbolicModel
pytorch_symbolic.SymbolicModel(inputs: Tuple[SymbolicData, ...] | List[SymbolicData] | SymbolicData, outputs: Tuple[SymbolicData, ...] | List[SymbolicData] | SymbolicData, enable_cuda_graphs = False, enable_forward_codegen = None)
Bases: nn.Module
A PyTorch model that replays operations defined in the graph.
All operations that were required to change inputs
into outputs
will be replayed
in the same order, but on the real data provided as input to this model.
Example::
input1 = Input((10,))
input2 = Input((10,))
x = input1 + input2
x = nn.Linear(x.features, 1)(x)
model = SymbolicModel((input1, input2), x)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
Tuple[SymbolicData, ...] | List[SymbolicData] | SymbolicData
|
A collection of SymbolicData that represent the input data used by the model. It is you who provide the specific data when the model is created. If you have mulitple inputs here, be prepared to pass multiple inputs during training/inference. |
required |
outputs |
Tuple[SymbolicData, ...] | List[SymbolicData] | SymbolicData
|
A collection of SymbolicTensors that will end the computations. These nodes return your final computation result. So if you have mulitple outputs, SymbolicModel will return a tuple of tensors. |
required |
enable_cuda_graphs |
If True, after the model creation, model will be converted to CUDA Graph. This requires CUDA capable device. CUDA Graphs are greatly speeding up the execution of some of the models. Not all models are compatible with CUDA Graphs. For example, if your model includes non-deterministic behaviour, it likely won't work. |
False
|
Attributes:
Name | Type | Description |
---|---|---|
inputs |
tuple
|
Non-modifiable tuple of input nodes |
outputs |
tuple
|
Non-modifiable tuple of output nodes |
input_shape
property
Return shape of the input or in case of multiple inputs - a tuple of them.
output_shape
property
Return shape of the output or in case of multiple outputs - a tuple of them.
forward
This function is executed by call. Do not use this directly, use call instead.
Warning!
This function will be overwritten by _replace_forward_with_codegen
if enable_forward_codegen
is True. If this happened and you want to see your source, print self._generated_forward_source
.