Some details about the TFHE operator behind the neural network in advanced_examples

I used model.compile(inputset, device=device, verbose=True, show_mlir=True) in FullyConnectedNeuralNetworkOnMNIST.ipynb to print some circuit details, but I can’t understand them.

Computation Graph for _clear_forward_proxy
--------------------------------------------------------------------------------
%0 = _x_181                                                                                               # EncryptedTensor<int4, shape=(1, 784)>         ∈ [-1, 6]
%1 = [[1 0 1 .. ... .. 0 0 0]]                                                                            # ClearTensor<int4, shape=(784, 392)>           ∈ [-4, 4]             
%2 = matmul(%0, %1)                                                                                       # EncryptedTensor<int12, shape=(1, 392)>        ∈ [-1387, 458]        
%3 = [-1 -1 -2  ...  -1 -1 -1]                                                                            # ClearTensor<int4, shape=(392,)>               ∈ [-5, 1]
%4 = add(%2, %3)                                                                                          # EncryptedTensor<int12, shape=(1, 392)>        ∈ [-1392, 457]
%5 = round_bit_pattern(%4, lsbs_to_remove=5, overflow_protection=False, exactness=Exactness.EXACT)        # EncryptedTensor<int12, shape=(1, 392)>        ∈ [-1376, 448]        @ /features/fc0/Gemm.matmul_rounding
%6 = subgraph(%5)                                                                                         # EncryptedTensor<uint3, shape=(1, 392)>        ∈ [0, 7]
%7 = [[ 0  1  1 ... -1 -1  1]]                                                                            # ClearTensor<int4, shape=(392, 10)>            ∈ [-5, 4] 
%8 = matmul(%6, %7)                                                                                       # EncryptedTensor<int8, shape=(1, 10)>          ∈ [-82, 101] 
return %8

MLIR
--------------------------------------------------------------------------------
module {
  func.func @_clear_forward_proxy(%arg0: tensor<1x784x!FHE.esint<12>>) -> tensor<1x10x!FHE.esint<9>> {
    %cst = arith.constant dense<"0x0001......00101"> : tensor<784x392xi5>
    %0 = "FHELinalg.matmul_eint_int"(%arg0, %cst) : (tensor<1x784x!FHE.esint<12>>, tensor<784x392xi5>) -> tensor<1x392x!FHE.esint<12>>
    %cst_0 = arith.constant dense<"0x000E......F0F0F"> : tensor<392xi4>
    %1 = "FHELinalg.add_eint_int"(%0, %cst_0) : (tensor<1x392x!FHE.esint<12>>, tensor<392xi4>) -> tensor<1x392x!FHE.esint<12>>
    %2 = "FHELinalg.round"(%1) : (tensor<1x392x!FHE.esint<12>>) -> tensor<1x392x!FHE.esint<7>>
    %cst_1 = arith.constant dense<"0x0000......0000"> : tensor<128xi64>
    %3 = "FHELinalg.apply_lookup_table"(%2, %cst_1) : (tensor<1x392x!FHE.esint<7>>, tensor<128xi64>) -> tensor<1x392x!FHE.eint<9>>
    %cst_2 = arith.constant dense<"0x1F010......0011F"> : tensor<392x10xi5>
    %4 = "FHELinalg.to_signed"(%3) : (tensor<1x392x!FHE.eint<9>>) -> tensor<1x392x!FHE.esint<9>>
    %5 = "FHELinalg.matmul_eint_int"(%4, %cst_2) : (tensor<1x392x!FHE.esint<9>>, tensor<392x10xi5>) -> tensor<1x10x!FHE.esint<9>>
    return %5 : tensor<1x10x!FHE.esint<9>>
  }
}

programmable_bootstrap_count: 2352
programmable_bootstrap_count_per_parameter: {
BootstrapKeyParam(polynomial_size=256, glwe_dimension=6, input_lwe_dimension=587, level=3, base_log=9, variance=0.000000): 1960
BootstrapKeyParam(polynomial_size=16384, glwe_dimension=1, input_lwe_dimension=862, level=2, base_log=15, variance=0.000000): 392
}
  1. Correspondence between PBS and computational graph
  • There are two types of PBS. The first type has smaller parameter levels. The second type has larger parameter levels. I want to determine whether the first type of PBS corresponds to round_bit_pattern in the computation graph, and whether the second type of PBS corresponds to %6 = subgraph(%5) (the LUT operator for quantization, ReLU, and truncation fusion) in the computation graph?
  1. round_bit_pattern
  • The function of round_bit_pattern is to carry and then truncate. I want to know how it corresponds in TFHE. Why is its parameter so small? Does it change the ciphertext modulus or the plaintext modulus?
  1. The meaning of IR
  • (tensor<1x392x!FHE.esint<12>>) -> tensor<1x392x!FHE.esint<7>>
  • The IR value above indicates that the plaintext modulus decreased from 12 to 7, right?
  1. FHELinalg.to_signed
  • The documentation states that FHELinalg.to_signed can cast an unsigned integer tensor to a signed one. Does this mean that FHELinalg.to_signed’s TFHE operator in the backend only changes the encoding and decoding methods?

Hi @gyu,

  • round_bit_pattern in the computation graph is the MLIR FHELinalg.round op. It is the rounded PBS step that reduces precision before the TLU. The smaller parameter set (poly size 256) lines up with this rounding step, and the count 1960 = 392 * 5 matches lsbs_to_remove=5, so the compiler decomposes the rounding into multiple PBS internally.

  • The reason its parameters are smaller is that the output precision is smaller which allows lighter crypto parameters.

  • From the user perspective it reduces the plaintext bit‑width; under the hood it outputs a fresh ciphertext under a different parameter set. So yes, it changes the plaintext modulus/bit‑width. The ciphertext modulus changes only in the sense that a new parameter set is used.

  • tensor<…!FHE.esint<12>> → tensor<…!FHE.esint<7>> means the message precision drops from 12 bits to 7 bits. It’s not just truncation of a 12‑bit value, it’s a rescaled/rounded value encoded in a smaller plaintext space.

  • FHELinalg.to_signed is a cast between unsigned and signed interpretations of the same encrypted bits. No PBS, it’s essentially metadata/type for subsequent signed ops.

Thank you so much!
I have some other questions about library usage.

Is there any way for me to extract the specific calculation process from the calculation graph?

We can see that the most complex operations in this computation graph are %5 = round_bit_pattern and %6 = subgraph(%5), especially subgraph, which involves operator fusion.

Could I extract these two processes from the computation graph as functions?

This would allow me to make some modifications, such as modifying %6 = subgraph(%5), and see how this modification affects the overall ML inference result.

I am not sure how you could do that simply from the fhe_circuit graph. You would need to hack around quite a bit. You could change the onnx / ml part instead and use rounding explicitly instead of using the power_of_two_scaling that this notebook has.